diff --git a/.cirrus.yml b/.cirrus.yml new file mode 100644 index 00000000..568c34d7 --- /dev/null +++ b/.cirrus.yml @@ -0,0 +1,31 @@ +freebsd_task: + name: FreeBSD + + matrix: + - name: FreeBSD 14.3 + freebsd_instance: + image_family: freebsd-14-3 + + pkginstall_script: + - pkg update -f + - pkg install -y go125 + - pkg install -y git + + setup_script: + - ln -s /usr/local/bin/go125 /usr/local/bin/go + - pw groupadd sftpgo + - pw useradd sftpgo -g sftpgo -w none -m + - mkdir /home/sftpgo/sftpgo + - cp -R . /home/sftpgo/sftpgo + - chown -R sftpgo:sftpgo /home/sftpgo/sftpgo + + compile_script: + - su sftpgo -c 'cd ~/sftpgo && go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' + - su sftpgo -c 'cd ~/sftpgo/tests/eventsearcher && go build -trimpath -ldflags "-s -w" -o eventsearcher' + - su sftpgo -c 'cd ~/sftpgo/tests/ipfilter && go build -trimpath -ldflags "-s -w" -o ipfilter' + + check_script: + - su sftpgo -c 'cd ~/sftpgo && ./sftpgo initprovider && ./sftpgo resetprovider --force' + + test_script: + - su sftpgo -c 'cd ~/sftpgo && go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 20m ./... -coverprofile=coverage.txt -covermode=atomic' diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..edfab1aa --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: [drakkan] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 00000000..934cc584 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,110 @@ +name: Open Source Bug Report +description: "Submit a report and help us improve SFTPGo" +title: "[Bug]: " +labels: ["bug"] +body: + - type: markdown + attributes: + value: | + ### 👍 Thank you for contributing to our project! + Before asking for help please check our [support policy](https://github.com/drakkan/sftpgo?tab=readme-ov-file#support). + If you are a [commercial user](https://sftpgo.com/) please contact us using the dedicated [email address](mailto:support@sftpgo.com). + If you'd like to contribute code, please make sure to read and understand our [Contributor License Agreement (CLA)](https://sftpgo.com/cla.html). + You’ll be asked to accept it when submitting a pull request. + - type: checkboxes + id: before-posting + attributes: + label: "⚠️ This issue respects the following points: ⚠️" + description: All conditions are **required**. + options: + - label: This is a **bug**, not a question or a configuration issue. + required: true + - label: This issue is **not** already reported on Github _(I've searched it)_. + required: true + - type: textarea + id: bug-description + attributes: + label: Bug description + description: | + Provide a description of the bug you're experiencing. + Don't just expect someone will guess what your specific problem is and provide full details. + validations: + required: true + - type: textarea + id: reproduce + attributes: + label: Steps to reproduce + description: | + Describe the steps to reproduce the bug. + The better your description is the fastest you'll get an _(accurate)_ answer. + value: | + 1. + 2. + 3. + validations: + required: true + - type: textarea + id: expected-behavior + attributes: + label: Expected behavior + description: Describe what you expected to happen instead. + validations: + required: true + - type: input + id: version + attributes: + label: SFTPGo version + validations: + required: true + - type: input + id: data-provider + attributes: + label: Data provider + validations: + required: true + - type: dropdown + id: install-method + attributes: + label: Installation method + description: | + Select installation method you've used. + _Describe the method in the "Additional info" section if you chose "Other"._ + options: + - "Community Docker image" + - "Community Deb package" + - "Community RPM package" + - "Other" + validations: + required: true + - type: textarea + attributes: + label: Configuration + description: "Describe your customizations to the configuration: both config file changes and overrides via environment variables" + value: config + validations: + required: true + - type: textarea + id: logs + attributes: + label: Relevant log output + description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. + render: shell + - type: dropdown + id: usecase + attributes: + label: What are you using SFTPGo for? + description: We'd like to understand your SFTPGo usecase more + multiple: true + options: + - "Private user, home usecase (home backup/VPS)" + - "Professional user, 1 person business" + - "Small business (3-person firm with file exchange?)" + - "Medium business" + - "Enterprise" + validations: + required: true + - type: textarea + id: additional-info + attributes: + label: Additional info + description: Any additional information related to the issue. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..1b3b236a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,9 @@ +blank_issues_enabled: false +contact_links: + - name: Commercial Support + url: https://sftpgo.com/ + about: > + If you need Professional support, so your reports are prioritized and resolved more quickly. + - name: GitHub Community Discussions + url: https://github.com/drakkan/sftpgo/discussions + about: Please ask and answer questions here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 00000000..5fd83037 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,50 @@ +name: 🚀 Feature request +description: Suggest an idea for SFTPGo +labels: ["suggestion"] +body: + - type: markdown + attributes: + value: | + ### 👍 Thank you for contributing to our project! + Before asking for help please check our [support policy](https://github.com/drakkan/sftpgo?tab=readme-ov-file#support). + If you are a [commercial user](https://sftpgo.com/) please contact us using the dedicated [email address](mailto:support@sftpgo.com). + If you'd like to contribute code, please make sure to read and understand our [Contributor License Agreement (CLA)](https://sftpgo.com/cla.html). + You’ll be asked to accept it when submitting a pull request. + - type: textarea + attributes: + label: Is your feature request related to a problem? Please describe. + description: A clear and concise description of what the problem is. + validations: + required: false + - type: textarea + attributes: + label: Describe the solution you'd like + description: A clear and concise description of what you want to happen. + validations: + required: true + - type: textarea + attributes: + label: Describe alternatives you've considered + description: A clear and concise description of any alternative solutions or features you've considered. + validations: + required: false + - type: dropdown + id: usecase + attributes: + label: What are you using SFTPGo for? + description: We'd like to understand your SFTPGo usecase more + multiple: true + options: + - "Private user, home usecase (home backup/VPS)" + - "Professional user, 1 person business" + - "Small business (3-person firm with file exchange?)" + - "Medium business" + - "Enterprise" + validations: + required: true + - type: textarea + attributes: + label: Additional context + description: Add any other context or screenshots about the feature request here. + validations: + required: false \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..44438dd6 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,5 @@ +# Checklist for Pull Requests + +- [ ] Have you signed the [Contributor License Agreement](https://sftpgo.com/cla.html)? + +--- diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..0c90b458 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,20 @@ +version: 2 + +updates: + #- package-ecosystem: "gomod" + # directory: "/" + # schedule: + # interval: "weekly" + # open-pull-requests-limit: 2 + + - package-ecosystem: "docker" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 2 + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 2 diff --git a/.github/workflows/.editorconfig b/.github/workflows/.editorconfig new file mode 100644 index 00000000..7bd3346f --- /dev/null +++ b/.github/workflows/.editorconfig @@ -0,0 +1,2 @@ +[*.yml] +indent_size = 2 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 00000000..db26ed74 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,36 @@ +name: "Code scanning - action" + +on: + push: + pull_request: + schedule: + - cron: '30 1 * * 6' + +jobs: + CodeQL-Build: + runs-on: ubuntu-latest + + permissions: + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.25' + + - name: Initialize CodeQL + uses: github/codeql-action/init@v4 + with: + languages: go + + - name: Autobuild + uses: github/codeql-action/autobuild@v4 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v4 \ No newline at end of file diff --git a/.github/workflows/development.yml b/.github/workflows/development.yml new file mode 100644 index 00000000..45769de3 --- /dev/null +++ b/.github/workflows/development.yml @@ -0,0 +1,562 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +permissions: + id-token: write + contents: read + +jobs: + test-deploy: + name: Test and deploy + runs-on: ${{ matrix.os }} + strategy: + matrix: + go: ['1.26'] + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ matrix.go }} + + - name: Build for Linux/macOS x86_64 + run: | + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + cd tests/eventsearcher + go build -trimpath -ldflags "-s -w" -o eventsearcher + cd - + cd tests/ipfilter + go build -trimpath -ldflags "-s -w" -o ipfilter + cd - + ./sftpgo initprovider + ./sftpgo resetprovider --force + + - name: Build for macOS arm64 + if: startsWith(matrix.os, 'macos-') == true + run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64 + + - name: Run test cases using SQLite provider + run: go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + files: ./coverage.txt + fail_ci_if_error: false + token: ${{ secrets.CODECOV_TOKEN }} + + - name: Run test cases using bolt provider + run: | + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/config -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/common -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/httpd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 8m ./internal/sftpd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/ftpd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/webdavd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/telemetry -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/mfa -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/command -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: bolt + SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db' + + - name: Run test cases using memory provider + run: go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: memory + SFTPGO_DATA_PROVIDER__NAME: '' + + - name: Prepare build artifact for macOS + if: startsWith(matrix.os, 'macos-') == true + run: | + mkdir -p output/{init,bash_completion,zsh_completion} + cp sftpgo output/sftpgo_x86_64 + cp sftpgo_arm64 output/ + cp sftpgo.json output/ + cp -r templates output/ + cp -r static output/ + cp -r openapi output/ + cp init/com.github.drakkan.sftpgo.plist output/init/ + ./sftpgo gen completion bash > output/bash_completion/sftpgo + ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo + ./sftpgo gen man -d output/man/man1 + gzip output/man/man1/* + + - name: Upload build artifact + if: startsWith(matrix.os, 'ubuntu-') != true + uses: actions/upload-artifact@v7 + with: + name: sftpgo-${{ matrix.os }}-go-${{ matrix.go }} + path: output + + test-deploy-windows: + name: Test and deploy Windows + runs-on: windows-latest + + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.26' + + - name: Run test cases using SQLite provider + run: | + cd tests/eventsearcher + go build -trimpath -ldflags "-s -w" -o eventsearcher.exe + cd ../.. + cd tests/ipfilter + go build -trimpath -ldflags "-s -w" -o ipfilter.exe + cd ../.. + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic + + - name: Run test cases using bolt provider + run: | + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/config -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/common -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/httpd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 8m ./internal/sftpd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/ftpd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 5m ./internal/webdavd -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/telemetry -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/mfa -covermode=atomic + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 2m ./internal/command -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: bolt + SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db' + + - name: Run test cases using memory provider + run: go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: memory + SFTPGO_DATA_PROVIDER__NAME: '' + + - name: Build + run: | + $GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String + $DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String + $LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim() + $REV_LIST=$LATEST_TAG+"..HEAD" + $COMMITS_FROM_TAG= ((git rev-list $REV_LIST --count) | Out-String).Trim() + $FILE_VERSION = $LATEST_TAG.substring(1) + "." + $COMMITS_FROM_TAG + go install github.com/tc-hib/go-winres@latest + go-winres simply --arch amd64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe + mkdir arm64 + $Env:CGO_ENABLED='0' + $Env:GOOS='windows' + $Env:GOARCH='arm64' + go-winres simply --arch arm64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe + mkdir x86 + $Env:GOARCH='386' + go-winres simply --arch 386 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe + Remove-Item Env:\CGO_ENABLED + Remove-Item Env:\GOOS + Remove-Item Env:\GOARCH + + - name: Initialize data provider + run: | + rm sftpgo.db + ./sftpgo initprovider + shell: bash + + - name: Prepare Windows installers + if: ${{ github.event_name != 'pull_request' }} + run: | + choco install innosetup + Remove-Item -LiteralPath "output" -Force -Recurse -ErrorAction Ignore + mkdir output + copy .\sftpgo.exe .\output + copy .\sftpgo.json .\output + copy .\sftpgo.db .\output + copy .\LICENSE .\output\LICENSE.txt + copy .\NOTICE .\output\NOTICE.txt + mkdir output\templates + xcopy .\templates .\output\templates\ /E + mkdir output\static + xcopy .\static .\output\static\ /E + mkdir output\openapi + xcopy .\openapi .\output\openapi\ /E + $LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim() + $REV_LIST=$LATEST_TAG+"..HEAD" + $COMMITS_FROM_TAG= ((git rev-list $REV_LIST --count) | Out-String).Trim() + $Env:SFTPGO_ISS_DEV_VERSION = $LATEST_TAG + "." + $COMMITS_FROM_TAG + iscc .\windows-installer\sftpgo.iss + + rm .\output\sftpgo.exe + rm .\output\sftpgo.db + copy .\arm64\sftpgo.exe .\output + (Get-Content .\output\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\output\sftpgo.json + $Env:SFTPGO_DATA_PROVIDER__DRIVER='bolt' + $Env:SFTPGO_DATA_PROVIDER__NAME='.\output\sftpgo.db' + .\sftpgo.exe initprovider + Remove-Item Env:\SFTPGO_DATA_PROVIDER__DRIVER + Remove-Item Env:\SFTPGO_DATA_PROVIDER__NAME + $Env:SFTPGO_ISS_ARCH='arm64' + iscc .\windows-installer\sftpgo.iss + + rm .\output\sftpgo.exe + copy .\x86\sftpgo.exe .\output + $Env:SFTPGO_ISS_ARCH='x86' + iscc .\windows-installer\sftpgo.iss + + - name: Upload Windows installer x86_64 artifact + if: ${{ github.event_name != 'pull_request' }} + uses: actions/upload-artifact@v7 + with: + name: sftpgo_windows_installer_x86_64 + path: ./sftpgo_windows_x86_64.exe + + - name: Upload Windows installer arm64 artifact + if: ${{ github.event_name != 'pull_request' }} + uses: actions/upload-artifact@v7 + with: + name: sftpgo_windows_installer_arm64 + path: ./sftpgo_windows_arm64.exe + + - name: Upload Windows installer x86 artifact + if: ${{ github.event_name != 'pull_request' }} + uses: actions/upload-artifact@v7 + with: + name: sftpgo_windows_installer_x86 + path: ./sftpgo_windows_x86.exe + + - name: Prepare build artifact for Windows + run: | + Remove-Item -LiteralPath "output" -Force -Recurse -ErrorAction Ignore + mkdir output + copy .\sftpgo.exe .\output + mkdir output\arm64 + copy .\arm64\sftpgo.exe .\output\arm64 + mkdir output\x86 + copy .\x86\sftpgo.exe .\output\x86 + copy .\sftpgo.json .\output + (Get-Content .\output\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\output\sftpgo.json + mkdir output\templates + xcopy .\templates .\output\templates\ /E + mkdir output\static + xcopy .\static .\output\static\ /E + mkdir output\openapi + xcopy .\openapi .\output\openapi\ /E + + - name: Upload build artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo-windows-portable + path: output + + test-build-flags: + name: Test build flags + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.26' + + - name: Build + run: | + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nogcs,nos3,noportable,nobolt,nomysql,nopgsql,nosqlite,nometrics,noazblob,unixcrypt -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + ./sftpgo -v + cp -r openapi static templates internal/bundle/ + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,bundle -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + ./sftpgo -v + + test-postgresql-mysql-crdb: + name: Test with PgSQL/MySQL/Cockroach + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:latest + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: sftpgo + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + mariadb: + image: mariadb:latest + env: + MYSQL_ROOT_PASSWORD: mysql + MYSQL_DATABASE: sftpgo + MYSQL_USER: sftpgo + MYSQL_PASSWORD: sftpgo + options: >- + --health-cmd "mariadb-admin status -h 127.0.0.1 -P 3306 -u root -p$MYSQL_ROOT_PASSWORD" + --health-interval 10s + --health-timeout 5s + --health-retries 6 + ports: + - 3307:3306 + + mysql: + image: mysql:latest + env: + MYSQL_ROOT_PASSWORD: mysql + MYSQL_DATABASE: sftpgo + MYSQL_USER: sftpgo + MYSQL_PASSWORD: sftpgo + options: >- + --health-cmd "mysqladmin status -h 127.0.0.1 -P 3306 -u root -p$MYSQL_ROOT_PASSWORD" + --health-interval 10s + --health-timeout 5s + --health-retries 6 + ports: + - 3308:3306 + + steps: + - uses: actions/checkout@v6 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.26' + + - name: Build + run: | + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + cd tests/eventsearcher + go build -trimpath -ldflags "-s -w" -o eventsearcher + cd - + cd tests/ipfilter + go build -trimpath -ldflags "-s -w" -o ipfilter + cd - + + - name: Run tests using MySQL provider + run: | + ./sftpgo initprovider + ./sftpgo resetprovider --force + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: mysql + SFTPGO_DATA_PROVIDER__NAME: sftpgo + SFTPGO_DATA_PROVIDER__HOST: localhost + SFTPGO_DATA_PROVIDER__PORT: 3308 + SFTPGO_DATA_PROVIDER__USERNAME: sftpgo + SFTPGO_DATA_PROVIDER__PASSWORD: sftpgo + + - name: Run tests using PostgreSQL provider + run: | + ./sftpgo initprovider + ./sftpgo resetprovider --force + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: postgresql + SFTPGO_DATA_PROVIDER__NAME: sftpgo + SFTPGO_DATA_PROVIDER__HOST: localhost + SFTPGO_DATA_PROVIDER__PORT: 5432 + SFTPGO_DATA_PROVIDER__USERNAME: postgres + SFTPGO_DATA_PROVIDER__PASSWORD: postgres + + - name: Run tests using MariaDB provider + run: | + ./sftpgo initprovider + ./sftpgo resetprovider --force + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic + env: + SFTPGO_DATA_PROVIDER__DRIVER: mysql + SFTPGO_DATA_PROVIDER__NAME: sftpgo + SFTPGO_DATA_PROVIDER__HOST: localhost + SFTPGO_DATA_PROVIDER__PORT: 3307 + SFTPGO_DATA_PROVIDER__USERNAME: sftpgo + SFTPGO_DATA_PROVIDER__PASSWORD: sftpgo + SFTPGO_DATA_PROVIDER__SQL_TABLES_PREFIX: prefix_ + + - name: Run tests using CockroachDB provider + run: | + docker run --rm --name crdb --health-cmd "curl -I http://127.0.0.1:8080" --health-interval 10s --health-timeout 5s --health-retries 6 -p 26257:26257 -d cockroachdb/cockroach:latest start-single-node --insecure --listen-addr :26257 + sleep 10 + docker exec crdb cockroach sql --insecure -e 'create database "sftpgo"' + ./sftpgo initprovider + ./sftpgo resetprovider --force + go test -v -tags nopgxregisterdefaulttypes,disable_grpc_modules -p 1 -timeout 15m ./... -covermode=atomic + docker stop crdb + env: + SFTPGO_DATA_PROVIDER__DRIVER: cockroachdb + SFTPGO_DATA_PROVIDER__NAME: sftpgo + SFTPGO_DATA_PROVIDER__HOST: localhost + SFTPGO_DATA_PROVIDER__PORT: 26257 + SFTPGO_DATA_PROVIDER__USERNAME: root + SFTPGO_DATA_PROVIDER__PASSWORD: + SFTPGO_DATA_PROVIDER__TARGET_SESSION_ATTRS: any + SFTPGO_DATA_PROVIDER__SQL_TABLES_PREFIX: prefix_ + + build-linux-packages: + name: Build Linux packages + runs-on: ubuntu-latest + strategy: + matrix: + include: + - arch: amd64 + distro: ubuntu:18.04 + go: latest + go-arch: amd64 + - arch: aarch64 + distro: ubuntu18.04 + go: latest + go-arch: arm64 + - arch: ppc64le + distro: ubuntu18.04 + go: latest + go-arch: ppc64le + - arch: armv7 + distro: ubuntu18.04 + go: latest + go-arch: arm7 + steps: + - uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Get commit SHA + id: get_commit + run: echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT + shell: bash + + - name: Build on amd64 + if: ${{ matrix.arch == 'amd64' }} + run: | + echo '#!/bin/bash' > build.sh + echo '' >> build.sh + echo 'set -e' >> build.sh + echo 'apt-get update -q -y' >> build.sh + echo 'apt-get install -q -y curl gcc' >> build.sh + if [ ${{ matrix.go }} == 'latest' ] + then + echo 'GO_VERSION=$(curl -L https://go.dev/VERSION?m=text | head -n 1)' >> build.sh + else + echo 'GO_VERSION=${{ matrix.go }}' >> build.sh + fi + echo 'GO_DOWNLOAD_ARCH=${{ matrix.go-arch }}' >> build.sh + echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/${GO_VERSION}.linux-${GO_DOWNLOAD_ARCH}.tar.gz' >> build.sh + echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh + echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh + echo 'go version' >> build.sh + echo 'cd /usr/local/src' >> build.sh + echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh + + chmod 755 build.sh + docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh + mkdir -p output/{init,bash_completion,zsh_completion} + cp sftpgo.json output/ + cp -r templates output/ + cp -r static output/ + cp -r openapi output/ + cp init/sftpgo.service output/init/ + ./sftpgo gen completion bash > output/bash_completion/sftpgo + ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo + ./sftpgo gen man -d output/man/man1 + gzip output/man/man1/* + cp sftpgo output/ + + - uses: uraimo/run-on-arch-action@v3 + if: ${{ matrix.arch != 'amd64' }} + name: Build for ${{ matrix.arch }} + id: build + with: + arch: ${{ matrix.arch }} + distro: ${{ matrix.distro }} + setup: | + mkdir -p "${PWD}/output" + dockerRunArgs: | + --volume "${PWD}/output:/output" + shell: /bin/bash + install: | + apt-get update -q -y + apt-get install -q -y curl gcc + if [ ${{ matrix.go }} == 'latest' ] + then + GO_VERSION=$(curl -L https://go.dev/VERSION?m=text | head -n 1) + else + GO_VERSION=${{ matrix.go }} + fi + GO_DOWNLOAD_ARCH=${{ matrix.go-arch }} + if [ ${{ matrix.arch}} == 'armv7' ] + then + GO_DOWNLOAD_ARCH=armv6l + fi + curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/${GO_VERSION}.linux-${GO_DOWNLOAD_ARCH}.tar.gz + tar -C /usr/local -xzf go.tar.gz + run: | + export PATH=$PATH:/usr/local/go/bin + go version + if [ ${{ matrix.arch}} == 'armv7' ] + then + export GOARM=7 + fi + go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + mkdir -p output/{init,bash_completion,zsh_completion} + cp sftpgo.json output/ + cp -r templates output/ + cp -r static output/ + cp -r openapi output/ + cp init/sftpgo.service output/init/ + ./sftpgo gen completion bash > output/bash_completion/sftpgo + ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo + ./sftpgo gen man -d output/man/man1 + gzip output/man/man1/* + cp sftpgo output/ + + - name: Upload build artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo-linux-${{ matrix.arch }}-go-${{ matrix.go }} + path: output + + - name: Build Packages + id: build_linux_pkgs + run: | + export NFPM_ARCH=${{ matrix.go-arch }} + cd pkgs + ./build.sh + PKG_VERSION=$(cat dist/version) + echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT + + - name: Upload Debian Package + uses: actions/upload-artifact@v7 + with: + name: sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-${{ matrix.go-arch }}-deb + path: pkgs/dist/deb/* + + - name: Upload RPM Package + uses: actions/upload-artifact@v7 + with: + name: sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-${{ matrix.go-arch }}-rpm + path: pkgs/dist/rpm/* + + golangci-lint: + name: golangci-lint + runs-on: ubuntu-latest + steps: + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.26' + - uses: actions/checkout@v6 + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: latest diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 00000000..822c916e --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,188 @@ +name: Docker + +on: + #schedule: + # - cron: '0 4 * * *' # everyday at 4:00 AM UTC + push: + branches: + - main + tags: + - v* + pull_request: + +jobs: + build: + name: Build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + docker_pkg: + - debian + - alpine + optional_deps: + - true + - false + include: + - os: ubuntu-latest + docker_pkg: distroless + optional_deps: false + - os: ubuntu-latest + docker_pkg: debian-plugins + optional_deps: true + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Gather image information + id: info + run: | + VERSION=noop + DOCKERFILE=Dockerfile + MINOR="" + MAJOR="" + FEATURES="nopgxregisterdefaulttypes,disable_grpc_modules" + if [ "${{ github.event_name }}" = "schedule" ]; then + VERSION=nightly + elif [[ $GITHUB_REF == refs/tags/* ]]; then + VERSION=${GITHUB_REF#refs/tags/} + elif [[ $GITHUB_REF == refs/heads/* ]]; then + VERSION=$(echo ${GITHUB_REF#refs/heads/} | sed -r 's#/+#-#g') + if [ "${{ github.event.repository.default_branch }}" = "$VERSION" ]; then + VERSION=edge + fi + elif [[ $GITHUB_REF == refs/pull/* ]]; then + VERSION=pr-${{ github.event.number }} + fi + if [[ $VERSION =~ ^v[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$ ]]; then + MINOR=${VERSION%.*} + MAJOR=${MINOR%.*} + fi + VERSION_SLIM="${VERSION}-slim" + if [[ $DOCKER_PKG == alpine ]]; then + VERSION="${VERSION}-alpine" + VERSION_SLIM="${VERSION}-slim" + DOCKERFILE=Dockerfile.alpine + elif [[ $DOCKER_PKG == distroless ]]; then + VERSION="${VERSION}-distroless" + VERSION_SLIM="${VERSION}-slim" + DOCKERFILE=Dockerfile.distroless + FEATURES="${FEATURES},nosqlite" + elif [[ $DOCKER_PKG == debian-plugins ]]; then + VERSION="${VERSION}-plugins" + VERSION_SLIM="${VERSION}-slim" + FEATURES="${FEATURES},unixcrypt" + elif [[ $DOCKER_PKG == debian ]]; then + FEATURES="${FEATURES},unixcrypt" + fi + DOCKER_IMAGES=("drakkan/sftpgo" "ghcr.io/drakkan/sftpgo") + TAGS="${DOCKER_IMAGES[0]}:${VERSION}" + TAGS_SLIM="${DOCKER_IMAGES[0]}:${VERSION_SLIM}" + + for DOCKER_IMAGE in ${DOCKER_IMAGES[@]}; do + if [[ ${DOCKER_IMAGE} != ${DOCKER_IMAGES[0]} ]]; then + TAGS="${TAGS},${DOCKER_IMAGE}:${VERSION}" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${VERSION_SLIM}" + fi + if [[ $GITHUB_REF == refs/tags/* ]]; then + if [[ $DOCKER_PKG == debian ]]; then + if [[ -n $MAJOR && -n $MINOR ]]; then + TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR},${DOCKER_IMAGE}:${MAJOR}" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-slim,${DOCKER_IMAGE}:${MAJOR}-slim" + fi + TAGS="${TAGS},${DOCKER_IMAGE}:latest" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:slim" + elif [[ $DOCKER_PKG == distroless ]]; then + if [[ -n $MAJOR && -n $MINOR ]]; then + TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-distroless,${DOCKER_IMAGE}:${MAJOR}-distroless" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-distroless-slim,${DOCKER_IMAGE}:${MAJOR}-distroless-slim" + fi + TAGS="${TAGS},${DOCKER_IMAGE}:distroless" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:distroless-slim" + elif [[ $DOCKER_PKG == debian-plugins ]]; then + if [[ -n $MAJOR && -n $MINOR ]]; then + TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-plugins,${DOCKER_IMAGE}:${MAJOR}-plugins" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-plugins-slim,${DOCKER_IMAGE}:${MAJOR}-plugins-slim" + fi + TAGS="${TAGS},${DOCKER_IMAGE}:plugins" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:plugins-slim" + else + if [[ -n $MAJOR && -n $MINOR ]]; then + TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-alpine,${DOCKER_IMAGE}:${MAJOR}-alpine" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-alpine-slim,${DOCKER_IMAGE}:${MAJOR}-alpine-slim" + fi + TAGS="${TAGS},${DOCKER_IMAGE}:alpine" + TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:alpine-slim" + fi + fi + done + + if [[ $OPTIONAL_DEPS == true ]]; then + echo "version=${VERSION}" >> $GITHUB_OUTPUT + echo "tags=${TAGS}" >> $GITHUB_OUTPUT + echo "full=true" >> $GITHUB_OUTPUT + else + echo "version=${VERSION_SLIM}" >> $GITHUB_OUTPUT + echo "tags=${TAGS_SLIM}" >> $GITHUB_OUTPUT + echo "full=false" >> $GITHUB_OUTPUT + fi + if [[ $DOCKER_PKG == debian-plugins ]]; then + echo "plugins=true" >> $GITHUB_OUTPUT + else + echo "plugins=false" >> $GITHUB_OUTPUT + fi + echo "dockerfile=${DOCKERFILE}" >> $GITHUB_OUTPUT + echo "features=${FEATURES}" >> $GITHUB_OUTPUT + echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT + echo "sha=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT + env: + DOCKER_PKG: ${{ matrix.docker_pkg }} + OPTIONAL_DEPS: ${{ matrix.optional_deps }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v4 + + - name: Set up builder + uses: docker/setup-buildx-action@v4 + id: builder + + - name: Login to Docker Hub + uses: docker/login-action@v4 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + if: ${{ github.event_name != 'pull_request' }} + + - name: Login to GitHub Container Registry + uses: docker/login-action@v4 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + if: ${{ github.event_name != 'pull_request' }} + + - name: Build and push + uses: docker/build-push-action@v7 + with: + context: . + builder: ${{ steps.builder.outputs.name }} + file: ./${{ steps.info.outputs.dockerfile }} + platforms: linux/amd64,linux/arm64,linux/ppc64le,linux/arm/v7 + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.info.outputs.tags }} + build-args: | + COMMIT_SHA=${{ steps.info.outputs.sha }} + INSTALL_OPTIONAL_PACKAGES=${{ steps.info.outputs.full }} + DOWNLOAD_PLUGINS=${{ steps.info.outputs.plugins }} + FEATURES=${{ steps.info.outputs.features }} + labels: | + org.opencontainers.image.title=SFTPGo + org.opencontainers.image.description=Full-featured and highly configurable file transfer server: SFTP, HTTP/S,FTP/S, WebDAV + org.opencontainers.image.url=https://github.com/drakkan/sftpgo + org.opencontainers.image.documentation=https://github.com/drakkan/sftpgo/blob/${{ github.sha }}/docker/README.md + org.opencontainers.image.source=https://github.com/drakkan/sftpgo + org.opencontainers.image.version=${{ steps.info.outputs.version }} + org.opencontainers.image.created=${{ steps.info.outputs.created }} + org.opencontainers.image.revision=${{ github.sha }} + org.opencontainers.image.licenses=AGPL-3.0-only \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..9ac10ee5 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,594 @@ +name: Release + +on: + push: + tags: 'v*' + +permissions: + id-token: write + contents: write + +env: + GO_VERSION: 1.25.8 + +jobs: + prepare-sources-with-deps: + name: Prepare sources with deps + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Get SFTPGo version + id: get_version + run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT + + - name: Prepare release + run: | + go mod vendor + echo "${SFTPGO_VERSION}" > VERSION.txt + echo "${GITHUB_SHA::8}" >> VERSION.txt + tar cJvf sftpgo_${SFTPGO_VERSION}_src_with_deps.tar.xz * + env: + SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} + + - name: Upload build artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_src_with_deps.tar.xz + path: ./sftpgo_${{ steps.get_version.outputs.VERSION }}_src_with_deps.tar.xz + retention-days: 1 + + prepare-windows: + name: Prepare Windows binaries + runs-on: windows-2022 + + steps: + - uses: actions/checkout@v6 + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Get SFTPGo version + id: get_version + run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT + shell: bash + + - name: Build + run: | + $GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String + $DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String + $FILE_VERSION = $Env:SFTPGO_VERSION.substring(1) + ".0" + go install github.com/tc-hib/go-winres@latest + go-winres simply --arch amd64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe + mkdir arm64 + $Env:CGO_ENABLED='0' + $Env:GOOS='windows' + $Env:GOARCH='arm64' + go-winres simply --arch arm64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe + mkdir x86 + $Env:GOARCH='386' + go-winres simply --arch 386 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version "$FILE_VERSION" --file-description "SFTPGo server" --product-name SFTPGo --copyright "2019-2025 Nicola Murino" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico + go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules,nosqlite -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe + Remove-Item Env:\CGO_ENABLED + Remove-Item Env:\GOOS + Remove-Item Env:\GOARCH + env: + SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} + + - name: Initialize data provider + run: ./sftpgo initprovider + shell: bash + + - name: Prepare Release + run: | + mkdir output + copy .\sftpgo.exe .\output + copy .\sftpgo.json .\output + copy .\sftpgo.db .\output + copy .\LICENSE .\output\LICENSE.txt + copy .\NOTICE .\output\NOTICE.txt + mkdir output\templates + xcopy .\templates .\output\templates\ /E + mkdir output\static + xcopy .\static .\output\static\ /E + mkdir output\openapi + xcopy .\openapi .\output\openapi\ /E + iscc .\windows-installer\sftpgo.iss + rm .\output\sftpgo.exe + rm .\output\sftpgo.db + copy .\arm64\sftpgo.exe .\output + (Get-Content .\output\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\output\sftpgo.json + $Env:SFTPGO_DATA_PROVIDER__DRIVER='bolt' + $Env:SFTPGO_DATA_PROVIDER__NAME='.\output\sftpgo.db' + .\sftpgo.exe initprovider + Remove-Item Env:\SFTPGO_DATA_PROVIDER__DRIVER + Remove-Item Env:\SFTPGO_DATA_PROVIDER__NAME + $Env:SFTPGO_ISS_ARCH='arm64' + iscc .\windows-installer\sftpgo.iss + + rm .\output\sftpgo.exe + copy .\x86\sftpgo.exe .\output + $Env:SFTPGO_ISS_ARCH='x86' + iscc .\windows-installer\sftpgo.iss + env: + SFTPGO_ISS_VERSION: ${{ steps.get_version.outputs.VERSION }} + + - name: Prepare Portable Release + run: | + mkdir win-portable + copy .\sftpgo.exe .\win-portable + mkdir win-portable\arm64 + copy .\arm64\sftpgo.exe .\win-portable\arm64 + mkdir win-portable\x86 + copy .\x86\sftpgo.exe .\win-portable\x86 + copy .\sftpgo.json .\win-portable + (Get-Content .\win-portable\sftpgo.json).replace('"sqlite"', '"bolt"') | Set-Content .\win-portable\sftpgo.json + copy .\output\sftpgo.db .\win-portable + copy .\LICENSE .\win-portable\LICENSE.txt + copy .\NOTICE .\win-portable\NOTICE.txt + mkdir win-portable\templates + xcopy .\templates .\win-portable\templates\ /E + mkdir win-portable\static + xcopy .\static .\win-portable\static\ /E + mkdir win-portable\openapi + xcopy .\openapi .\win-portable\openapi\ /E + Compress-Archive .\win-portable\* sftpgo_portable.zip + + - name: Upload Windows installer x86_64 artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_x86_64.exe + path: ./sftpgo_windows_x86_64.exe + retention-days: 1 + + - name: Upload Windows installer arm64 artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_arm64.exe + path: ./sftpgo_windows_arm64.exe + retention-days: 1 + + - name: Upload Windows installer x86 artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_x86.exe + path: ./sftpgo_windows_x86.exe + retention-days: 1 + + - name: Upload Windows portable artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_windows_portable.zip + path: ./sftpgo_portable.zip + retention-days: 1 + + prepare-mac: + name: Prepare macOS binaries + runs-on: macos-14 + + steps: + - uses: actions/checkout@v6 + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Get SFTPGo version + id: get_version + run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT + shell: bash + + - name: Build for macOS x86_64 + run: go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + + - name: Build for macOS arm64 + run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64 + + - name: Initialize data provider + run: ./sftpgo initprovider + shell: bash + + - name: Prepare Release + run: | + mkdir -p output/{init,sqlite,bash_completion,zsh_completion} + echo "For documentation please take a look here:" > output/README.txt + echo "" >> output/README.txt + echo "https://docs.sftpgo.com" >> output/README.txt + cp LICENSE output/ + cp NOTICE output/ + cp sftpgo output/ + cp sftpgo.json output/ + cp sftpgo.db output/sqlite/ + cp -r static output/ + cp -r openapi output/ + cp -r templates output/ + cp init/com.github.drakkan.sftpgo.plist output/init/ + ./sftpgo gen completion bash > output/bash_completion/sftpgo + ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo + ./sftpgo gen man -d output/man/man1 + gzip output/man/man1/* + cd output + tar cJvf ../sftpgo_${SFTPGO_VERSION}_macOS_x86_64.tar.xz * + cd .. + cp sftpgo_arm64 output/sftpgo + cd output + tar cJvf ../sftpgo_${SFTPGO_VERSION}_macOS_arm64.tar.xz * + cd .. + env: + SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} + + - name: Upload macOS x86_64 artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_x86_64.tar.xz + path: ./sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_x86_64.tar.xz + retention-days: 1 + + - name: Upload macOS arm64 artifact + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_arm64.tar.xz + path: ./sftpgo_${{ steps.get_version.outputs.VERSION }}_macOS_arm64.tar.xz + retention-days: 1 + + prepare-linux: + name: Prepare Linux binaries + runs-on: ubuntu-latest + strategy: + matrix: + include: + - arch: amd64 + distro: ubuntu:18.04 + go-arch: amd64 + deb-arch: amd64 + rpm-arch: x86_64 + tar-arch: x86_64 + - arch: aarch64 + distro: ubuntu18.04 + go-arch: arm64 + deb-arch: arm64 + rpm-arch: aarch64 + tar-arch: arm64 + - arch: ppc64le + distro: ubuntu18.04 + go-arch: ppc64le + deb-arch: ppc64el + rpm-arch: ppc64le + tar-arch: ppc64le + - arch: armv7 + distro: ubuntu18.04 + go-arch: arm7 + deb-arch: armhf + rpm-arch: armv7hl + tar-arch: armv7 + + steps: + - uses: actions/checkout@v6 + + - name: Get versions + id: get_version + run: | + echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT + echo "GO_VERSION=${GO_VERSION}" >> $GITHUB_OUTPUT + echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT + shell: bash + env: + GO_VERSION: ${{ env.GO_VERSION }} + + - name: Build on amd64 + if: ${{ matrix.arch == 'amd64' }} + run: | + echo '#!/bin/bash' > build.sh + echo '' >> build.sh + echo 'set -e' >> build.sh + echo 'apt-get update -q -y' >> build.sh + echo 'apt-get install -q -y curl gcc' >> build.sh + echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/go${{ steps.get_version.outputs.GO_VERSION }}.linux-${{ matrix.go-arch }}.tar.gz' >> build.sh + echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh + echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh + echo 'go version' >> build.sh + echo 'cd /usr/local/src' >> build.sh + echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh + + chmod 755 build.sh + docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh + mkdir -p output/{init,sqlite,bash_completion,zsh_completion} + echo "For documentation please take a look here:" > output/README.txt + echo "" >> output/README.txt + echo "https://github.com/drakkan/sftpgo/blob/${SFTPGO_VERSION}/README.md" >> output/README.txt + cp LICENSE output/ + cp NOTICE output/ + cp sftpgo.json output/ + cp -r templates output/ + cp -r static output/ + cp -r openapi output/ + cp init/sftpgo.service output/init/ + ./sftpgo initprovider + ./sftpgo gen completion bash > output/bash_completion/sftpgo + ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo + ./sftpgo gen man -d output/man/man1 + gzip output/man/man1/* + cp sftpgo output/ + cp sftpgo.db output/sqlite/ + cd output + tar cJvf sftpgo_${SFTPGO_VERSION}_linux_${{ matrix.tar-arch }}.tar.xz * + cd .. + env: + SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} + + - uses: uraimo/run-on-arch-action@v3 + if: ${{ matrix.arch != 'amd64' }} + name: Build for ${{ matrix.arch }} + id: build + with: + arch: ${{ matrix.arch }} + distro: ${{ matrix.distro }} + setup: | + mkdir -p "${PWD}/output" + dockerRunArgs: | + --volume "${PWD}/output:/output" + shell: /bin/bash + install: | + apt-get update -q -y + apt-get install -q -y curl gcc xz-utils + GO_DOWNLOAD_ARCH=${{ matrix.go-arch }} + if [ ${{ matrix.arch}} == 'armv7' ] + then + GO_DOWNLOAD_ARCH=armv6l + fi + curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/go${{ steps.get_version.outputs.GO_VERSION }}.linux-${GO_DOWNLOAD_ARCH}.tar.gz + tar -C /usr/local -xzf go.tar.gz + run: | + export PATH=$PATH:/usr/local/go/bin + go version + go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes,disable_grpc_modules -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo + mkdir -p output/{init,sqlite,bash_completion,zsh_completion} + echo "For documentation please take a look here:" > output/README.txt + echo "" >> output/README.txt + echo "https://github.com/drakkan/sftpgo/blob/${{ steps.get_version.outputs.SFTPGO_VERSION }}/README.md" >> output/README.txt + cp LICENSE output/ + cp NOTICE output/ + cp sftpgo.json output/ + cp -r templates output/ + cp -r static output/ + cp -r openapi output/ + cp init/sftpgo.service output/init/ + ./sftpgo initprovider + ./sftpgo gen completion bash > output/bash_completion/sftpgo + ./sftpgo gen completion zsh > output/zsh_completion/_sftpgo + ./sftpgo gen man -d output/man/man1 + gzip output/man/man1/* + cp sftpgo output/ + cp sftpgo.db output/sqlite/ + cd output + tar cJvf sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_${{ matrix.tar-arch }}.tar.xz * + cd .. + + - name: Upload build artifact for ${{ matrix.arch }} + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_${{ matrix.tar-arch }}.tar.xz + path: ./output/sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_${{ matrix.tar-arch }}.tar.xz + retention-days: 1 + + - name: Build Packages + id: build_linux_pkgs + run: | + export NFPM_ARCH=${{ matrix.go-arch }} + cd pkgs + ./build.sh + PKG_VERSION=${SFTPGO_VERSION:1} + echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT + env: + SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} + + - name: Upload Deb Package + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.build_linux_pkgs.outputs.pkg-version }}-1_${{ matrix.deb-arch}}.deb + path: ./pkgs/dist/deb/sftpgo_${{ steps.build_linux_pkgs.outputs.pkg-version }}-1_${{ matrix.deb-arch}}.deb + retention-days: 1 + + - name: Upload RPM Package + uses: actions/upload-artifact@v7 + with: + name: sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-1.${{ matrix.rpm-arch}}.rpm + path: ./pkgs/dist/rpm/sftpgo-${{ steps.build_linux_pkgs.outputs.pkg-version }}-1.${{ matrix.rpm-arch}}.rpm + retention-days: 1 + + prepare-linux-bundle: + name: Prepare Linux bundle + needs: prepare-linux + runs-on: ubuntu-latest + + steps: + - name: Get versions + id: get_version + run: | + echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT + shell: bash + + - name: Download amd64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_x86_64.tar.xz + + - name: Download arm64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_arm64.tar.xz + + - name: Download ppc64le artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_ppc64le.tar.xz + + - name: Download armv7 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_armv7.tar.xz + + - name: Build bundle + shell: bash + run: | + mkdir -p bundle/{arm64,ppc64le,armv7} + cd bundle + tar xvf ../sftpgo_${SFTPGO_VERSION}_linux_x86_64.tar.xz + cd arm64 + tar xvf ../../sftpgo_${SFTPGO_VERSION}_linux_arm64.tar.xz sftpgo + cd ../ppc64le + tar xvf ../../sftpgo_${SFTPGO_VERSION}_linux_ppc64le.tar.xz sftpgo + cd ../armv7 + tar xvf ../../sftpgo_${SFTPGO_VERSION}_linux_armv7.tar.xz sftpgo + cd .. + tar cJvf sftpgo_${SFTPGO_VERSION}_linux_bundle.tar.xz * + cd .. + env: + SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} + + - name: Upload Linux bundle + uses: actions/upload-artifact@v7 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_bundle.tar.xz + path: ./bundle/sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_bundle.tar.xz + retention-days: 1 + + create-release: + name: Release + needs: [prepare-linux-bundle, prepare-sources-with-deps, prepare-mac, prepare-windows] + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v6 + - name: Get versions + id: get_version + run: | + SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//} + PKG_VERSION=${SFTPGO_VERSION:1} + echo "SFTPGO_VERSION=${SFTPGO_VERSION}" >> $GITHUB_OUTPUT + echo "PKG_VERSION=${PKG_VERSION}" >> $GITHUB_OUTPUT + shell: bash + + - name: Download amd64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_x86_64.tar.xz + + - name: Download arm64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_arm64.tar.xz + + - name: Download ppc64le artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_ppc64le.tar.xz + + - name: Download armv7 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_armv7.tar.xz + + - name: Download Linux bundle artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_linux_bundle.tar.xz + + - name: Download Deb amd64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_amd64.deb + + - name: Download Deb arm64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_arm64.deb + + - name: Download Deb ppc64le artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_ppc64el.deb + + - name: Download Deb armv7 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.PKG_VERSION }}-1_armhf.deb + + - name: Download RPM x86_64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.x86_64.rpm + + - name: Download RPM aarch64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.aarch64.rpm + + - name: Download RPM ppc64le artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.ppc64le.rpm + + - name: Download RPM armv7 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo-${{ steps.get_version.outputs.PKG_VERSION }}-1.armv7hl.rpm + + - name: Download macOS x86_64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_macOS_x86_64.tar.xz + + - name: Download macOS arm64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_macOS_arm64.tar.xz + + - name: Download Windows installer x86_64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_x86_64.exe + + - name: Download Windows installer arm64 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_arm64.exe + + - name: Download Windows installer x86 artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_x86.exe + + - name: Download Windows portable artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_windows_portable.zip + + - name: Download source with deps artifact + uses: actions/download-artifact@v8 + with: + name: sftpgo_${{ steps.get_version.outputs.SFTPGO_VERSION }}_src_with_deps.tar.xz + + - name: Create release + run: | + mv sftpgo_windows_x86_64.exe sftpgo_${SFTPGO_VERSION}_windows_x86_64.exe + mv sftpgo_windows_arm64.exe sftpgo_${SFTPGO_VERSION}_windows_arm64.exe + mv sftpgo_windows_x86.exe sftpgo_${SFTPGO_VERSION}_windows_x86.exe + mv sftpgo_portable.zip sftpgo_${SFTPGO_VERSION}_windows_portable.zip + gh release create "${SFTPGO_VERSION}" -t "${SFTPGO_VERSION}" + gh release upload "${SFTPGO_VERSION}" sftpgo_*.xz --clobber + gh release upload "${SFTPGO_VERSION}" sftpgo-*.rpm --clobber + gh release upload "${SFTPGO_VERSION}" sftpgo_*.deb --clobber + gh release upload "${SFTPGO_VERSION}" sftpgo_*.exe --clobber + gh release upload "${SFTPGO_VERSION}" sftpgo_*.zip --clobber + gh release view "${SFTPGO_VERSION}" + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f61b3864 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# compilation output +sftpgo +sftpgo.exe diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..14f9c1f9 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,66 @@ +version: "2" +run: + issues-exit-code: 1 + tests: true +linters: + enable: + - bodyclose + - dogsled + - dupl + - goconst + - gocyclo + - misspell + - revive + - rowserrcheck + - unconvert + - unparam + - whitespace + settings: + dupl: + threshold: 150 + errcheck: + check-type-assertions: false + check-blank: false + goconst: + min-len: 3 + min-occurrences: 3 + gocyclo: + min-complexity: 15 + # https://golangci-lint.run/usage/linters/#revive + revive: + rules: + - name: var-naming + severity: warning + disabled: true + exclude: [""] + arguments: + - ["ID"] # AllowList + - ["VM"] # DenyList + - - upper-case-const: true + - - skip-package-name-checks: true + exclusions: + generated: lax + presets: + - common-false-positives + - legacy + - std-error-handling + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + enable: + - gofmt + - goimports + settings: + gofmt: + simplify: true + goimports: + local-prefixes: + - github.com/drakkan/sftpgo + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8beb780f..00000000 --- a/.travis.yml +++ /dev/null @@ -1,23 +0,0 @@ -language: go - -os: - - linux - - osx - -go: - - "1.12.x" - -env: - - GO111MODULE=on - -before_script: - - sqlite3 sftpgo.db 'CREATE TABLE "users" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE, "password" varchar(255) NULL, "public_keys" text NULL, "home_dir" varchar(255) NOT NULL, "uid" integer NOT NULL, "gid" integer NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL);' - -install: - - go get -v -t ./... - -script: - - go test -v ./... -coverprofile=coverage.txt -covermode=atomic - -after_success: - - bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..2d0e89f2 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @drakkan diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..46619004 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +support@sftpgo.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..65371870 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,67 @@ +FROM golang:1.26-trixie AS builder + +ENV GOFLAGS="-mod=readonly" + +RUN apt-get update && apt-get -y upgrade && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /workspace +WORKDIR /workspace + +ARG GOPROXY + +COPY go.mod go.sum ./ +RUN go mod download && go mod verify + +ARG COMMIT_SHA + +# This ARG allows to disable some optional features and it might be useful if you build the image yourself. +# For example you can disable S3 and GCS support like this: +# --build-arg FEATURES=nos3,nogcs +ARG FEATURES + +COPY . . + +RUN set -xe && \ + export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ + go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo + +# Set to "true" to download the "official" plugins in /usr/local/bin +ARG DOWNLOAD_PLUGINS=false + +RUN if [ "${DOWNLOAD_PLUGINS}" = "true" ]; then apt-get update && apt-get install --no-install-recommends -y curl && ./docker/scripts/download-plugins.sh; fi + +FROM debian:trixie-slim + +# Set to "true" to install jq +ARG INSTALL_OPTIONAL_PACKAGES=false + +RUN apt-get update && apt-get -y upgrade && apt-get install --no-install-recommends -y ca-certificates media-types && rm -rf /var/lib/apt/lists/* + +RUN if [ "${INSTALL_OPTIONAL_PACKAGES}" = "true" ]; then apt-get update && apt-get install --no-install-recommends -y jq && rm -rf /var/lib/apt/lists/*; fi + +RUN mkdir -p /etc/sftpgo /var/lib/sftpgo /usr/share/sftpgo /srv/sftpgo/data /srv/sftpgo/backups + +RUN groupadd --system -g 1000 sftpgo && \ + useradd --system --gid sftpgo --no-create-home \ + --home-dir /var/lib/sftpgo --shell /usr/sbin/nologin \ + --comment "SFTPGo user" --uid 1000 sftpgo + +COPY --from=builder /workspace/sftpgo.json /etc/sftpgo/sftpgo.json +COPY --from=builder /workspace/templates /usr/share/sftpgo/templates +COPY --from=builder /workspace/static /usr/share/sftpgo/static +COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi +COPY --from=builder /workspace/sftpgo /usr/local/bin/sftpgo-plugin-* /usr/local/bin/ + +# Log to the stdout so the logs will be available using docker logs +ENV SFTPGO_LOG_FILE_PATH="" + +# Modify the default configuration file +RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' /etc/sftpgo/sftpgo.json && \ + sed -i 's|"backups"|"/srv/sftpgo/backups"|' /etc/sftpgo/sftpgo.json + +RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups + +WORKDIR /var/lib/sftpgo +USER 1000:1000 + +CMD ["sftpgo", "serve"] diff --git a/Dockerfile.alpine b/Dockerfile.alpine new file mode 100644 index 00000000..494442dc --- /dev/null +++ b/Dockerfile.alpine @@ -0,0 +1,60 @@ +FROM golang:1.26-alpine3.23 AS builder + +ENV GOFLAGS="-mod=readonly" + +RUN apk -U upgrade --no-cache && apk add --update --no-cache bash ca-certificates curl git gcc g++ + +RUN mkdir -p /workspace +WORKDIR /workspace + +ARG GOPROXY + +COPY go.mod go.sum ./ +RUN go mod download && go mod verify + +ARG COMMIT_SHA + +# This ARG allows to disable some optional features and it might be useful if you build the image yourself. +# For example you can disable S3 and GCS support like this: +# --build-arg FEATURES=nos3,nogcs +ARG FEATURES + +COPY . . + +RUN set -xe && \ + export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ + go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo + +FROM alpine:3.23 + +# Set to "true" to install jq +ARG INSTALL_OPTIONAL_PACKAGES=false + +RUN apk -U upgrade --no-cache && apk add --update --no-cache ca-certificates tzdata mailcap + +RUN if [ "${INSTALL_OPTIONAL_PACKAGES}" = "true" ]; then apk add --update --no-cache jq; fi + +RUN mkdir -p /etc/sftpgo /var/lib/sftpgo /usr/share/sftpgo /srv/sftpgo/data /srv/sftpgo/backups + +RUN addgroup -g 1000 -S sftpgo && \ + adduser -u 1000 -h /var/lib/sftpgo -s /sbin/nologin -G sftpgo -S -D -H -g "SFTPGo user" sftpgo + +COPY --from=builder /workspace/sftpgo.json /etc/sftpgo/sftpgo.json +COPY --from=builder /workspace/templates /usr/share/sftpgo/templates +COPY --from=builder /workspace/static /usr/share/sftpgo/static +COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi +COPY --from=builder /workspace/sftpgo /usr/local/bin/ + +# Log to the stdout so the logs will be available using docker logs +ENV SFTPGO_LOG_FILE_PATH="" + +# Modify the default configuration file +RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' /etc/sftpgo/sftpgo.json && \ + sed -i 's|"backups"|"/srv/sftpgo/backups"|' /etc/sftpgo/sftpgo.json + +RUN chown -R sftpgo:sftpgo /etc/sftpgo /srv/sftpgo && chown sftpgo:sftpgo /var/lib/sftpgo && chmod 700 /srv/sftpgo/backups + +WORKDIR /var/lib/sftpgo +USER 1000:1000 + +CMD ["sftpgo", "serve"] diff --git a/Dockerfile.distroless b/Dockerfile.distroless new file mode 100644 index 00000000..36bc1406 --- /dev/null +++ b/Dockerfile.distroless @@ -0,0 +1,57 @@ +FROM golang:1.26-trixie AS builder + +ENV CGO_ENABLED=0 GOFLAGS="-mod=readonly" + +RUN apt-get update && apt-get -y upgrade && apt-get install --no-install-recommends -y media-types && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /workspace +WORKDIR /workspace + +ARG GOPROXY + +COPY go.mod go.sum ./ +RUN go mod download && go mod verify + +ARG COMMIT_SHA + +# This ARG allows to disable some optional features and it might be useful if you build the image yourself. +# For this variant we disable SQLite support since it requires CGO and so a C runtime which is not installed +# in distroless/static-* images +ARG FEATURES + +COPY . . + +RUN set -xe && \ + export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ + go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo + +# Modify the default configuration file +RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' sftpgo.json && \ + sed -i 's|"backups"|"/srv/sftpgo/backups"|' sftpgo.json && \ + sed -i 's|"sqlite"|"bolt"|' sftpgo.json + +RUN mkdir /etc/sftpgo /var/lib/sftpgo /srv/sftpgo + +FROM gcr.io/distroless/static-debian13 + +COPY --from=builder --chown=1000:1000 /etc/sftpgo /etc/sftpgo +COPY --from=builder --chown=1000:1000 /srv/sftpgo /srv/sftpgo +COPY --from=builder --chown=1000:1000 /var/lib/sftpgo /var/lib/sftpgo +COPY --from=builder --chown=1000:1000 /workspace/sftpgo.json /etc/sftpgo/sftpgo.json +COPY --from=builder /workspace/templates /usr/share/sftpgo/templates +COPY --from=builder /workspace/static /usr/share/sftpgo/static +COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi +COPY --from=builder /workspace/sftpgo /usr/local/bin/ +COPY --from=builder /etc/mime.types /etc/mime.types + +# Log to the stdout so the logs will be available using docker logs +ENV SFTPGO_LOG_FILE_PATH="" +# These env vars are required to avoid the following error when calling user.Current(): +# unable to get the current user: user: Current requires cgo or $USER set in environment +ENV USER=sftpgo +ENV HOME=/var/lib/sftpgo + +WORKDIR /var/lib/sftpgo +USER 1000:1000 + +CMD ["sftpgo", "serve"] \ No newline at end of file diff --git a/LICENSE b/LICENSE index f288702d..29ebfa54 100644 --- a/LICENSE +++ b/LICENSE @@ -1,5 +1,5 @@ - GNU GENERAL PUBLIC LICENSE - Version 3, 29 June 2007 + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies @@ -7,17 +7,15 @@ Preamble - The GNU General Public License is a free, copyleft license for -software and other kinds of works. + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, -the GNU General Public License is intended to guarantee your freedom to +our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free -software for all its users. We, the Free Software Foundation, use the -GNU General Public License for most of our software; it applies also to -any other work released this way by its authors. You can apply it to -your programs, too. +software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you @@ -26,44 +24,34 @@ them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. - To protect your rights, we need to prevent others from denying you -these rights or asking you to surrender the rights. Therefore, you have -certain responsibilities if you distribute copies of the software, or if -you modify it: responsibilities to respect the freedom of others. + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. - For example, if you distribute copies of such a program, whether -gratis or for a fee, you must pass on to the recipients the same -freedoms that you received. You must make sure that they, too, receive -or can get the source code. And you must show them these terms so they -know their rights. + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. - Developers that use the GNU GPL protect your rights with two steps: -(1) assert copyright on the software, and (2) offer you this License -giving you legal permission to copy, distribute and/or modify it. + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. - For the developers' and authors' protection, the GPL clearly explains -that there is no warranty for this free software. For both users' and -authors' sake, the GPL requires that modified versions be marked as -changed, so that their problems will not be attributed erroneously to -authors of previous versions. - - Some devices are designed to deny users access to install or run -modified versions of the software inside them, although the manufacturer -can do so. This is fundamentally incompatible with the aim of -protecting users' freedom to change the software. The systematic -pattern of such abuse occurs in the area of products for individuals to -use, which is precisely where it is most unacceptable. Therefore, we -have designed this version of the GPL to prohibit the practice for those -products. If such problems arise substantially in other domains, we -stand ready to extend this provision to those domains in future versions -of the GPL, as needed to protect the freedom of users. - - Finally, every program is threatened constantly by software patents. -States should not allow patents to restrict development and use of -software on general-purpose computers, but in those that do, we wish to -avoid the special danger that patents applied to a free program could -make it effectively proprietary. To prevent this, the GPL assures that -patents cannot be used to render the program non-free. + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. The precise terms and conditions for copying, distribution and modification follow. @@ -72,7 +60,7 @@ modification follow. 0. Definitions. - "This License" refers to version 3 of the GNU General Public License. + "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. @@ -549,35 +537,45 @@ to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. - 13. Use with the GNU Affero General Public License. + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed -under version 3 of the GNU Affero General Public License into a single +under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, -but the special requirements of the GNU Affero General Public License, -section 13, concerning interaction through a network will apply to the -combination as such. +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of -the GNU General Public License from time to time. Such new versions will -be similar in spirit to the present version, but may differ in detail to +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU General +Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the -GNU General Public License, you may choose any version ever published +GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future -versions of the GNU General Public License can be used, that proxy's +versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. @@ -635,40 +633,29 @@ the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. + GNU Affero General Public License for more details. - You should have received a copy of the GNU General Public License + You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. - If the program does terminal interaction, make it output a short -notice like this when it starts in an interactive mode: - - Copyright (C) - This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it - under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate -parts of the General Public License. Of course, your program's commands -might be different; for a GUI interface, you would use an "about box". + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU GPL, see -. - - The GNU General Public License does not permit incorporating your program -into proprietary programs. If your program is a subroutine library, you -may consider it more useful to permit linking proprietary applications with -the library. If this is what you want to do, use the GNU Lesser General -Public License instead of this License. But first, please read -. +For more information on this, and how to apply and follow the GNU AGPL, see +. \ No newline at end of file diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..f683d86f --- /dev/null +++ b/NOTICE @@ -0,0 +1,12 @@ +Additional terms under GNU AGPL version 3 section 7.3(b) and 13.1: + +If you have included SFTPGo so that it is offered through any network +interactions, including by means of an external user interface, or +any other integration, even without modifying its source code and then +SFTPGo is partially, fully or optionally configured via your frontend, +you must provide reasonable but clear attribution to the SFTPGo project +and its author(s), not imply any endorsement by or affiliation with the +SFTPGo project, and you must prominently offer all users interacting +with it remotely through a computer network an opportunity to receive +the Corresponding Source of the SFTPGo version you include by providing +a link to the Corresponding Source in the SFTPGo source code repository. diff --git a/README.md b/README.md index 1c7d787b..bee0bc4f 100644 --- a/README.md +++ b/README.md @@ -1,344 +1,131 @@ # SFTPGo -[![Build Status](https://travis-ci.org/drakkan/sftpgo.svg?branch=master)](https://travis-ci.org/drakkan/sftpgo) [![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/master/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/master) [![Go Report Card](https://goreportcard.com/badge/github.com/drakkan/sftpgo)](https://goreportcard.com/report/github.com/drakkan/sftpgo) [![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) -Full featured and highly configurable SFTP server +[![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg)](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg) +[![License: AGPL-3.0-only](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) +[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) -## Features +Full-featured and highly configurable event-driven file transfer solution. Server protocols: SFTP, HTTP/S, FTP/S, WebDAV. Storage backends: local filesystem, encrypted local filesystem, S3 (compatible) Object Storage, Google Cloud Storage, Azure Blob Storage, other SFTP servers. -- Each account is chrooted to his Home Dir. -- SFTP accounts are virtual accounts stored in a "data provider". -- SQLite, MySQL, PostgreSQL and bbolt (key/value store in pure Go) data providers are supported. -- Public key and password authentication. Multiple public keys per user are supported. -- Quota support: accounts can have individual quota expressed as max total size and/or max number of files. -- Bandwidth throttling is supported, with distinct settings for upload and download. -- Per user maximum concurrent sessions. -- Per user permissions: list directories content, upload, download, delete, rename, create directories, create symlinks can be enabled or disabled. -- Per user files/folders ownership: you can map all the users to the system account that runs SFTPGo (all platforms are supported) or you can run SFTPGo as root user and map each user or group of users to a different system account (*NIX only). -- Configurable custom commands and/or HTTP notifications on upload, download, delete or rename. -- Automatically terminating idle connections. -- Atomic uploads are configurable. -- Optional SCP support. -- REST API for users and quota management and real time reports for the active connections with possibility of forcibly closing a connection. -- Configuration is a your choice: JSON, TOML, YAML, HCL, envfile are supported. -- Log files are accurate and they are saved in the easily parsable JSON format. +With SFTPGo you can leverage local and cloud storage backends for exchanging and storing files internally or with business partners using the same tools and processes you are already familiar with. -## Platforms +## Project Status & Editions -SFTPGo is developed and tested on Linux. After each commit the code is automatically built and tested on Linux and macOS using Travis CI. -Regularly the test cases are manually executed and pass on Windows. Other UNIX variants such as *BSD should work too. +SFTPGo is an open-source project with a sustainable business model. We offer two editions to suit different requirements, ensuring the project remains healthy and maintained for everyone. -## Requirements +### Open Source (Community) -- Go 1.12 or higher. -- A suitable SQL server or key/value store to use as data provider: PostreSQL 9.4+ or MySQL 5.6+ or SQLite 3.x or bbolt 1.3.x +Free, Copyleft (AGPLv3), Community Supported. The Community edition is a fully functional, production-ready solution widely adopted worldwide. It includes all the core protocols, storage backends, and the WebAdmin/WebClient UIs. It is ideal for: -## Installation +- Standard file transfer needs. +- Integrating storage backends (S3, GCS, Azure Blob) with legacy protocols. +- Projects that are comfortable with AGPLv3 licensing. -Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell: +### SFTPGo Enterprise -```bash -$ go get -u github.com/drakkan/sftpgo -``` +Commercial License, Professional Support, ISO 27001 Vendor. The Enterprise edition is built on the same core but extends it for mission-critical environments, compliance-heavy industries, and advanced workflows. It is a drop-in replacement (seamless upgrade). -Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`. +| Feature | Open Source (Community) | Enterprise Edition | +| :--- | :--- | :--- | +| **License Type** | AGPLv3 (Copyleft) | **Commercial License**
Proprietary/No Copyleft | +| **Vendor Compliance** | Not Applicable
Community Project | **Certified Vendor**
ISO 27001 & Supply Chain Validation | +| **Support** | Community (GitHub) | **Direct from Authors** | +| **Cloud Storage Engine** | Standard | **High Performance & Scalable**
In-memory streaming (no local temp files) and up to 70% faster | +| **High Availability (HA)** | Standard
Shared DB & Storage | **Advanced**
Enhanced event handling and optimized instance coordination | +| **Automation Logic** | Simple Placeholders | **Dynamic Logic & Virtual Folders**
Conditions, loops, route data across storage backends | +| **Data Lifecycle** | Delete / Retain | **Smart Archiving**
Move data to external Cloud/SFTP storage via Virtual Folders | +| **Email Data Ingestion** | - | **Native IMAP Integration**
Auto-extract attachments from email to storage | +| **Public Sharing** | Standard Links | **Advanced & Collaborative**
Email Authentication & Group Delegation | +| **Data Protection** | - | **Encryption & Scanning**
Automated PGP, Antivirus & DLP via ICAP | +| **Advanced Identity (SSO)** | Standard | **Extended Controls**
Advanced Single Sign-On parameters | +| **Document Editing** | - | **Included**
View, edit, and co-author in browser | -SFTPGo depends on [go-sqlite3](https://github.com/mattn/go-sqlite3) that is a CGO package and so it requires a `C` compiler at build time. -On Linux and macOS a compiler is easy to install or already installed, on Windows you need to download [MinGW-w64](https://sourceforge.net/projects/mingw-w64/files/) and build SFTPGo from its command prompt. - -The compiler is a build time only dependency, it is not not required at runtime. - -If you don't need SQLite, you can also get/build SFTPGo setting the environment variable `GCO_ENABLED` to 0, this way SQLite support will be disabled but PostgreSQL, MySQL and bbolt will work and you don't need a `C` compiler for building. - -Version info, such as git commit and build date, can be embedded setting the following string variables at build time: - -- `github.com/drakkan/sftpgo/utils.commit` -- `github.com/drakkan/sftpgo/utils.date` - -For example you can build using the following command: - -```bash -go build -i -ldflags "-s -w -X github.com/drakkan/sftpgo/utils.commit=`git describe --tags --always --dirty` -X github.com/drakkan/sftpgo/utils.date=`date -u +%FT%TZ`" -o sftpgo -``` - -and you will get a version that includes git commit and build date like this one: - -```bash -sftpgo -v -SFTPGo version: 0.9.0-dev-90607d4-dirty-2019-08-08T19:28:36Z -``` - -For Linux, a systemd sample [service](https://github.com/drakkan/sftpgo/tree/master/init/sftpgo.service "systemd service") can be found inside the source tree. - -Alternately you can use distro packages: - -- Arch Linux PKGBUILD is available on [AUR](https://aur.archlinux.org/packages/sftpgo-git/ "SFTPGo") - -## Configuration - -The `sftpgo` executable can be used this way: - -```bash -Usage: - sftpgo [command] - -Available Commands: - help Help about any command - serve Start the SFTP Server - -Flags: - -h, --help help for sftpgo - -v, --version -``` - -The `serve` subcommand supports the following flags: - -- `--config-dir` string. Location of the config dir. This directory should contain the `sftpgo` configuration file and is used as the base for files with a relative path (eg. the private keys for the SFTP server, the SQLite or bblot database if you use SQLite or bbolt as data provider). The default value is "." or the value of `SFTPGO_CONFIG_DIR` environment variable. -- `--config-file` string. Name of the configuration file. It must be the name of a file stored in config-dir not the absolute path to the configuration file. The specified file name must have no extension we automatically load JSON, YAML, TOML, HCL and Java properties. The default value is "sftpgo" (and therefore `sftpgo.json`, `sftpgo.yaml` and so on are searched) or the value of `SFTPGO_CONFIG_FILE` environment variable. -- `--log-compress` boolean. Determine if the rotated log files should be compressed using gzip. Default `false` or the value of `SFTPGO_LOG_COMPRESS` environment variable (1 or `true`, 0 or `false`). -- `--log-file-path` string. Location for the log file, default "sftpgo.log" or the value of `SFTPGO_LOG_FILE_PATH` environment variable. -- `--log-max-age` int. Maximum number of days to retain old log files. Default 28 or the value of `SFTPGO_LOG_MAX_AGE` environment variable. -- `--log-max-backups` int. Maximum number of old log files to retain. Default 5 or the value of `SFTPGO_LOG_MAX_BACKUPS` environment variable. -- `--log-max-size` int. Maximum size in megabytes of the log file before it gets rotated. Default 10 or the value of `SFTPGO_LOG_MAX_SIZE` environment variable. -- `--log-verbose` boolean. Enable verbose logs. Default `true` or the value of `SFTPGO_LOG_VERBOSE` environment variable (1 or `true`, 0 or `false`). - -If you don't configure any private host keys, the daemon will use `id_rsa` in the configuration directory. If that file doesn't exist, the daemon will attempt to autogenerate it (if the user that executes SFTPGo has write access to the config-dir). The server supports any private key format supported by [`crypto/ssh`](https://github.com/golang/crypto/blob/master/ssh/keys.go#L32). - -Before starting `sftpgo` a dataprovider must be configured. - -Sample SQL scripts to create the required database structure can be found inside the source tree [sql](https://github.com/drakkan/sftpgo/tree/master/sql "sql") directory. The SQL scripts filename's is, by convention, the date as `YYYYMMDD` and the suffix `.sql`. You need to apply all the SQL scripts for your database ordered by name, for example `20190706.sql` must be applied before `20190728.sql` and so on. - -The `sftpgo` configuration file contains the following sections: - -- **"sftpd"**, the configuration for the SFTP server - - `bind_port`, integer. The port used for serving SFTP requests. Default: 2022 - - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: "" - - `idle_timeout`, integer. Time in minutes after which an idle client will be disconnected. Default: 15 - - `max_auth_tries` integer. Maximum number of authentication attempts permitted per connection. If set to a negative number, the number of attempts are unlimited. If set to zero, the number of attempts are limited to 6. - - `umask`, string. Umask for the new files and directories. This setting has no effect on Windows. Default: "0022" - - `banner`, string. Identification string used by the server. Default "SFTPGo" - - `upload_mode` integer. 0 means standard, the files are uploaded directly to the requested path. 1 means atomic: files are uploaded to a temporary path and renamed to the requested path when the client ends the upload. Atomic mode avoids problems such as a web server that serves partial files when the files are being uploaded - - `actions`, struct. It contains the command to execute and/or the HTTP URL to notify and the trigger conditions - - `execute_on`, list of strings. Valid values are `download`, `upload`, `delete`, `rename`. On folder deletion a `delete` notification will be sent for each deleted file. Leave empty to disable actions. - - `command`, string. Absolute path to the command to execute. Leave empty to disable. The command is invoked with the following arguments: - - `action`, any valid `execute_on` string - - `username`, user who did the action - - `path` to the affected file. For `rename` action this is the old file name - - `target_path`, non empty for `rename` action, this is the new file name - - `http_notification_url`, a valid URL. An HTTP GET request will be executed to this URL. Leave empty to disable. The query string will contain the following parameters that have the same meaning of the command's arguments: - - `action` - - `username` - - `path` - - `target_path`, added for `rename` action only - - `keys`, struct array. It contains the daemon's private keys. If empty or missing the daemon will search or try to generate `id_rsa` in the configuration directory. - - `private_key`, path to the private key file. It can be a path relative to the config dir or an absolute one. - - `enable_scp`, boolean. Default disabled. Set to `true` to enable SCP support. SCP is an experimental feature, we have our own SCP implementation since we can't rely on `scp` system command to proper handle permissions, quota and user's home dir restrictions. The SCP protocol is quite simple but there is no official docs about it, so we need more testing and feedbacks before enabling it by default. We may not handle some borderline cases or have sneaky bugs. Please do accurate tests yourself before enabling SCP and let us known if something does not work as expected for your use cases. SCP between two remote hosts is supported using the `-3` scp option. -- **"data_provider"**, the configuration for the data provider - - `driver`, string. Supported drivers are `sqlite`, `mysql`, `postgresql`, `bolt` - - `name`, string. Database name. For driver `sqlite` this can be the database name relative to the config dir or the absolute path to the SQLite database. - - `host`, string. Database host. Leave empty for driver `sqlite` and `bolt` - - `port`, integer. Database port. Leave empty for driver `sqlite` and `bolt` - - `username`, string. Database user. Leave empty for driver `sqlite` and `bolt` - - `password`, string. Database password. Leave empty for driver `sqlite` and `bolt` - - `sslmode`, integer. Used for drivers `mysql` and `postgresql`. 0 disable SSL/TLS connections, 1 require ssl, 2 set ssl mode to `verify-ca` for driver `postgresql` and `skip-verify` for driver `mysql`, 3 set ssl mode to `verify-full` for driver `postgresql` and `preferred` for driver `mysql` - - `connectionstring`, string. Provide a custom database connection string. If not empty this connection string will be used instead of build one using the previous parameters. Leave empty for driver `bolt` - - `users_table`, string. Database table for SFTP users - - `manage_users`, integer. Set to 0 to disable users management, 1 to enable - - `track_quota`, integer. Set the preferred way to track users quota between the following choices: - - 0, disable quota tracking. REST API to scan user dir and update quota will do nothing - - 1, quota is updated each time a user upload or delete a file even if the user has no quota restrictions - - 2, quota is updated each time a user upload or delete a file but only for users with quota restrictions. With this configuration the "quota scan" REST API can still be used to periodically update space usage for users without quota restrictions -- **"httpd"**, the configuration for the HTTP server used to serve REST API - - `bind_port`, integer. The port used for serving HTTP requests. Set to 0 to disable HTTP server. Default: 8080 - - `bind_address`, string. Leave blank to listen on all available network interfaces. Default: "127.0.0.1" - -Here is a full example showing the default config in JSON format: - -```json -{ - "sftpd": { - "bind_port": 2022, - "bind_address": "", - "idle_timeout": 15, - "max_auth_tries": 0, - "umask": "0022", - "banner": "SFTPGo", - "actions": { - "execute_on": [], - "command": "", - "http_notification_url": "" - }, - "keys": [], - "enable_scp": false - }, - "data_provider": { - "driver": "sqlite", - "name": "sftpgo.db", - "host": "", - "port": 5432, - "username": "", - "password": "", - "sslmode": 0, - "connection_string": "", - "users_table": "users", - "manage_users": 1, - "track_quota": 2 - }, - "httpd": { - "bind_port": 8080, - "bind_address": "127.0.0.1" - } -} -``` - -If you want to use a private key that use an algorithm different from RSA or more than one private key then replace the empty `keys` array with something like this: - -```json -"keys": [ - { - "private_key": "id_rsa" - }, - { - "private_key": "id_ecdsa" - } -] -``` - -The configuration can be read from JSON, TOML, YAML, HCL, envfile and Java properties config files, if your `config-file` flag is set to `sftpgo` (default value) you need to create a configuration file called `sftpgo.json` or `sftpgo.yaml` and so on inside `config-dir`. - -You can also configure all the available options using environment variables, sftpgo will check for environment variables with a name matching the key uppercased and prefixed with the `SFTPGO_`. You need to use `__` to traverse a struct. - -Let's see some examples: - -- To set sftpd `bind_port` you need to define the env var `SFTPGO_SFTPD__BIND_PORT` -- To set the `execute_on` actions you need to define the env var `SFTPGO_SFTPD__ACTIONS__EXECUTE_ON` for example `SFTPGO_SFTPD__ACTIONS__EXECUTE_ON=upload,download` - -To start the SFTP Server with the default values for the command line flags simply use: - -```bash -sftpgo serve -``` - -## Account's configuration properties - -For each account the following properties can be configured: - -- `username` -- `password` used for password authentication. For users created using SFTPGo REST API if the password has no known hashing algo prefix it will be stored using argon2id. SFTPGo supports checking passwords stored with bcrypt and pbkdf2 too. For pbkdf2 the supported format is `$$$$`, where algo is `pbkdf2-sha1` or `pbkdf2-sha256` or `pbkdf2-sha512`. For example the `pbkdf2-sha256` of the word `password` using 150000 iterations and `E86a9YMX3zC7` as salt must be stored as `$pbkdf2-sha256$150000$E86a9YMX3zC7$R5J62hsSq+pYw00hLLPKBbcGXmq7fj5+/M0IFoYtZbo=`. For bcrypt the format must be the one supported by golang's [crypto/bcrypt](https://godoc.org/golang.org/x/crypto/bcrypt) package, for example the password `secret` with cost `14` must be stored as `$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK`. Using the REST API you can send a password hashed as bcrypt or pbkdf2 and it will be stored as is. -- `public_keys` array of public keys. At least one public key or the password is mandatory. -- `home_dir` The user cannot upload or download files outside this directory. Must be an absolute path -- `uid`, `gid`. If sftpgo runs as root system user then the created files and directories will be assigned to this system uid/gid. Ignored on windows and if sftpgo runs as non root user: in this case files and directories for all SFTP users will be owned by the system user that runs sftpgo. -- `max_sessions` maximum concurrent sessions. 0 means unlimited -- `quota_size` maximum size allowed as bytes. 0 means unlimited -- `quota_files` maximum number of files allowed. 0 means unlimited -- `permissions` the following permissions are supported: - - `*` all permission are granted - - `list` list items is allowed - - `download` download files is allowed - - `upload` upload files is allowed - - `delete` delete files or directories is allowed - - `rename` rename files or directories is allowed - - `create_dirs` create directories is allowed - - `create_symlinks` create symbolic links is allowed -- `upload_bandwidth` maximum upload bandwidth as KB/s, 0 means unlimited -- `download_bandwidth` maximum download bandwidth as KB/s, 0 means unlimited - -These properties are stored inside the data provider. If you want to use your existing accounts, you can create a database view. Since a view is read only, you have to disable user management and quota tracking so SFTPGo will never try to write to the view. - -## REST API - -SFTPGo exposes REST API to manage users and quota and to get real time reports for the active connections with possibility of forcibly closing a connection. - -If quota tracking is enabled in `sftpgo` configuration file, then the used size and number of files are updated each time a file is added/removed. If files are added/removed not using SFTP or if you change `track_quota` from `2` to `1`, you can rescan the user home dir and update the used quota using the REST API. - -REST API is designed to run on localhost or on a trusted network, if you need HTTPS or authentication you can setup a reverse proxy using an HTTP Server such as Apache or NGNIX. - -For example you can keep SFTPGo listening on localhost and expose it externally configuring a reverse proxy using Apache HTTP Server this way: - -``` -ProxyPass /api/v1 http://127.0.0.1:8080/api/v1 -ProxyPassReverse /api/v1 http://127.0.0.1:8080/api/v1 -``` - -and you can add authentication with something like this: - -``` - - AuthType Digest - AuthName "Private" - AuthDigestDomain "/api/v1" - AuthDigestProvider file - AuthUserFile "/etc/httpd/conf/auth_digest" - Require valid-user - -``` - -and, of course, you can configure the web server to use HTTPS. - -The OpenAPI 3 schema for the exposed API can be found inside the source tree: [openapi.yaml](https://github.com/drakkan/sftpgo/tree/master/api/schema/openapi.yaml "OpenAPI 3 specs"). - -A sample CLI client for the REST API can be found inside the source tree [scripts](https://github.com/drakkan/sftpgo/tree/master/scripts "scripts") directory. - -You can also generate your own REST client, in your preferred programming language or even bash scripts, using an OpenAPI generator such as [swagger-codegen](https://github.com/swagger-api/swagger-codegen) or [OpenAPI Generator](https://openapi-generator.tech/) - -## Logs - -Inside the log file each line is a JSON struct, each struct has a `sender` fields that identify the log type. - -The logs can be divided into the following categories: - -- **"app logs"**, internal logs used to debug `sftpgo`: - - `sender` string. This is generally the package name that emits the log - - `time` string. Date/time with millisecond precision - - `level` string - - `message` string -- **"transfer logs"**, SFTP/SCP transfer logs: - - `sender` string. `Upload` or `Download` - - `time` string. Date/time with millisecond precision - - `level` string - - `elapsed_ms`, int64. Elapsed time, as milliseconds, for the upload/download - - `size_bytes`, int64. Size, as bytes, of the download/upload - - `username`, string - - `file_path` string - - `connection_id` string. Unique connection identifier - - `protocol` string. `SFTP` or `SCP` -- **"command logs"**, SFTP/SCP command logs: - - `sender` string. `Rename`, `Rmdir`, `Mkdir`, `Symlink`, `Remove` - - `level` string - - `username`, string - - `file_path` string - - `target_path` string - - `connection_id` string. Unique connection identifier - - `protocol` string. `SFTP` or `SCP` -- **"http logs"**, REST API logs: - - `sender` string. `httpd` - - `level` string - - `remote_addr` string. IP and port of the remote client - - `proto` string, for example `HTTP/1.1` - - `method` string. HTTP method (`GET`, `POST`, `PUT`, `DELETE` etc.) - - `user_agent` string - - `uri` string. Full uri - - `resp_status` integer. HTTP response status code - - `resp_size` integer. Size in bytes of the HTTP response - - `elapsed_ms` int64. Elapsed time, as milliseconds, to complete the request - - `request_id` string. Unique request identifier +**Note**: We are committed to keeping the Open Source edition powerful and maintained. The Enterprise edition helps fund the development of the entire SFTPGo ecosystem. + +## Sponsors + +If you rely on SFTPGo in your projects, consider becoming a [sponsor](https://github.com/sponsors/drakkan). + +Your sponsorship helps cover maintenance, security updates and ongoing development of the open-source edition. + +### Thank you to our sponsors + +#### Platinum sponsors + +[Aledade logo](https://www.aledade.com/) +

+[Jump Trading logo](https://www.jumptrading.com/) +

+[WP Engine logo](https://wpengine.com/) + +#### Silver sponsors + +[IDCS logo](https://idcs.ip-paris.fr/) + +#### Bronze sponsors + +[7digital logo](https://www.7digital.com/) +

+[servinga logo](https://servinga.com/) +

+[ReUI logo](https://www.reui.io/) + +## Documentation + +You can explore all supported features and configuration options at [docs.sftpgo.com](https://docs.sftpgo.com/latest/). + +**Note:** The link above refers to the **Community Edition**. +For details on **Enterprise Edition**, please refer to the [Enterprise Documentation](https://docs.sftpgo.com/enterprise/). + +## Support + +- **Community Support**: use [GitHub Discussions](https://github.com/drakkan/sftpgo/discussions) to ask questions, share feedback, and engage with other users. +- **Commercial Support**: If you require guaranteed SLAs, expert guidance, or the advanced features listed above, check out [SFTPGo Enterprise](https://sftpgo.com). + +SFTPGo Enterprise is available as: + +- On-premises: Full control on your infrastructure. More details: [sftpgo.com/on-premises](https://sftpgo.com/on-premises) +- Fully managed SaaS: We handle the infrastructure. More details: [sftpgo.com/saas](https://sftpgo.com/saas) + +## Internationalization + +The translations are available via [Crowdin](https://crowdin.com/project/sftpgo), who have granted us an open source license. + +Before translating please take a look at our contribution [guidelines](https://docs.sftpgo.com/latest/web-interfaces/#internationalization). + +## Release Cadence + +SFTPGo follows a feature-driven release cycle. + +- Enterprise Edition: Receives major new features first and follows a faster [release cadence](https://docs.sftpgo.com/enterprise/changelog/). +- Community Edition: Remains maintained, receiving bug fixes, security updates, and updates to core features. ## Acknowledgements -- [pkg/sftp](https://github.com/pkg/sftp) -- [go-chi](https://github.com/go-chi/chi) -- [zerolog](https://github.com/rs/zerolog) -- [lumberjack](https://gopkg.in/natefinch/lumberjack.v2) -- [argon2id](https://github.com/alexedwards/argon2id) -- [go-sqlite3](https://github.com/mattn/go-sqlite3) -- [go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) -- [bbolt](https://github.com/etcd-io/bbolt) -- [lib/pq](https://github.com/lib/pq) -- [viper](https://github.com/spf13/viper) -- [cobra](https://github.com/spf13/cobra) -- [xid](https://github.com/rs/xid) +SFTPGo makes use of the third party libraries listed inside [go.mod](./go.mod). -Some code was initially taken from [Pterodactyl sftp server](https://github.com/pterodactyl/sftp-server) +We are very grateful to all the people who contributed with ideas and/or pull requests. + +Thank you to [ysura](https://www.ysura.com/) for granting us stable access to a test AWS S3 account. + +Thank you to [KeenThemes](https://keenthemes.com/) for granting us a custom license to use their amazing [themes](https://keenthemes.com/bootstrap-templates) for the SFTPGo WebAdmin and WebClient user interfaces, across both the Open Source and Open Core versions. + +Thank you to [Crowdin](https://crowdin.com/) for granting us an Open Source License. + +Thank you to [Incode](https://www.incode.it/) for helping us to improve the UI/UX. ## License -GNU GPLv3 +SFTPGo source code is licensed under the GNU AGPL-3.0-only with [additional terms](./NOTICE). + +The [theme](https://keenthemes.com/bootstrap-templates) used in WebAdmin and WebClient user interfaces is proprietary, this means: + +- KeenThemes HTML/CSS/JS components are allowed for use only within the SFTPGo product and restricted to be used in a resealable HTML template that can compete with KeenThemes products anyhow. +- The SFTPGo WebAdmin and WebClient user interfaces (HTML, CSS and JS components) based on this theme are allowed for use only within the SFTPGo product and therefore cannot be used in derivative works/products without an explicit grant from the [SFTPGo Team](mailto:support@sftpgo.com). + +More information about [compliance](https://sftpgo.com/compliance.html). + +**Note:** We do not provide legal advice. If you have questions about license compliance or whether your use case is permitted under the license terms, please consult your legal team. + +## Copyright + +Copyright (C) 2019 - 2026 Nicola Murino diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..01650ca6 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,20 @@ +# Security Policy + +## Supported Versions + +We actively maintain the latest stable release of SFTPGo. While we strive to keep the Open Source version secure and up-to-date, maintenance is performed on a best-effort basis by the community and contributors. + +## Scope and Dependency Policy + +Our security advisories focus on vulnerabilities found within the **SFTPGo codebase itself**. + +To ensure the long-term sustainability of the project, we handle upstream dependencies (like the Go standard library, external packages, or Docker base images) as follows: + +- Community Updates: For the Open Source version, vulnerabilities in upstream components (such as the Go standard library or third-party packages) are addressed during our **regular release cycles**. We generally do not provide immediate, out-of-band or ad-hoc releases to address dependency-only CVEs. +- Empowering Users: One of the strengths of SFTPGo being open-source is that you have full control. If your security scanners require an immediate fix, you can always rebuild the project using the latest patched Go toolchain or updated dependencies. +- Compatibility: We are committed to keeping SFTPGo compatible with the latest stable Go compiler. If an upstream fix breaks SFTPGo, fixing that becomes a priority for us. +- Professional Needs: We understand that some organizations have strict compliance requirements or internal SLAs that require guaranteed, immediate response times and out-of-band patches. For these cases, we offer [SFTPGo Enterprise](https://sftpgo.com/on-premises) to cover the additional maintenance and support overhead. + +## Reporting a Vulnerability + +To report (possible) security issues in SFTPGo, please either send a mail to the [SFTPGo Team](mailto:support@sftpgo.com) or use Github's [private reporting feature](https://github.com/drakkan/sftpgo/security/advisories/new). diff --git a/api/api.go b/api/api.go deleted file mode 100644 index 2b280edc..00000000 --- a/api/api.go +++ /dev/null @@ -1,77 +0,0 @@ -// Package api implements REST API for sftpgo. -// REST API allows to manage users and quota and to get real time reports for the active connections -// with possibility of forcibly closing a connection. -// The OpenAPI 3 schema for the exposed API can be found inside the source tree: -// https://github.com/drakkan/sftpgo/tree/master/api/schema/openapi.yaml -package api - -import ( - "net/http" - - "github.com/drakkan/sftpgo/dataprovider" - "github.com/go-chi/chi" - "github.com/go-chi/render" -) - -const ( - logSender = "api" - activeConnectionsPath = "/api/v1/connection" - quotaScanPath = "/api/v1/quota_scan" - userPath = "/api/v1/user" - versionPath = "/api/v1/version" -) - -var ( - router *chi.Mux - dataProvider dataprovider.Provider -) - -// HTTPDConf httpd daemon configuration -type HTTPDConf struct { - // The port used for serving HTTP requests. 0 disable the HTTP server. Default: 8080 - BindPort int `json:"bind_port" mapstructure:"bind_port"` - // The address to listen on. A blank value means listen on all available network interfaces. Default: "127.0.0.1" - BindAddress string `json:"bind_address" mapstructure:"bind_address"` -} - -type apiResponse struct { - Error string `json:"error"` - Message string `json:"message"` - HTTPStatus int `json:"status"` -} - -func init() { - initializeRouter() -} - -// SetDataProvider sets the data provider to use to fetch the data about users -func SetDataProvider(provider dataprovider.Provider) { - dataProvider = provider -} - -func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { - var errorString string - if err != nil { - errorString = err.Error() - } - resp := apiResponse{ - Error: errorString, - Message: message, - HTTPStatus: code, - } - if code != http.StatusOK { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(code) - } - render.JSON(w, r, resp) -} - -func getRespStatus(err error) int { - if _, ok := err.(*dataprovider.ValidationError); ok { - return http.StatusBadRequest - } - if _, ok := err.(*dataprovider.MethodDisabledError); ok { - return http.StatusForbidden - } - return http.StatusInternalServerError -} diff --git a/api/api_test.go b/api/api_test.go deleted file mode 100644 index b26138b0..00000000 --- a/api/api_test.go +++ /dev/null @@ -1,755 +0,0 @@ -package api_test - -import ( - "bytes" - "encoding/json" - "fmt" - "net" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "runtime" - "strconv" - "testing" - "time" - - "github.com/go-chi/render" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog" - - "github.com/drakkan/sftpgo/api" - "github.com/drakkan/sftpgo/config" - "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/sftpd" -) - -const ( - defaultUsername = "test_user" - defaultPassword = "test_password" - testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" - logSender = "APITesting" - userPath = "/api/v1/user" - activeConnectionsPath = "/api/v1/connection" - quotaScanPath = "/api/v1/quota_scan" - versionPath = "/api/v1/version" -) - -var ( - defaultPerms = []string{dataprovider.PermAny} - homeBasePath string - testServer *httptest.Server -) - -func TestMain(m *testing.M) { - if runtime.GOOS == "windows" { - homeBasePath = "C:\\" - } else { - homeBasePath = "/tmp" - } - configDir := ".." - logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") - logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel) - config.LoadConfig(configDir, "") - providerConf := config.GetProviderConf() - - err := dataprovider.Initialize(providerConf, configDir) - if err != nil { - logger.Warn(logSender, "error initializing data provider: %v", err) - os.Exit(1) - } - dataProvider := dataprovider.GetProvider() - httpdConf := config.GetHTTPDConfig() - router := api.GetHTTPRouter() - - httpdConf.BindPort = 8081 - api.SetBaseURL("http://127.0.0.1:8081") - - sftpd.SetDataProvider(dataProvider) - api.SetDataProvider(dataProvider) - - go func() { - logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf) - s := &http.Server{ - Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort), - Handler: router, - ReadTimeout: 300 * time.Second, - WriteTimeout: 300 * time.Second, - MaxHeaderBytes: 1 << 20, // 1MB - } - if err := s.ListenAndServe(); err != nil { - logger.Error(logSender, "could not start HTTP server: %v", err) - } - }() - - testServer = httptest.NewServer(api.GetHTTPRouter()) - defer testServer.Close() - - waitTCPListening(fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort)) - - exitCode := m.Run() - os.Remove(logfilePath) - os.Exit(exitCode) -} - -func TestBasicUserHandling(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - user.MaxSessions = 10 - user.QuotaSize = 4096 - user.QuotaFiles = 2 - user.UploadBandwidth = 128 - user.DownloadBandwidth = 64 - user, _, err = api.UpdateUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to update user: %v", err) - } - users, _, err := api.GetUsers(0, 0, defaultUsername, http.StatusOK) - if err != nil { - t.Errorf("unable to get users: %v", err) - } - if len(users) != 1 { - t.Errorf("number of users mismatch, expected: 1, actual: %v", len(users)) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove: %v", err) - } -} - -func TestAddUserNoCredentials(t *testing.T) { - u := getTestUser() - u.Password = "" - u.PublicKeys = []string{} - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with no credentials: %v", err) - } -} - -func TestAddUserNoUsername(t *testing.T) { - u := getTestUser() - u.Username = "" - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with no home dir: %v", err) - } -} - -func TestAddUserNoHomeDir(t *testing.T) { - u := getTestUser() - u.HomeDir = "" - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with no home dir: %v", err) - } -} - -func TestAddUserInvalidHomeDir(t *testing.T) { - u := getTestUser() - u.HomeDir = "relative_path" - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with invalid home dir: %v", err) - } -} - -func TestAddUserNoPerms(t *testing.T) { - u := getTestUser() - u.Permissions = []string{} - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with no perms: %v", err) - } -} - -func TestAddUserInvalidPerms(t *testing.T) { - u := getTestUser() - u.Permissions = []string{"invalidPerm"} - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with no perms: %v", err) - } -} - -func TestUserPublicKey(t *testing.T) { - u := getTestUser() - invalidPubKey := "invalid" - validPubKey := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" - u.PublicKeys = []string{invalidPubKey} - _, _, err := api.AddUser(u, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error adding user with invalid pub key: %v", err) - } - u.PublicKeys = []string{validPubKey} - user, _, err := api.AddUser(u, http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - user.PublicKeys = []string{validPubKey, invalidPubKey} - _, _, err = api.UpdateUser(user, http.StatusBadRequest) - if err != nil { - t.Errorf("update user with invalid public key must fail: %v", err) - } - user.PublicKeys = []string{validPubKey, validPubKey, validPubKey} - _, _, err = api.UpdateUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to update user: %v", err) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove: %v", err) - } -} - -func TestUpdateUser(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - user.HomeDir = filepath.Join(homeBasePath, "testmod") - user.UID = 33 - user.GID = 101 - user.MaxSessions = 10 - user.QuotaSize = 4096 - user.QuotaFiles = 2 - user.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload} - user.UploadBandwidth = 1024 - user.DownloadBandwidth = 512 - user, _, err = api.UpdateUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to update user: %v", err) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove: %v", err) - } -} - -func TestUpdateUserNoCredentials(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - user.Password = "" - user.PublicKeys = []string{} - // password and public key will be omitted from json serialization if empty and so they will remain unchanged - // and no validation error will be raised - _, _, err = api.UpdateUser(user, http.StatusOK) - if err != nil { - t.Errorf("unexpected error updating user with no credentials: %v", err) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove: %v", err) - } -} - -func TestUpdateUserEmptyHomeDir(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - user.HomeDir = "" - _, _, err = api.UpdateUser(user, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error updating user with empty home dir: %v", err) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove: %v", err) - } -} - -func TestUpdateUserInvalidHomeDir(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - user.HomeDir = "relative_path" - _, _, err = api.UpdateUser(user, http.StatusBadRequest) - if err != nil { - t.Errorf("unexpected error updating user with empty home dir: %v", err) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove: %v", err) - } -} - -func TestUpdateNonExistentUser(t *testing.T) { - _, _, err := api.UpdateUser(getTestUser(), http.StatusNotFound) - if err != nil { - t.Errorf("unable to update user: %v", err) - } -} - -func TestGetNonExistentUser(t *testing.T) { - _, _, err := api.GetUserByID(0, http.StatusNotFound) - if err != nil { - t.Errorf("unable to get user: %v", err) - } -} - -func TestDeleteNonExistentUser(t *testing.T) { - _, err := api.RemoveUser(getTestUser(), http.StatusNotFound) - if err != nil { - t.Errorf("unable to remove user: %v", err) - } -} - -func TestAddDuplicateUser(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - _, _, err = api.AddUser(getTestUser(), http.StatusInternalServerError) - if err != nil { - t.Errorf("unable to add second user: %v", err) - } - _, _, err = api.AddUser(getTestUser(), http.StatusOK) - if err == nil { - t.Errorf("adding a duplicate user must fail") - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove user: %v", err) - } -} - -func TestGetUsers(t *testing.T) { - user1, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - u := getTestUser() - u.Username = defaultUsername + "1" - user2, _, err := api.AddUser(u, http.StatusOK) - if err != nil { - t.Errorf("unable to add second user: %v", err) - } - users, _, err := api.GetUsers(0, 0, "", http.StatusOK) - if err != nil { - t.Errorf("unable to get users: %v", err) - } - if len(users) < 2 { - t.Errorf("at least 2 users are expected") - } - users, _, err = api.GetUsers(1, 0, "", http.StatusOK) - if err != nil { - t.Errorf("unable to get users: %v", err) - } - if len(users) != 1 { - t.Errorf("1 user is expected") - } - users, _, err = api.GetUsers(1, 1, "", http.StatusOK) - if err != nil { - t.Errorf("unable to get users: %v", err) - } - if len(users) != 1 { - t.Errorf("1 user is expected") - } - _, _, err = api.GetUsers(1, 1, "", http.StatusInternalServerError) - if err == nil { - t.Errorf("get users must succeed, we requested a fail for a good request") - } - _, err = api.RemoveUser(user1, http.StatusOK) - if err != nil { - t.Errorf("unable to remove user: %v", err) - } - _, err = api.RemoveUser(user2, http.StatusOK) - if err != nil { - t.Errorf("unable to remove user: %v", err) - } -} - -func TestGetQuotaScans(t *testing.T) { - _, _, err := api.GetQuotaScans(http.StatusOK) - if err != nil { - t.Errorf("unable to get quota scans: %v", err) - } - _, _, err = api.GetQuotaScans(http.StatusInternalServerError) - if err == nil { - t.Errorf("quota scan request must succeed, we requested to check a wrong status code") - } -} - -func TestStartQuotaScan(t *testing.T) { - user, _, err := api.AddUser(getTestUser(), http.StatusOK) - if err != nil { - t.Errorf("unable to add user: %v", err) - } - _, err = api.StartQuotaScan(user, http.StatusCreated) - if err != nil { - t.Errorf("unable to start quota scan: %v", err) - } - _, err = api.RemoveUser(user, http.StatusOK) - if err != nil { - t.Errorf("unable to remove user: %v", err) - } -} - -func TestGetVersion(t *testing.T) { - _, _, err := api.GetVersion(http.StatusOK) - if err != nil { - t.Errorf("unable to get sftp version: %v", err) - } - _, _, err = api.GetVersion(http.StatusInternalServerError) - if err == nil { - t.Errorf("get version request must succeed, we requested to check a wrong status code") - } -} - -func TestGetConnections(t *testing.T) { - _, _, err := api.GetConnections(http.StatusOK) - if err != nil { - t.Errorf("unable to get sftp connections: %v", err) - } - _, _, err = api.GetConnections(http.StatusInternalServerError) - if err == nil { - t.Errorf("get sftp connections request must succeed, we requested to check a wrong status code") - } -} - -func TestCloseActiveConnection(t *testing.T) { - _, err := api.CloseConnection("non_existent_id", http.StatusNotFound) - if err != nil { - t.Errorf("unexpected error closing non existent sftp connection: %v", err) - } -} - -// test using mock http server - -func TestBasicUserHandlingMock(t *testing.T) { - user := getTestUser() - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - err := render.DecodeJSON(rr.Body, &user) - if err != nil { - t.Errorf("Error get user: %v", err) - } - req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusInternalServerError, rr.Code) - user.MaxSessions = 10 - user.UploadBandwidth = 128 - userAsJSON = getUserAsJSON(t, user) - req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - - req, _ = http.NewRequest(http.MethodGet, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - - var updatedUser dataprovider.User - err = render.DecodeJSON(rr.Body, &updatedUser) - if err != nil { - t.Errorf("Error decoding updated user: %v", err) - } - if user.MaxSessions != updatedUser.MaxSessions || user.UploadBandwidth != updatedUser.UploadBandwidth { - t.Errorf("Error modifying user actual: %v, %v", updatedUser.MaxSessions, updatedUser.UploadBandwidth) - } - req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestGetUserByIdInvalidParamsMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, userPath+"/0", nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusNotFound, rr.Code) - req, _ = http.NewRequest(http.MethodGet, userPath+"/a", nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestAddUserNoUsernameMock(t *testing.T) { - user := getTestUser() - user.Username = "" - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestAddUserInvalidHomeDirMock(t *testing.T) { - user := getTestUser() - user.HomeDir = "relative_path" - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestAddUserInvalidPermsMock(t *testing.T) { - user := getTestUser() - user.Permissions = []string{} - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestAddUserInvalidJsonMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer([]byte("invalid json"))) - rr := executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestUpdateUserInvalidJsonMock(t *testing.T) { - user := getTestUser() - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - err := render.DecodeJSON(rr.Body, &user) - if err != nil { - t.Errorf("Error get user: %v", err) - } - req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer([]byte("Invalid json"))) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestUpdateUserInvalidParamsMock(t *testing.T) { - user := getTestUser() - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - err := render.DecodeJSON(rr.Body, &user) - if err != nil { - t.Errorf("Error get user: %v", err) - } - user.HomeDir = "" - userAsJSON = getUserAsJSON(t, user) - req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(user.ID, 10), bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - userID := user.ID - user.ID = 0 - userAsJSON = getUserAsJSON(t, user) - req, _ = http.NewRequest(http.MethodPut, userPath+"/"+strconv.FormatInt(userID, 10), bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - user.ID = userID - req, _ = http.NewRequest(http.MethodPut, userPath+"/0", bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusNotFound, rr.Code) - req, _ = http.NewRequest(http.MethodPut, userPath+"/a", bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestGetUsersMock(t *testing.T) { - user := getTestUser() - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - err := render.DecodeJSON(rr.Body, &user) - if err != nil { - t.Errorf("Error get user: %v", err) - } - req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=510&offset=0&order=ASC&username="+defaultUsername, nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - var users []dataprovider.User - err = render.DecodeJSON(rr.Body, &users) - if err != nil { - t.Errorf("Error decoding users: %v", err) - } - if len(users) != 1 { - t.Errorf("1 user is expected") - } - req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=a&offset=0&order=ASC", nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=a&order=ASC", nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=0&order=ASCa", nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) - - req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestDeleteUserInvalidParamsMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodDelete, userPath+"/0", nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusNotFound, rr.Code) - req, _ = http.NewRequest(http.MethodDelete, userPath+"/a", nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestGetQuotaScansMock(t *testing.T) { - req, err := http.NewRequest("GET", quotaScanPath, nil) - if err != nil { - t.Errorf("error get quota scan: %v", err) - } - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestStartQuotaScanMock(t *testing.T) { - user := getTestUser() - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - err := render.DecodeJSON(rr.Body, &user) - if err != nil { - t.Errorf("Error get user: %v", err) - } - _, err = os.Stat(user.HomeDir) - if err == nil { - os.Remove(user.HomeDir) - } - // simulate a duplicate quota scan - userAsJSON = getUserAsJSON(t, user) - sftpd.AddQuotaScan(user.Username) - req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusConflict, rr.Code) - sftpd.RemoveQuotaScan(user.Username) - - userAsJSON = getUserAsJSON(t, user) - req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusCreated, rr.Code) - - req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - var scans []sftpd.ActiveQuotaScan - err = render.DecodeJSON(rr.Body, &scans) - if err != nil { - t.Errorf("Error get active scans: %v", err) - } - for len(scans) > 0 { - req, _ = http.NewRequest(http.MethodGet, quotaScanPath, nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) - err = render.DecodeJSON(rr.Body, &scans) - if err != nil { - t.Errorf("Error get active scans: %v", err) - break - } - } - _, err = os.Stat(user.HomeDir) - if err != nil && os.IsNotExist(err) { - os.MkdirAll(user.HomeDir, 0777) - } - req, _ = http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) - rr = executeRequest(req) - checkResponseCode(t, http.StatusCreated, rr.Code) - req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+strconv.FormatInt(user.ID, 10), nil) - rr = executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestStartQuotaScanBadUserMock(t *testing.T) { - user := getTestUser() - userAsJSON := getUserAsJSON(t, user) - req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer(userAsJSON)) - rr := executeRequest(req) - checkResponseCode(t, http.StatusNotFound, rr.Code) -} - -func TestStartQuotaScanNonExistentUserMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, quotaScanPath, bytes.NewBuffer([]byte("invalid json"))) - rr := executeRequest(req) - checkResponseCode(t, http.StatusBadRequest, rr.Code) -} - -func TestGetVersionMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, versionPath, nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestGetConnectionsMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusOK, rr.Code) -} - -func TestDeleteActiveConnectionMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusNotFound, rr.Code) -} - -func TestNotFoundMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, "/non/existing/path", nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusNotFound, rr.Code) -} - -func TestMethodNotAllowedMock(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, activeConnectionsPath, nil) - rr := executeRequest(req) - checkResponseCode(t, http.StatusMethodNotAllowed, rr.Code) -} - -func waitTCPListening(address string) { - for { - conn, err := net.Dial("tcp", address) - if err != nil { - logger.WarnToConsole("tcp server %v not listening: %v\n", address, err) - time.Sleep(100 * time.Millisecond) - continue - } - logger.InfoToConsole("tcp server %v now listening\n", address) - defer conn.Close() - break - } -} - -func getTestUser() dataprovider.User { - return dataprovider.User{ - Username: defaultUsername, - Password: defaultPassword, - HomeDir: filepath.Join(homeBasePath, defaultUsername), - Permissions: defaultPerms, - } -} - -func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { - json, err := json.Marshal(user) - if err != nil { - t.Errorf("error get user as json: %v", err) - return []byte("{}") - } - return json -} - -func executeRequest(req *http.Request) *httptest.ResponseRecorder { - rr := httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(rr, req) - return rr -} - -func checkResponseCode(t *testing.T, expected, actual int) { - if expected != actual { - t.Errorf("Expected response code %d. Got %d", expected, actual) - } -} diff --git a/api/api_utils.go b/api/api_utils.go deleted file mode 100644 index 14bba528..00000000 --- a/api/api_utils.go +++ /dev/null @@ -1,330 +0,0 @@ -package api - -import ( - "bytes" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/url" - "path" - "strconv" - "strings" - "time" - - "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/sftpd" - "github.com/drakkan/sftpgo/utils" - "github.com/go-chi/render" -) - -var ( - httpBaseURL = "http://127.0.0.1:8080" -) - -// SetBaseURL sets the base url to use for HTTP requests, default is "http://127.0.0.1:8080" -func SetBaseURL(url string) { - httpBaseURL = url -} - -// gets an HTTP Client with a timeout -func getHTTPClient() *http.Client { - return &http.Client{ - Timeout: 15 * time.Second, - } -} - -func buildURLRelativeToBase(paths ...string) string { - // we need to use path.Join and not filepath.Join - // since filepath.Join will use backslash separator on Windows - p := path.Join(paths...) - return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/")) -} - -// AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode. -func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) { - var newUser dataprovider.User - var body []byte - userAsJSON, err := json.Marshal(user) - if err != nil { - return newUser, body, err - } - resp, err := getHTTPClient().Post(buildURLRelativeToBase(userPath), "application/json", bytes.NewBuffer(userAsJSON)) - if err != nil { - return newUser, body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - if expectedStatusCode != http.StatusOK { - body, _ = getResponseBody(resp) - return newUser, body, err - } - if err == nil { - err = render.DecodeJSON(resp.Body, &newUser) - } else { - body, _ = getResponseBody(resp) - } - if err == nil { - err = checkUser(user, newUser) - } - return newUser, body, err -} - -// UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode. -func UpdateUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) { - var newUser dataprovider.User - var body []byte - userAsJSON, err := json.Marshal(user) - if err != nil { - return user, body, err - } - req, err := http.NewRequest(http.MethodPut, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)), - bytes.NewBuffer(userAsJSON)) - if err != nil { - return user, body, err - } - resp, err := getHTTPClient().Do(req) - if err != nil { - return user, body, err - } - defer resp.Body.Close() - body, _ = getResponseBody(resp) - err = checkResponse(resp.StatusCode, expectedStatusCode) - if expectedStatusCode != http.StatusOK { - return newUser, body, err - } - if err == nil { - newUser, body, err = GetUserByID(user.ID, expectedStatusCode) - } - if err == nil { - err = checkUser(user, newUser) - } - return newUser, body, err -} - -// RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode. -func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) { - var body []byte - req, err := http.NewRequest(http.MethodDelete, buildURLRelativeToBase(userPath, strconv.FormatInt(user.ID, 10)), nil) - if err != nil { - return body, err - } - resp, err := getHTTPClient().Do(req) - if err != nil { - return body, err - } - defer resp.Body.Close() - body, _ = getResponseBody(resp) - return body, checkResponse(resp.StatusCode, expectedStatusCode) -} - -// GetUserByID gets an user by database id and checks the received HTTP Status code against expectedStatusCode. -func GetUserByID(userID int64, expectedStatusCode int) (dataprovider.User, []byte, error) { - var user dataprovider.User - var body []byte - resp, err := getHTTPClient().Get(buildURLRelativeToBase(userPath, strconv.FormatInt(userID, 10))) - if err != nil { - return user, body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - if err == nil && expectedStatusCode == http.StatusOK { - err = render.DecodeJSON(resp.Body, &user) - } else { - body, _ = getResponseBody(resp) - } - return user, body, err -} - -// GetUsers allows to get a list of users and checks the received HTTP Status code against expectedStatusCode. -// The number of results can be limited specifying a limit. -// Some results can be skipped specifying an offset. -// The results can be filtered specifying an username, the username filter is an exact match -func GetUsers(limit int64, offset int64, username string, expectedStatusCode int) ([]dataprovider.User, []byte, error) { - var users []dataprovider.User - var body []byte - url, err := url.Parse(buildURLRelativeToBase(userPath)) - if err != nil { - return users, body, err - } - q := url.Query() - if limit > 0 { - q.Add("limit", strconv.FormatInt(limit, 10)) - } - if offset > 0 { - q.Add("offset", strconv.FormatInt(offset, 10)) - } - if len(username) > 0 { - q.Add("username", username) - } - url.RawQuery = q.Encode() - resp, err := getHTTPClient().Get(url.String()) - if err != nil { - return users, body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - if err == nil && expectedStatusCode == http.StatusOK { - err = render.DecodeJSON(resp.Body, &users) - } else { - body, _ = getResponseBody(resp) - } - return users, body, err -} - -// GetQuotaScans gets active quota scans and checks the received HTTP Status code against expectedStatusCode. -func GetQuotaScans(expectedStatusCode int) ([]sftpd.ActiveQuotaScan, []byte, error) { - var quotaScans []sftpd.ActiveQuotaScan - var body []byte - resp, err := getHTTPClient().Get(buildURLRelativeToBase(quotaScanPath)) - if err != nil { - return quotaScans, body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - if err == nil && expectedStatusCode == http.StatusOK { - err = render.DecodeJSON(resp.Body, "aScans) - } else { - body, _ = getResponseBody(resp) - } - return quotaScans, body, err -} - -// StartQuotaScan start a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode. -func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) { - var body []byte - userAsJSON, err := json.Marshal(user) - if err != nil { - return body, err - } - resp, err := getHTTPClient().Post(buildURLRelativeToBase(quotaScanPath), "application/json", bytes.NewBuffer(userAsJSON)) - if err != nil { - return body, err - } - defer resp.Body.Close() - body, _ = getResponseBody(resp) - return body, checkResponse(resp.StatusCode, expectedStatusCode) -} - -// GetConnections returns status and stats for active SFTP/SCP connections -func GetConnections(expectedStatusCode int) ([]sftpd.ConnectionStatus, []byte, error) { - var connections []sftpd.ConnectionStatus - var body []byte - resp, err := getHTTPClient().Get(buildURLRelativeToBase(activeConnectionsPath)) - if err != nil { - return connections, body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - if err == nil && expectedStatusCode == http.StatusOK { - err = render.DecodeJSON(resp.Body, &connections) - } else { - body, _ = getResponseBody(resp) - } - return connections, body, err -} - -// CloseConnection closes an active connection identified by connectionID -func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) { - var body []byte - req, err := http.NewRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), nil) - if err != nil { - return body, err - } - resp, err := getHTTPClient().Do(req) - if err != nil { - return body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - body, _ = getResponseBody(resp) - return body, err -} - -// GetVersion returns version details -func GetVersion(expectedStatusCode int) (utils.VersionInfo, []byte, error) { - var version utils.VersionInfo - var body []byte - resp, err := getHTTPClient().Get(buildURLRelativeToBase(versionPath)) - if err != nil { - return version, body, err - } - defer resp.Body.Close() - err = checkResponse(resp.StatusCode, expectedStatusCode) - if err == nil && expectedStatusCode == http.StatusOK { - err = render.DecodeJSON(resp.Body, &version) - } else { - body, _ = getResponseBody(resp) - } - return version, body, err -} - -func checkResponse(actual int, expected int) error { - if expected != actual { - return fmt.Errorf("wrong status code: got %v want %v", actual, expected) - } - return nil -} - -func getResponseBody(resp *http.Response) ([]byte, error) { - return ioutil.ReadAll(resp.Body) -} - -func checkUser(expected dataprovider.User, actual dataprovider.User) error { - if len(actual.Password) > 0 { - return errors.New("User password must not be visible") - } - if len(actual.PublicKeys) > 0 { - return errors.New("User public keys must not be visible") - } - if expected.ID <= 0 { - if actual.ID <= 0 { - return errors.New("actual user ID must be > 0") - } - } else { - if actual.ID != expected.ID { - return errors.New("user ID mismatch") - } - } - for _, v := range expected.Permissions { - if !utils.IsStringInSlice(v, actual.Permissions) { - return errors.New("Permissions contents mismatch") - } - } - return compareEqualsUserFields(expected, actual) -} - -func compareEqualsUserFields(expected dataprovider.User, actual dataprovider.User) error { - if expected.Username != actual.Username { - return errors.New("Username mismatch") - } - if expected.HomeDir != actual.HomeDir { - return errors.New("HomeDir mismatch") - } - if expected.UID != actual.UID { - return errors.New("UID mismatch") - } - if expected.GID != actual.GID { - return errors.New("GID mismatch") - } - if expected.MaxSessions != actual.MaxSessions { - return errors.New("MaxSessions mismatch") - } - if expected.QuotaSize != actual.QuotaSize { - return errors.New("QuotaSize mismatch") - } - if expected.QuotaFiles != actual.QuotaFiles { - return errors.New("QuotaFiles mismatch") - } - if len(expected.Permissions) != len(actual.Permissions) { - return errors.New("Permissions mismatch") - } - if expected.UploadBandwidth != actual.UploadBandwidth { - return errors.New("UploadBandwidth mismatch") - } - if expected.DownloadBandwidth != actual.DownloadBandwidth { - return errors.New("DownloadBandwidth mismatch") - } - return nil -} diff --git a/api/internal_test.go b/api/internal_test.go deleted file mode 100644 index 02c3d70d..00000000 --- a/api/internal_test.go +++ /dev/null @@ -1,228 +0,0 @@ -package api - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/drakkan/sftpgo/dataprovider" - "github.com/go-chi/chi" -) - -const ( - invalidURL = "http://foo\x7f.com/" - inactiveURL = "http://127.0.0.1:12345" -) - -func TestGetRespStatus(t *testing.T) { - var err error - err = &dataprovider.MethodDisabledError{} - respStatus := getRespStatus(err) - if respStatus != http.StatusForbidden { - t.Errorf("wrong resp status extected: %d got: %d", http.StatusForbidden, respStatus) - } - err = fmt.Errorf("generic error") - respStatus = getRespStatus(err) - if respStatus != http.StatusInternalServerError { - t.Errorf("wrong resp status extected: %d got: %d", http.StatusInternalServerError, respStatus) - } -} - -func TestCheckResponse(t *testing.T) { - err := checkResponse(http.StatusOK, http.StatusCreated) - if err == nil { - t.Errorf("check must fail") - } - err = checkResponse(http.StatusBadRequest, http.StatusBadRequest) - if err != nil { - t.Errorf("test must succeed, error: %v", err) - } -} - -func TestCheckUser(t *testing.T) { - expected := dataprovider.User{} - actual := dataprovider.User{} - actual.Password = "password" - err := checkUser(expected, actual) - if err == nil { - t.Errorf("actual password must be nil") - } - actual.Password = "" - actual.PublicKeys = []string{"pub key"} - err = checkUser(expected, actual) - if err == nil { - t.Errorf("actual public key must be nil") - } - actual.PublicKeys = []string{} - err = checkUser(expected, actual) - if err == nil { - t.Errorf("actual ID must be > 0") - } - expected.ID = 1 - actual.ID = 2 - err = checkUser(expected, actual) - if err == nil { - t.Errorf("actual ID must be equal to expected ID") - } - expected.ID = 2 - actual.ID = 2 - expected.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload} - actual.Permissions = []string{dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks} - err = checkUser(expected, actual) - if err == nil { - t.Errorf("Permissions are not equal") - } - expected.Permissions = append(expected.Permissions, dataprovider.PermRename) - err = checkUser(expected, actual) - if err == nil { - t.Errorf("Permissions are not equal") - } -} - -func TestCompareUserFields(t *testing.T) { - expected := dataprovider.User{} - actual := dataprovider.User{} - expected.Username = "test" - err := compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("Username does not match") - } - expected.Username = "" - expected.HomeDir = "homedir" - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("HomeDir does not match") - } - expected.HomeDir = "" - expected.UID = 1 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("UID does not match") - } - expected.UID = 0 - expected.GID = 1 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("GID does not match") - } - expected.GID = 0 - expected.MaxSessions = 2 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("MaxSessions do not match") - } - expected.MaxSessions = 0 - expected.QuotaSize = 4096 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("QuotaSize does not match") - } - expected.QuotaSize = 0 - expected.QuotaFiles = 2 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("QuotaFiles do not match") - } - expected.QuotaFiles = 0 - expected.Permissions = []string{dataprovider.PermCreateDirs} - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("Permissions are not equal") - } - expected.Permissions = nil - expected.UploadBandwidth = 64 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("UploadBandwidth does not match") - } - expected.UploadBandwidth = 0 - expected.DownloadBandwidth = 128 - err = compareEqualsUserFields(expected, actual) - if err == nil { - t.Errorf("DownloadBandwidth does not match") - } -} - -func TestApiCallsWithBadURL(t *testing.T) { - oldBaseURL := httpBaseURL - SetBaseURL(invalidURL) - u := dataprovider.User{} - _, _, err := UpdateUser(u, http.StatusBadRequest) - if err == nil { - t.Errorf("request with invalid URL must fail") - } - _, err = RemoveUser(u, http.StatusNotFound) - if err == nil { - t.Errorf("request with invalid URL must fail") - } - _, _, err = GetUsers(1, 0, "", http.StatusBadRequest) - if err == nil { - t.Errorf("request with invalid URL must fail") - } - _, err = CloseConnection("non_existent_id", http.StatusNotFound) - if err == nil { - t.Errorf("request with invalid URL must fail") - } - SetBaseURL(oldBaseURL) -} - -func TestApiCallToNotListeningServer(t *testing.T) { - oldBaseURL := httpBaseURL - SetBaseURL(inactiveURL) - u := dataprovider.User{} - _, _, err := AddUser(u, http.StatusBadRequest) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, _, err = UpdateUser(u, http.StatusNotFound) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, err = RemoveUser(u, http.StatusNotFound) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, _, err = GetUserByID(-1, http.StatusNotFound) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, _, err = GetUsers(100, 0, "", http.StatusOK) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, _, err = GetQuotaScans(http.StatusOK) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, err = StartQuotaScan(u, http.StatusNotFound) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, _, err = GetConnections(http.StatusOK) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, err = CloseConnection("non_existent_id", http.StatusNotFound) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - _, _, err = GetVersion(http.StatusOK) - if err == nil { - t.Errorf("request to an inactive URL must fail") - } - SetBaseURL(oldBaseURL) -} - -func TestCloseConnectionHandler(t *testing.T) { - req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) - rctx := chi.NewRouteContext() - rctx.URLParams.Add("connectionID", "") - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) - rr := httptest.NewRecorder() - handleCloseConnection(rr, req) - if rr.Code != http.StatusBadRequest { - t.Errorf("Expected response code 400. Got %d", rr.Code) - } -} diff --git a/api/quota.go b/api/quota.go deleted file mode 100644 index c930f643..00000000 --- a/api/quota.go +++ /dev/null @@ -1,44 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/sftpd" - "github.com/drakkan/sftpgo/utils" - "github.com/go-chi/render" -) - -func getQuotaScans(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, sftpd.GetQuotaScans()) -} - -func startQuotaScan(w http.ResponseWriter, r *http.Request) { - var u dataprovider.User - err := render.DecodeJSON(r.Body, &u) - if err != nil { - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - user, err := dataprovider.UserExists(dataProvider, u.Username) - if err != nil { - sendAPIResponse(w, r, err, "", http.StatusNotFound) - return - } - if sftpd.AddQuotaScan(user.Username) { - sendAPIResponse(w, r, err, "Scan started", http.StatusCreated) - go func() { - numFiles, size, _, err := utils.ScanDirContents(user.HomeDir) - if err != nil { - logger.Warn(logSender, "error scanning user home dir %v: %v", user.HomeDir, err) - } else { - err := dataprovider.UpdateUserQuota(dataProvider, user, numFiles, size, true) - logger.Debug(logSender, "user dir scanned, user: %v, dir: %v, error: %v", user.Username, user.HomeDir, err) - } - sftpd.RemoveQuotaScan(user.Username) - }() - } else { - sendAPIResponse(w, r, err, "Another scan is already in progress", http.StatusConflict) - } -} diff --git a/api/router.go b/api/router.go deleted file mode 100644 index bd58668f..00000000 --- a/api/router.go +++ /dev/null @@ -1,86 +0,0 @@ -package api - -import ( - "net/http" - - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/sftpd" - "github.com/drakkan/sftpgo/utils" - "github.com/go-chi/chi" - "github.com/go-chi/chi/middleware" - "github.com/go-chi/render" -) - -// GetHTTPRouter returns the configured HTTP handler -func GetHTTPRouter() http.Handler { - return router -} - -func initializeRouter() { - router = chi.NewRouter() - router.Use(middleware.RequestID) - router.Use(middleware.RealIP) - router.Use(logger.NewStructuredLogger(logger.GetLogger())) - router.Use(middleware.Recoverer) - - router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) - })) - - router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed) - })) - - router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, utils.GetAppVersion()) - }) - - router.Get(activeConnectionsPath, func(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, sftpd.GetConnectionsStats()) - }) - - router.Delete(activeConnectionsPath+"/{connectionID}", func(w http.ResponseWriter, r *http.Request) { - handleCloseConnection(w, r) - }) - - router.Get(quotaScanPath, func(w http.ResponseWriter, r *http.Request) { - getQuotaScans(w, r) - }) - - router.Post(quotaScanPath, func(w http.ResponseWriter, r *http.Request) { - startQuotaScan(w, r) - }) - - router.Get(userPath, func(w http.ResponseWriter, r *http.Request) { - getUsers(w, r) - }) - - router.Post(userPath, func(w http.ResponseWriter, r *http.Request) { - addUser(w, r) - }) - - router.Get(userPath+"/{userID}", func(w http.ResponseWriter, r *http.Request) { - getUserByID(w, r) - }) - - router.Put(userPath+"/{userID}", func(w http.ResponseWriter, r *http.Request) { - updateUser(w, r) - }) - - router.Delete(userPath+"/{userID}", func(w http.ResponseWriter, r *http.Request) { - deleteUser(w, r) - }) -} - -func handleCloseConnection(w http.ResponseWriter, r *http.Request) { - connectionID := chi.URLParam(r, "connectionID") - if connectionID == "" { - sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest) - return - } - if sftpd.CloseActiveConnection(connectionID) { - sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK) - } else { - sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) - } -} diff --git a/api/schema/openapi.yaml b/api/schema/openapi.yaml deleted file mode 100644 index b79b9c31..00000000 --- a/api/schema/openapi.yaml +++ /dev/null @@ -1,689 +0,0 @@ -openapi: 3.0.1 -info: - title: SFTPGo - description: 'SFTPGo REST API' - version: 1.0.0 - -servers: -- url: /api/v1 -paths: - /version: - get: - tags: - - version - summary: Get version details - operationId: get_version - responses: - 200: - description: successful operation - content: - application/json: - schema: - type: array - items: - $ref : '#/components/schemas/VersionInfo' - /connection: - get: - tags: - - connections - summary: Get the active users and info about their uploads/downloads - operationId: get_connections - responses: - 200: - description: successful operation - content: - application/json: - schema: - type: array - items: - $ref : '#/components/schemas/ConnectionStatus' - /connection/{connectionID}: - delete: - tags: - - connections - summary: Terminate an active connection - operationId: close_connection - parameters: - - name: connectionID - in: path - description: ID of the connection to close - required: true - schema: - type: string - responses: - 200: - description: successful operation - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 200 - message: "Connection closed" - error: "" - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 404: - description: Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 404 - message: "" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" - /quota_scan: - get: - tags: - - quota - summary: Get the active quota scans - operationId: get_quota_scans - responses: - 200: - description: successful operation - content: - application/json: - schema: - type: array - items: - $ref : '#/components/schemas/QuotaScan' - post: - tags: - - quota - summary: start a new quota scan - description: A quota scan update the number of files and their total size for the given user - operationId: start_quota_scan - requestBody: - required: true - content: - application/json: - schema: - $ref : '#/components/schemas/User' - responses: - 201: - description: successful operation - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 201 - message: "Scan started" - error: "" - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 403: - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 403 - message: "" - error: "Error description if any" - 404: - description: Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 404 - message: "" - error: "Error description if any" - 409: - description: Another scan is already in progress for this user - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 409 - message: "Another scan is already in progress" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" - /user: - get: - tags: - - users - summary: Returns an array with one or more users - description: For security reasons password and public key are empty in the response - operationId: get_users - parameters: - - in: query - name: offset - schema: - type: integer - minimum: 0 - default: 0 - required: false - - in: query - name: limit - schema: - type: integer - minimum: 1 - maximum: 500 - default: 100 - required: false - description: The maximum number of items to return. Max value is 500, default is 100 - - in: query - name: order - required: false - description: Ordering users by username - schema: - type: string - enum: - - ASC - - DESC - example: ASC - - in: query - name: username - required: false - description: Filter by username, extact match case sensitive - schema: - type: string - responses: - 200: - description: successful operation - content: - application/json: - schema: - type: array - items: - $ref : '#/components/schemas/User' - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 403: - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 403 - message: "" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" - post: - tags: - - users - summary: Adds a new SFTP/SCP user - operationId: add_user - requestBody: - required: true - content: - application/json: - schema: - $ref : '#/components/schemas/User' - responses: - 200: - description: successful operation - content: - application/json: - schema: - $ref : '#/components/schemas/User' - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 403: - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 403 - message: "" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" - /user/{userID}: - get: - tags: - - users - summary: Find user by ID - description: For security reasons password and public key are empty in the response - operationId: get_user_by_id - parameters: - - name: userID - in: path - description: ID of the user to retrieve - required: true - schema: - type: integer - format: int32 - responses: - 200: - description: successful operation - content: - application/json: - schema: - $ref : '#/components/schemas/User' - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 403: - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 403 - message: "" - error: "Error description if any" - 404: - description: Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 404 - message: "" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" - put: - tags: - - users - summary: Update an existing user - operationId: update_user - parameters: - - name: userID - in: path - description: ID of the user to update - required: true - schema: - type: integer - format: int32 - requestBody: - required: true - content: - application/json: - schema: - $ref : '#/components/schemas/User' - responses: - 200: - description: successful operation - content: - application/json: - schema: - $ref : '#/components/schemas/ApiResponse' - example: - status: 200 - message: "User updated" - error: "" - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 403: - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 403 - message: "" - error: "Error description if any" - 404: - description: Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 404 - message: "" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" - delete: - tags: - - users - summary: Delete an existing user - operationId: delete_user - parameters: - - name: userID - in: path - description: ID of the user to delete - required: true - schema: - type: integer - format: int32 - responses: - 200: - description: successful operation - content: - application/json: - schema: - $ref : '#/components/schemas/ApiResponse' - example: - status: 200 - message: "User deleted" - error: "" - 400: - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 400 - message: "" - error: "Error description if any" - 403: - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 403 - message: "" - error: "Error description if any" - 404: - description: Not Found - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 404 - message: "" - error: "Error description if any" - 500: - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/ApiResponse' - example: - status: 500 - message: "" - error: "Error description if any" -components: - schemas: - Permission: - type: string - enum: - - '*' - - list - - download - - upload - - delete - - rename - - create_dirs - - create_symlinks - description: > - Permissions: - * `*` - all permission are granted - * `list` - list items is allowed - * `download` - download files is allowed - * `upload` - upload files is allowed - * `delete` - delete files or directories is allowed - * `rename` - rename files or directories is allowed - * `create_dirs` - create directories is allowed - * `create_symlinks` - create links is allowed - User: - type: object - properties: - id: - type: integer - format: int32 - minimum: 1 - username: - type: string - password: - type: string - nullable: true - description: password or public key are mandatory. If the password has no known hashing algo prefix it will be stored using argon2id. You can send a password hashed as bcrypt or pbkdf2 and it will be stored as is. For security reasons this field is omitted when you search/get users - public_keys: - type: array - items: - type: string - nullable: true - description: a password or at least one public key are mandatory. For security reasons this field is omitted when you search/get users. - home_dir: - type: string - description: path to the user home directory. The user cannot upload or download files outside this directory. SFTPGo tries to automatically create this folder if missing. Must be an absolute path - uid: - type: integer - format: int32 - minimum: 0 - maximum: 65535 - description: if you run sftpgo as root user the created files and directories will be assigned to this uid. 0 means no change, the owner will be the user that runs sftpgo. Ignored on windows - gid: - type: integer - format: int32 - minimum: 0 - maximum: 65535 - description: if you run sftpgo as root user the created files and directories will be assigned to this gid. 0 means no change, the group will be the one of the user that runs sftpgo. Ignored on windows - max_sessions: - type: integer - format: int32 - description: limit the sessions that an user can open. 0 means unlimited - quota_size: - type: integer - format: int64 - description: quota as size. 0 menas unlimited. Please note that quota is updated if files are added/removed via SFTP/SCP otherwise a quota scan is needed - quota_files: - type: integer - format: int32 - description: quota as number of files. 0 menas unlimited. Please note that quota is updated if files are added/removed via SFTP/SCP otherwise a quota scan is needed - permissions: - type: array - items: - $ref: '#/components/schemas/Permission' - minItems: 1 - used_quota_size: - type: integer - format: int64 - used_quota_file: - type: integer - format: int32 - last_quota_update: - type: integer - format: int64 - description: last quota update as unix timestamp in milliseconds - upload_bandwidth: - type: integer - format: int32 - description: Maximum upload bandwidth as KB/s, 0 means unlimited - download_bandwidth: - type: integer - format: int32 - description: Maximum download bandwidth as KB/s, 0 means unlimited - Transfer: - type: object - properties: - operation_type: - type: string - enum: - - upload - - download - path: - type: string - description: SFTP/SCP file path for the upload/download - start_time: - type: integer - format: int64 - description: start time as unix timestamp in milliseconds - size: - type: integer - format: int64 - description: bytes transferred - last_activity: - type: integer - format: int64 - description: last transfer activity as unix timestamp in milliseconds - ConnectionStatus: - type: object - properties: - username: - type: string - description: connected username - connection_id: - type: string - description: unique connection identifier - client_version: - type: string - description: SFTP/SCP client version - remote_address: - type: string - description: Remote address for the connected SFTP/SCP client - connection_time: - type: integer - format: int64 - description: connection time as unix timestamp in milliseconds - last_activity: - type: integer - format: int64 - description: last client activity as unix timestamp in milliseconds - protocol: - type: string - enum: - - SFTP - - SCP - active_transfers: - type: array - items: - $ref : '#/components/schemas/Transfer' - QuotaScan: - type: object - properties: - username: - type: string - description: username with an active scan - start_time: - type: integer - format: int64 - description: scan start time as unix timestamp in milliseconds - ApiResponse: - type: object - properties: - status: - type: integer - format: int32 - minimum: 200 - maximum: 500 - example: 200 - description: HTTP Status code, for example 200 OK, 400 Bad request and so on - message: - type: string - nullable: true - description: additional message if any - error: - type: string - nullable: true - description: error description if any - VersionInfo: - type: object - properties: - version: - type: string - build_date: - type: string - commit_hash: - type: string - \ No newline at end of file diff --git a/api/user.go b/api/user.go deleted file mode 100644 index 719aa63d..00000000 --- a/api/user.go +++ /dev/null @@ -1,151 +0,0 @@ -package api - -import ( - "errors" - "net/http" - "strconv" - - "github.com/drakkan/sftpgo/dataprovider" - "github.com/go-chi/chi" - "github.com/go-chi/render" -) - -func getUsers(w http.ResponseWriter, r *http.Request) { - limit := 100 - offset := 0 - order := "ASC" - username := "" - var err error - if _, ok := r.URL.Query()["limit"]; ok { - limit, err = strconv.Atoi(r.URL.Query().Get("limit")) - if err != nil { - err = errors.New("Invalid limit") - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - if limit > 500 { - limit = 500 - } - } - if _, ok := r.URL.Query()["offset"]; ok { - offset, err = strconv.Atoi(r.URL.Query().Get("offset")) - if err != nil { - err = errors.New("Invalid offset") - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - } - if _, ok := r.URL.Query()["order"]; ok { - order = r.URL.Query().Get("order") - if order != "ASC" && order != "DESC" { - err = errors.New("Invalid order") - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - } - if _, ok := r.URL.Query()["username"]; ok { - username = r.URL.Query().Get("username") - } - users, err := dataprovider.GetUsers(dataProvider, limit, offset, order, username) - if err == nil { - render.JSON(w, r, users) - } else { - sendAPIResponse(w, r, err, "", http.StatusInternalServerError) - } -} - -func getUserByID(w http.ResponseWriter, r *http.Request) { - userID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64) - if err != nil { - err = errors.New("Invalid userID") - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - user, err := dataprovider.GetUserByID(dataProvider, userID) - if err == nil { - user.Password = "" - user.PublicKeys = []string{} - render.JSON(w, r, user) - } else if _, ok := err.(*dataprovider.RecordNotFoundError); ok { - sendAPIResponse(w, r, err, "", http.StatusNotFound) - } else { - sendAPIResponse(w, r, err, "", http.StatusInternalServerError) - } -} - -func addUser(w http.ResponseWriter, r *http.Request) { - var user dataprovider.User - err := render.DecodeJSON(r.Body, &user) - if err != nil { - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - err = dataprovider.AddUser(dataProvider, user) - if err == nil { - user, err = dataprovider.UserExists(dataProvider, user.Username) - if err == nil { - user.Password = "" - user.PublicKeys = []string{} - render.JSON(w, r, user) - } else { - sendAPIResponse(w, r, err, "", http.StatusInternalServerError) - } - } else { - sendAPIResponse(w, r, err, "", getRespStatus(err)) - } -} - -func updateUser(w http.ResponseWriter, r *http.Request) { - userID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64) - if err != nil { - err = errors.New("Invalid userID") - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - user, err := dataprovider.GetUserByID(dataProvider, userID) - if _, ok := err.(*dataprovider.RecordNotFoundError); ok { - sendAPIResponse(w, r, err, "", http.StatusNotFound) - return - } else if err != nil { - sendAPIResponse(w, r, err, "", http.StatusInternalServerError) - return - } - err = render.DecodeJSON(r.Body, &user) - if err != nil { - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - if user.ID != userID { - sendAPIResponse(w, r, err, "user ID in request body does not match user ID in path parameter", http.StatusBadRequest) - return - } - err = dataprovider.UpdateUser(dataProvider, user) - if err != nil { - sendAPIResponse(w, r, err, "", getRespStatus(err)) - } else { - sendAPIResponse(w, r, err, "User updated", http.StatusOK) - } -} - -func deleteUser(w http.ResponseWriter, r *http.Request) { - userID, err := strconv.ParseInt(chi.URLParam(r, "userID"), 10, 64) - if err != nil { - err = errors.New("Invalid userID") - sendAPIResponse(w, r, err, "", http.StatusBadRequest) - return - } - user, err := dataprovider.GetUserByID(dataProvider, userID) - if _, ok := err.(*dataprovider.RecordNotFoundError); ok { - sendAPIResponse(w, r, err, "", http.StatusNotFound) - return - } else if err != nil { - sendAPIResponse(w, r, err, "", http.StatusInternalServerError) - return - } - err = dataprovider.DeleteUser(dataProvider, user) - if err != nil { - sendAPIResponse(w, r, err, "", http.StatusInternalServerError) - } else { - sendAPIResponse(w, r, err, "User deleted", http.StatusOK) - } -} diff --git a/cmd/root.go b/cmd/root.go deleted file mode 100644 index 92414f05..00000000 --- a/cmd/root.go +++ /dev/null @@ -1,37 +0,0 @@ -package cmd - -import ( - "fmt" - "os" - - "github.com/drakkan/sftpgo/utils" - "github.com/spf13/cobra" -) - -const ( - logSender = "cmd" -) - -var ( - rootCmd = &cobra.Command{ - Use: "sftpgo", - Short: "Full featured and highly configurable SFTP server", - } -) - -func init() { - version := utils.GetAppVersion() - rootCmd.Flags().BoolP("version", "v", false, "") - rootCmd.Version = version.GetVersionAsString() - rootCmd.SetVersionTemplate(`{{printf "SFTPGo version: "}}{{printf "%s" .Version}} -`) -} - -// Execute adds all child commands to the root command and sets flags appropriately. -// This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) - } -} diff --git a/cmd/serve.go b/cmd/serve.go deleted file mode 100644 index e3ed0d42..00000000 --- a/cmd/serve.go +++ /dev/null @@ -1,181 +0,0 @@ -package cmd - -import ( - "fmt" - "net/http" - "os" - "time" - - "github.com/drakkan/sftpgo/api" - "github.com/drakkan/sftpgo/config" - "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/sftpd" - "github.com/rs/zerolog" - "github.com/spf13/cobra" - "github.com/spf13/viper" -) - -const ( - configDirFlag = "config-dir" - configDirKey = "config_dir" - configFileFlag = "config-file" - configFileKey = "config_file" - logFilePathFlag = "log-file-path" - logFilePathKey = "log_file_path" - logMaxSizeFlag = "log-max-size" - logMaxSizeKey = "log_max_size" - logMaxBackupFlag = "log-max-backups" - logMaxBackupKey = "log_max_backups" - logMaxAgeFlag = "log-max-age" - logMaxAgeKey = "log_max_age" - logCompressFlag = "log-compress" - logCompressKey = "log_compress" - logVerboseFlag = "log-verbose" - logVerboseKey = "log_verbose" -) - -var ( - configDir string - configFile string - logFilePath string - logMaxSize int - logMaxBackups int - logMaxAge int - logCompress bool - logVerbose bool - testVar string - serveCmd = &cobra.Command{ - Use: "serve", - Short: "Start the SFTP Server", - Long: `To start the SFTP Server with the default values for the command line flags simply use: - -sftpgo serve - -Please take a look at the usage below to customize the startup options`, - Run: func(cmd *cobra.Command, args []string) { - startServe() - }, - } -) - -func init() { - rootCmd.AddCommand(serveCmd) - - viper.SetDefault(configDirKey, ".") - viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") - serveCmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), - "Location for SFTPGo config dir. This directory should contain the \"sftpgo\" configuration file or the configured "+ - "config-file and it is used as the base for files with a relative path (eg. the private keys for the SFTP server, "+ - "the SQLite database if you use SQLite as data provider). This flag can be set using SFTPGO_CONFIG_DIR env var too.") - viper.BindPFlag(configDirKey, serveCmd.Flags().Lookup(configDirFlag)) - - viper.SetDefault(configFileKey, config.DefaultConfigName) - viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") - serveCmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey), - "Name for SFTPGo configuration file. It must be the name of a file stored in config-dir not the absolute path to the "+ - "configuration file. The specified file name must have no extension we automatically load JSON, YAML, TOML, HCL and "+ - "Java properties. Therefore if you set \"sftpgo\" then \"sftpgo.json\", \"sftpgo.yaml\" and so on are searched. "+ - "This flag can be set using SFTPGO_CONFIG_FILE env var too.") - viper.BindPFlag(configFileKey, serveCmd.Flags().Lookup(configFileFlag)) - - viper.SetDefault(logFilePathKey, "sftpgo.log") - viper.BindEnv(logFilePathKey, "SFTPGO_LOG_FILE_PATH") - serveCmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey), - "Location for the log file. This flag can be set using SFTPGO_LOG_FILE_PATH env var too.") - viper.BindPFlag(logFilePathKey, serveCmd.Flags().Lookup(logFilePathFlag)) - - viper.SetDefault(logMaxSizeKey, 10) - viper.BindEnv(logMaxSizeKey, "SFTPGO_LOG_MAX_SIZE") - serveCmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey), - "Maximum size in megabytes of the log file before it gets rotated. This flag can be set using SFTPGO_LOG_MAX_SIZE "+ - "env var too.") - viper.BindPFlag(logMaxSizeKey, serveCmd.Flags().Lookup(logMaxSizeFlag)) - - viper.SetDefault(logMaxBackupKey, 5) - viper.BindEnv(logMaxBackupKey, "SFTPGO_LOG_MAX_BACKUPS") - serveCmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey), - "Maximum number of old log files to retain. This flag can be set using SFTPGO_LOG_MAX_BACKUPS env var too.") - viper.BindPFlag(logMaxBackupKey, serveCmd.Flags().Lookup(logMaxBackupFlag)) - - viper.SetDefault(logMaxAgeKey, 28) - viper.BindEnv(logMaxAgeKey, "SFTPGO_LOG_MAX_AGE") - serveCmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey), - "Maximum number of days to retain old log files. This flag can be set using SFTPGO_LOG_MAX_AGE env var too.") - viper.BindPFlag(logMaxAgeKey, serveCmd.Flags().Lookup(logMaxAgeFlag)) - - viper.SetDefault(logCompressKey, false) - viper.BindEnv(logCompressKey, "SFTPGO_LOG_COMPRESS") - serveCmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), "Determine if the rotated "+ - "log files should be compressed using gzip. This flag can be set using SFTPGO_LOG_COMPRESS env var too.") - viper.BindPFlag(logCompressKey, serveCmd.Flags().Lookup(logCompressFlag)) - - viper.SetDefault(logVerboseKey, true) - viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") - serveCmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey), "Enable verbose logs. "+ - "This flag can be set using SFTPGO_LOG_VERBOSE env var too.") - viper.BindPFlag(logVerboseKey, serveCmd.Flags().Lookup(logVerboseFlag)) -} - -func startServe() { - logLevel := zerolog.DebugLevel - if !logVerbose { - logLevel = zerolog.InfoLevel - } - logger.InitLogger(logFilePath, logMaxSize, logMaxBackups, logMaxAge, logCompress, logLevel) - logger.Info(logSender, "starting SFTPGo, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ - "log max age: %v log verbose: %v, log compress: %v", configDir, configFile, logMaxSize, logMaxBackups, logMaxAge, - logVerbose, logCompress) - config.LoadConfig(configDir, configFile) - providerConf := config.GetProviderConf() - - err := dataprovider.Initialize(providerConf, configDir) - if err != nil { - logger.Error(logSender, "error initializing data provider: %v", err) - logger.ErrorToConsole("error initializing data provider: %v", err) - os.Exit(1) - } - - dataProvider := dataprovider.GetProvider() - sftpdConf := config.GetSFTPDConfig() - httpdConf := config.GetHTTPDConfig() - - sftpd.SetDataProvider(dataProvider) - - shutdown := make(chan bool) - - go func() { - logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf) - if err := sftpdConf.Initialize(configDir); err != nil { - logger.Error(logSender, "could not start SFTP server: %v", err) - logger.ErrorToConsole("could not start SFTP server: %v", err) - } - shutdown <- true - }() - - if httpdConf.BindPort > 0 { - router := api.GetHTTPRouter() - api.SetDataProvider(dataProvider) - - go func() { - logger.Debug(logSender, "initializing HTTP server with config %+v", httpdConf) - s := &http.Server{ - Addr: fmt.Sprintf("%s:%d", httpdConf.BindAddress, httpdConf.BindPort), - Handler: router, - ReadTimeout: 300 * time.Second, - WriteTimeout: 300 * time.Second, - MaxHeaderBytes: 1 << 20, // 1MB - } - if err := s.ListenAndServe(); err != nil { - logger.Error(logSender, "could not start HTTP server: %v", err) - logger.ErrorToConsole("could not start HTTP server: %v", err) - } - shutdown <- true - }() - } else { - logger.Debug(logSender, "HTTP server not started, disabled in config file") - logger.DebugToConsole("HTTP server not started, disabled in config file") - } - - <-shutdown -} diff --git a/config/config.go b/config/config.go deleted file mode 100644 index 86fba320..00000000 --- a/config/config.go +++ /dev/null @@ -1,134 +0,0 @@ -// Package config manages the configuration. -// Configuration is loaded from sftpgo.conf file. -// If sftpgo.conf is not found or cannot be readed or decoded as json the default configuration is used. -// The default configuration an be found inside the source tree: -// https://github.com/drakkan/sftpgo/blob/master/sftpgo.conf -package config - -import ( - "fmt" - "strings" - - "github.com/drakkan/sftpgo/api" - "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/sftpd" - "github.com/spf13/viper" -) - -const ( - logSender = "config" - defaultBanner = "SFTPGo" - // DefaultConfigName defines the name for the default config file. - // This is the file name without extension, we use viper and so we - // support all the config files format supported by viper - DefaultConfigName = "sftpgo" - // ConfigEnvPrefix defines a prefix that ENVIRONMENT variables will use - configEnvPrefix = "sftpgo" -) - -var ( - globalConf globalConfig -) - -type globalConfig struct { - SFTPD sftpd.Configuration `json:"sftpd" mapstructure:"sftpd"` - ProviderConf dataprovider.Config `json:"data_provider" mapstructure:"data_provider"` - HTTPDConfig api.HTTPDConf `json:"httpd" mapstructure:"httpd"` -} - -func init() { - // create a default configuration to use if no config file is provided - globalConf = globalConfig{ - SFTPD: sftpd.Configuration{ - Banner: defaultBanner, - BindPort: 2022, - BindAddress: "", - IdleTimeout: 15, - MaxAuthTries: 0, - Umask: "0022", - UploadMode: 0, - Actions: sftpd.Actions{ - ExecuteOn: []string{}, - Command: "", - HTTPNotificationURL: "", - }, - Keys: []sftpd.Key{}, - IsSCPEnabled: false, - }, - ProviderConf: dataprovider.Config{ - Driver: "sqlite", - Name: "sftpgo.db", - Host: "", - Port: 5432, - Username: "", - Password: "", - ConnectionString: "", - UsersTable: "users", - ManageUsers: 1, - SSLMode: 0, - TrackQuota: 1, - }, - HTTPDConfig: api.HTTPDConf{ - BindPort: 8080, - BindAddress: "127.0.0.1", - }, - } - - viper.SetEnvPrefix(configEnvPrefix) - replacer := strings.NewReplacer(".", "__") - viper.SetEnvKeyReplacer(replacer) - viper.SetConfigName(DefaultConfigName) - setViperAdditionalConfigPaths() - viper.AddConfigPath(".") - viper.AutomaticEnv() -} - -// GetSFTPDConfig returns the configuration for the SFTP server -func GetSFTPDConfig() sftpd.Configuration { - return globalConf.SFTPD -} - -// GetHTTPDConfig returns the configuration for the HTTP server -func GetHTTPDConfig() api.HTTPDConf { - return globalConf.HTTPDConfig -} - -//GetProviderConf returns the configuration for the data provider -func GetProviderConf() dataprovider.Config { - return globalConf.ProviderConf -} - -// LoadConfig loads the configuration -// configDir will be added to the configuration search paths. -// The search path contains by default the current directory and on linux it contains -// $HOME/.config/sftpgo and /etc/sftpgo too. -// configName is the name of the configuration to search without extension -func LoadConfig(configDir, configName string) error { - var err error - viper.AddConfigPath(configDir) - viper.SetConfigName(configName) - if err = viper.ReadInConfig(); err != nil { - logger.Warn(logSender, "error loading configuration file: %v. Default configuration will be used: %+v", err, globalConf) - logger.WarnToConsole("error loading configuration file: %v. Default configuration will be used.", err) - return err - } - err = viper.Unmarshal(&globalConf) - if err != nil { - logger.Warn(logSender, "error parsing configuration file: %v. Default configuration will be used: %+v", err, globalConf) - logger.WarnToConsole("error parsing configuration file: %v. Default configuration will be used.", err) - return err - } - if strings.TrimSpace(globalConf.SFTPD.Banner) == "" { - globalConf.SFTPD.Banner = defaultBanner - } - if globalConf.SFTPD.UploadMode < 0 || globalConf.SFTPD.UploadMode > 1 { - err = fmt.Errorf("Invalid upload_mode 0 and 1 are supported, configured: %v reset upload_mode to 0", - globalConf.SFTPD.UploadMode) - globalConf.SFTPD.UploadMode = 0 - logger.Warn(logSender, "Configuration error: %v", err) - logger.WarnToConsole("Configuration error: %v", err) - } - logger.Debug(logSender, "config file used: '%v', config loaded: %+v", viper.ConfigFileUsed(), globalConf) - return err -} diff --git a/config/config_linux.go b/config/config_linux.go deleted file mode 100644 index 967c2122..00000000 --- a/config/config_linux.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build linux - -package config - -import "github.com/spf13/viper" - -// linux specific config search path -func setViperAdditionalConfigPaths() { - viper.AddConfigPath("$HOME/.config/sftpgo") - viper.AddConfigPath("/etc/sftpgo") -} diff --git a/config/config_nolinux.go b/config/config_nolinux.go deleted file mode 100644 index fe5d6aeb..00000000 --- a/config/config_nolinux.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build !linux - -package config - -func setViperAdditionalConfigPaths() { - -} diff --git a/config/config_test.go b/config/config_test.go deleted file mode 100644 index d35f8186..00000000 --- a/config/config_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package config_test - -import ( - "encoding/json" - "io/ioutil" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/drakkan/sftpgo/api" - "github.com/drakkan/sftpgo/config" - "github.com/drakkan/sftpgo/dataprovider" - "github.com/drakkan/sftpgo/sftpd" -) - -const ( - tempConfigName = "temp" -) - -func TestLoadConfigTest(t *testing.T) { - configDir := ".." - err := config.LoadConfig(configDir, "") - if err != nil { - t.Errorf("error loading config") - } - emptyHTTPDConf := api.HTTPDConf{} - if config.GetHTTPDConfig() == emptyHTTPDConf { - t.Errorf("error loading httpd conf") - } - emptyProviderConf := dataprovider.Config{} - if config.GetProviderConf() == emptyProviderConf { - t.Errorf("error loading provider conf") - } - emptySFTPDConf := sftpd.Configuration{} - if config.GetSFTPDConfig().BindPort == emptySFTPDConf.BindPort { - t.Errorf("error loading SFTPD conf") - } - confName := tempConfigName + ".json" - configFilePath := filepath.Join(configDir, confName) - err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("loading a non existent config file must fail") - } - ioutil.WriteFile(configFilePath, []byte("{invalid json}"), 0666) - err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("loading an invalid config file must fail") - } - ioutil.WriteFile(configFilePath, []byte("{\"sftpd\": {\"bind_port\": \"a\"}}"), 0666) - err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("loading a config with an invalid bond_port must fail") - } - os.Remove(configFilePath) -} - -func TestEmptyBanner(t *testing.T) { - configDir := ".." - confName := tempConfigName + ".json" - configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") - sftpdConf := config.GetSFTPDConfig() - sftpdConf.Banner = " " - c := make(map[string]sftpd.Configuration) - c["sftpd"] = sftpdConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } - config.LoadConfig(configDir, tempConfigName) - sftpdConf = config.GetSFTPDConfig() - if strings.TrimSpace(sftpdConf.Banner) == "" { - t.Errorf("SFTPD banner cannot be empty") - } - os.Remove(configFilePath) -} - -func TestInvalidUploadMode(t *testing.T) { - configDir := ".." - confName := tempConfigName + ".json" - configFilePath := filepath.Join(configDir, confName) - config.LoadConfig(configDir, "") - sftpdConf := config.GetSFTPDConfig() - sftpdConf.UploadMode = 10 - c := make(map[string]sftpd.Configuration) - c["sftpd"] = sftpdConf - jsonConf, _ := json.Marshal(c) - err := ioutil.WriteFile(configFilePath, jsonConf, 0666) - if err != nil { - t.Errorf("error saving temporary configuration") - } - err = config.LoadConfig(configDir, tempConfigName) - if err == nil { - t.Errorf("Loading configuration with invalid upload_mode must fail") - } - os.Remove(configFilePath) -} diff --git a/crowdin.yml b/crowdin.yml new file mode 100644 index 00000000..d8e884c3 --- /dev/null +++ b/crowdin.yml @@ -0,0 +1,6 @@ +project_id_env: CROWDIN_PROJECT_ID +api_token_env: CROWDIN_PERSONAL_TOKEN +files: + - source: /static/locales/en/translation.json + translation: /static/locales/%two_letters_code%/%original_file_name% + type: i18next_json diff --git a/dataprovider/bolt.go b/dataprovider/bolt.go deleted file mode 100644 index b2ba61be..00000000 --- a/dataprovider/bolt.go +++ /dev/null @@ -1,315 +0,0 @@ -package dataprovider - -import ( - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "path/filepath" - "time" - - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/utils" - bolt "go.etcd.io/bbolt" -) - -var ( - usersBucket = []byte("users") - usersIDIdxBucket = []byte("users_id_idx") -) - -// BoltProvider auth provider for bolt key/value store -type BoltProvider struct { - dbHandle *bolt.DB -} - -func initializeBoltProvider(basePath string) error { - var err error - dbPath := config.Name - if !filepath.IsAbs(dbPath) { - dbPath = filepath.Join(basePath, dbPath) - } - dbHandle, err := bolt.Open(dbPath, 0600, &bolt.Options{ - NoGrowSync: false, - FreelistType: bolt.FreelistArrayType, - Timeout: 5 * time.Second}) - if err == nil { - logger.Debug(logSender, "bolt key store handle created") - err = dbHandle.Update(func(tx *bolt.Tx) error { - _, e := tx.CreateBucketIfNotExists(usersBucket) - return e - }) - if err != nil { - logger.Warn(logSender, "error creating users bucket: %v", err) - return err - } - err = dbHandle.Update(func(tx *bolt.Tx) error { - _, e := tx.CreateBucketIfNotExists(usersIDIdxBucket) - return e - }) - if err != nil { - logger.Warn(logSender, "error creating username idx bucket: %v", err) - return err - } - provider = BoltProvider{dbHandle: dbHandle} - } else { - logger.Warn(logSender, "error creating bolt key/value store handler: %v", err) - } - return err -} - -func (p BoltProvider) validateUserAndPass(username string, password string) (User, error) { - var user User - if len(password) == 0 { - return user, errors.New("Credentials cannot be null or empty") - } - user, err := p.userExists(username) - if err != nil { - logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) - return user, err - } - return checkUserAndPass(user, password) -} - -func (p BoltProvider) validateUserAndPubKey(username string, pubKey string) (User, error) { - var user User - if len(pubKey) == 0 { - return user, errors.New("Credentials cannot be null or empty") - } - user, err := p.userExists(username) - if err != nil { - logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) - return user, err - } - return checkUserAndPubKey(user, pubKey) -} - -func (p BoltProvider) getUserByID(ID int64) (User, error) { - var user User - err := p.dbHandle.View(func(tx *bolt.Tx) error { - bucket, idxBucket, err := getBuckets(tx) - if err != nil { - return err - } - userIDAsBytes := itob(ID) - username := idxBucket.Get(userIDAsBytes) - if username == nil { - return &RecordNotFoundError{err: fmt.Sprintf("user with ID %v does not exist", ID)} - } - u := bucket.Get(username) - if u == nil { - return &RecordNotFoundError{err: fmt.Sprintf("username %v and ID: %v does not exist", string(username), ID)} - } - return json.Unmarshal(u, &user) - }) - - return user, err -} - -func (p BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { - return p.dbHandle.Update(func(tx *bolt.Tx) error { - bucket, _, err := getBuckets(tx) - if err != nil { - return err - } - var u []byte - if u = bucket.Get([]byte(username)); u == nil { - return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist, unable to update quota", username)} - } - var user User - err = json.Unmarshal(u, &user) - if err != nil { - return err - } - if reset { - user.UsedQuotaSize = sizeAdd - user.UsedQuotaFiles = filesAdd - } else { - user.UsedQuotaSize += sizeAdd - user.UsedQuotaFiles += filesAdd - } - user.LastQuotaUpdate = utils.GetTimeAsMsSinceEpoch(time.Now()) - buf, err := json.Marshal(user) - if err != nil { - return err - } - return bucket.Put([]byte(username), buf) - }) -} - -func (p BoltProvider) getUsedQuota(username string) (int, int64, error) { - user, err := p.userExists(username) - if err != nil { - logger.Warn(logSender, "unable to get quota for user '%v' error: %v", username, err) - return 0, 0, err - } - return user.UsedQuotaFiles, user.UsedQuotaSize, err -} - -func (p BoltProvider) userExists(username string) (User, error) { - var user User - err := p.dbHandle.View(func(tx *bolt.Tx) error { - bucket, _, err := getBuckets(tx) - if err != nil { - return err - } - u := bucket.Get([]byte(username)) - if u == nil { - return &RecordNotFoundError{err: fmt.Sprintf("username %v does not exist", user.Username)} - } - return json.Unmarshal(u, &user) - }) - return user, err -} - -func (p BoltProvider) addUser(user User) error { - err := validateUser(&user) - if err != nil { - return err - } - return p.dbHandle.Update(func(tx *bolt.Tx) error { - bucket, idxBucket, err := getBuckets(tx) - if err != nil { - return err - } - if u := bucket.Get([]byte(user.Username)); u != nil { - return fmt.Errorf("username '%v' already exists", user.Username) - } - id, err := bucket.NextSequence() - if err != nil { - return err - } - user.ID = int64(id) - buf, err := json.Marshal(user) - if err != nil { - return err - } - userIDAsBytes := itob(user.ID) - err = bucket.Put([]byte(user.Username), buf) - if err != nil { - return err - } - return idxBucket.Put(userIDAsBytes, []byte(user.Username)) - }) -} - -func (p BoltProvider) updateUser(user User) error { - err := validateUser(&user) - if err != nil { - return err - } - return p.dbHandle.Update(func(tx *bolt.Tx) error { - bucket, _, err := getBuckets(tx) - if err != nil { - return err - } - if u := bucket.Get([]byte(user.Username)); u == nil { - return &RecordNotFoundError{err: fmt.Sprintf("username '%v' does not exist", user.Username)} - } - buf, err := json.Marshal(user) - if err != nil { - return err - } - return bucket.Put([]byte(user.Username), buf) - }) -} - -func (p BoltProvider) deleteUser(user User) error { - return p.dbHandle.Update(func(tx *bolt.Tx) error { - bucket, idxBucket, err := getBuckets(tx) - if err != nil { - return err - } - userIDAsBytes := itob(user.ID) - userName := idxBucket.Get(userIDAsBytes) - if userName == nil { - return &RecordNotFoundError{err: fmt.Sprintf("user with id %v does not exist", user.ID)} - } - err = bucket.Delete(userName) - if err != nil { - return err - } - return idxBucket.Delete(userIDAsBytes) - }) -} - -func (p BoltProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - users := []User{} - var err error - if len(username) > 0 { - if offset == 0 { - user, err := p.userExists(username) - if err == nil { - users = append(users, getUserNoCredentials(&user)) - } - } - return users, err - } - err = p.dbHandle.View(func(tx *bolt.Tx) error { - if limit <= 0 { - return nil - } - bucket, _, err := getBuckets(tx) - if err != nil { - return err - } - cursor := bucket.Cursor() - itNum := 0 - if order == "ASC" { - for k, v := cursor.First(); k != nil; k, v = cursor.Next() { - itNum++ - if itNum <= offset { - continue - } - var user User - err = json.Unmarshal(v, &user) - if err == nil { - users = append(users, getUserNoCredentials(&user)) - } - if len(users) >= limit { - break - } - } - } else { - for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { - itNum++ - if itNum <= offset { - continue - } - var user User - err = json.Unmarshal(v, &user) - if err == nil { - users = append(users, getUserNoCredentials(&user)) - } - if len(users) >= limit { - break - } - } - } - return err - }) - return users, err -} - -func getUserNoCredentials(user *User) User { - user.Password = "" - user.PublicKeys = []string{} - return *user -} - -// itob returns an 8-byte big endian representation of v. -func itob(v int64) []byte { - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(v)) - return b -} - -func getBuckets(tx *bolt.Tx) (*bolt.Bucket, *bolt.Bucket, error) { - var err error - bucket := tx.Bucket(usersBucket) - idxBucket := tx.Bucket(usersIDIdxBucket) - if bucket == nil || idxBucket == nil { - err = fmt.Errorf("Unable to find required buckets, bolt database structure not correcly defined") - } - return bucket, idxBucket, err -} diff --git a/dataprovider/dataprovider.go b/dataprovider/dataprovider.go deleted file mode 100644 index 17f9695d..00000000 --- a/dataprovider/dataprovider.go +++ /dev/null @@ -1,371 +0,0 @@ -// Package dataprovider provides data access. -// It abstract different data providers and exposes a common API. -// Currently the supported data providers are: PostreSQL (9+), MySQL (4.1+) and SQLite 3.x -package dataprovider - -import ( - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "crypto/subtle" - "encoding/base64" - "errors" - "fmt" - "hash" - "path/filepath" - "strconv" - "strings" - - "github.com/alexedwards/argon2id" - "golang.org/x/crypto/bcrypt" - "golang.org/x/crypto/pbkdf2" - "golang.org/x/crypto/ssh" - - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/utils" -) - -const ( - // SQLiteDataProviderName name for SQLite database provider - SQLiteDataProviderName = "sqlite" - // PGSSQLDataProviderName name for PostgreSQL database provider - PGSSQLDataProviderName = "postgresql" - // MySQLDataProviderName name for MySQL database provider - MySQLDataProviderName = "mysql" - // BoltDataProviderName name for bbolt key/value store provider - BoltDataProviderName = "bolt" - - logSender = "dataProvider" - argonPwdPrefix = "$argon2id$" - bcryptPwdPrefix = "$2a$" - pbkdf2SHA1Prefix = "$pbkdf2-sha1$" - pbkdf2SHA256Prefix = "$pbkdf2-sha256$" - pbkdf2SHA512Prefix = "$pbkdf2-sha512$" - manageUsersDisabledError = "please set manage_users to 1 in sftpgo.conf to enable this method" - trackQuotaDisabledError = "please enable track_quota in sftpgo.conf to use this method" -) - -var ( - // SupportedProviders data provider configured in the sftpgo.conf file must match of these strings - SupportedProviders = []string{SQLiteDataProviderName, PGSSQLDataProviderName, MySQLDataProviderName, BoltDataProviderName} - config Config - provider Provider - sqlPlaceholders []string - validPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermDelete, PermRename, - PermCreateDirs, PermCreateSymlinks} - hashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix} - pbkdfPwdPrefixes = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix} -) - -// Config provider configuration -type Config struct { - // Driver name, must be one of the SupportedProviders - Driver string `json:"driver" mapstructure:"driver"` - // Database name - Name string `json:"name" mapstructure:"name"` - // Database host - Host string `json:"host" mapstructure:"host"` - // Database port - Port int `json:"port" mapstructure:"port"` - // Database username - Username string `json:"username" mapstructure:"username"` - // Database password - Password string `json:"password" mapstructure:"password"` - // Used for drivers mysql and postgresql. - // 0 disable SSL/TLS connections. - // 1 require ssl. - // 2 set ssl mode to verify-ca for driver postgresql and skip-verify for driver mysql. - // 3 set ssl mode to verify-full for driver postgresql and preferred for driver mysql. - SSLMode int `json:"sslmode" mapstructure:"sslmode"` - // Custom database connection string. - // If not empty this connection string will be used instead of build one using the previous parameters - ConnectionString string `json:"connection_string" mapstructure:"connection_string"` - // Database table for SFTP users - UsersTable string `json:"users_table" mapstructure:"users_table"` - // Set to 0 to disable users management, 1 to enable - ManageUsers int `json:"manage_users" mapstructure:"manage_users"` - // Set the preferred way to track users quota between the following choices: - // 0, disable quota tracking. REST API to scan user dir and update quota will do nothing - // 1, quota is updated each time a user upload or delete a file even if the user has no quota restrictions - // 2, quota is updated each time a user upload or delete a file but only for users with quota restrictions. - // With this configuration the "quota scan" REST API can still be used to periodically update space usage - // for users without quota restrictions - TrackQuota int `json:"track_quota" mapstructure:"track_quota"` -} - -// ValidationError raised if input data is not valid -type ValidationError struct { - err string -} - -// Validation error details -func (e *ValidationError) Error() string { - return fmt.Sprintf("Validation error: %s", e.err) -} - -// MethodDisabledError raised if a method is disabled in config file. -// For example, if user management is disabled, this error is raised -// every time an user operation is done using the REST API -type MethodDisabledError struct { - err string -} - -// Method disabled error details -func (e *MethodDisabledError) Error() string { - return fmt.Sprintf("Method disabled error: %s", e.err) -} - -// RecordNotFoundError raised if a requested user is not found -type RecordNotFoundError struct { - err string -} - -func (e *RecordNotFoundError) Error() string { - return fmt.Sprintf("Not found: %s", e.err) -} - -// GetProvider returns the configured provider -func GetProvider() Provider { - return provider -} - -// Provider interface that data providers must implement. -type Provider interface { - validateUserAndPass(username string, password string) (User, error) - validateUserAndPubKey(username string, pubKey string) (User, error) - updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error - getUsedQuota(username string) (int, int64, error) - userExists(username string) (User, error) - addUser(user User) error - updateUser(user User) error - deleteUser(user User) error - getUsers(limit int, offset int, order string, username string) ([]User, error) - getUserByID(ID int64) (User, error) -} - -// Initialize the data provider. -// An error is returned if the configured driver is invalid or if the data provider cannot be initialized -func Initialize(cnf Config, basePath string) error { - config = cnf - sqlPlaceholders = getSQLPlaceholders() - if config.Driver == SQLiteDataProviderName { - return initializeSQLiteProvider(basePath) - } else if config.Driver == PGSSQLDataProviderName { - return initializePGSQLProvider() - } else if config.Driver == MySQLDataProviderName { - return initializeMySQLProvider() - } else if config.Driver == BoltDataProviderName { - return initializeBoltProvider(basePath) - } - return fmt.Errorf("Unsupported data provider: %v", config.Driver) -} - -// CheckUserAndPass retrieves the SFTP user with the given username and password if a match is found or an error -func CheckUserAndPass(p Provider, username string, password string) (User, error) { - return p.validateUserAndPass(username, password) -} - -// CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error -func CheckUserAndPubKey(p Provider, username string, pubKey string) (User, error) { - return p.validateUserAndPubKey(username, pubKey) -} - -// UpdateUserQuota updates the quota for the given SFTP user adding filesAdd and sizeAdd. -// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. -func UpdateUserQuota(p Provider, user User, filesAdd int, sizeAdd int64, reset bool) error { - if config.TrackQuota == 0 { - return &MethodDisabledError{err: trackQuotaDisabledError} - } else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() { - return nil - } - return p.updateQuota(user.Username, filesAdd, sizeAdd, reset) -} - -// GetUsedQuota returns the used quota for the given SFTP user. -// TrackQuota must be >=1 to enable this method -func GetUsedQuota(p Provider, username string) (int, int64, error) { - if config.TrackQuota == 0 { - return 0, 0, &MethodDisabledError{err: trackQuotaDisabledError} - } - return p.getUsedQuota(username) -} - -// UserExists checks if the given SFTP username exists, returns an error if no match is found -func UserExists(p Provider, username string) (User, error) { - return p.userExists(username) -} - -// AddUser adds a new SFTP user. -// ManageUsers configuration must be set to 1 to enable this method -func AddUser(p Provider, user User) error { - if config.ManageUsers == 0 { - return &MethodDisabledError{err: manageUsersDisabledError} - } - return p.addUser(user) -} - -// UpdateUser updates an existing SFTP user. -// ManageUsers configuration must be set to 1 to enable this method -func UpdateUser(p Provider, user User) error { - if config.ManageUsers == 0 { - return &MethodDisabledError{err: manageUsersDisabledError} - } - return p.updateUser(user) -} - -// DeleteUser deletes an existing SFTP user. -// ManageUsers configuration must be set to 1 to enable this method -func DeleteUser(p Provider, user User) error { - if config.ManageUsers == 0 { - return &MethodDisabledError{err: manageUsersDisabledError} - } - return p.deleteUser(user) -} - -// GetUsers returns an array of users respecting limit and offset and filtered by username exact match if not empty -func GetUsers(p Provider, limit int, offset int, order string, username string) ([]User, error) { - return p.getUsers(limit, offset, order, username) -} - -// GetUserByID returns the user with the given database ID if a match is found or an error -func GetUserByID(p Provider, ID int64) (User, error) { - return p.getUserByID(ID) -} - -func validateUser(user *User) error { - if len(user.Username) == 0 || len(user.HomeDir) == 0 { - return &ValidationError{err: "Mandatory parameters missing"} - } - if len(user.Password) == 0 && len(user.PublicKeys) == 0 { - return &ValidationError{err: "Please set password or at least a public_key"} - } - if len(user.Permissions) == 0 { - return &ValidationError{err: "Please grant some permissions to this user"} - } - if !filepath.IsAbs(user.HomeDir) { - return &ValidationError{err: fmt.Sprintf("home_dir must be an absolute path, actual value: %v", user.HomeDir)} - } - for _, p := range user.Permissions { - if !utils.IsStringInSlice(p, validPerms) { - return &ValidationError{err: fmt.Sprintf("Invalid permission: %v", p)} - } - } - if len(user.Password) > 0 && !utils.IsStringPrefixInSlice(user.Password, hashPwdPrefixes) { - pwd, err := argon2id.CreateHash(user.Password, argon2id.DefaultParams) - if err != nil { - return err - } - user.Password = pwd - } - for i, k := range user.PublicKeys { - _, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) - if err != nil { - return &ValidationError{err: fmt.Sprintf("Could not parse key nr. %d: %s", i, err)} - } - } - return nil -} - -func checkUserAndPass(user User, password string) (User, error) { - var err error - if len(user.Password) == 0 { - return user, errors.New("Credentials cannot be null or empty") - } - var match bool - if strings.HasPrefix(user.Password, argonPwdPrefix) { - match, err = argon2id.ComparePasswordAndHash(password, user.Password) - if err != nil { - logger.Warn(logSender, "error comparing password with argon hash: %v", err) - return user, err - } - } else if strings.HasPrefix(user.Password, bcryptPwdPrefix) { - if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - logger.Warn(logSender, "error comparing password with bcrypt hash: %v", err) - return user, err - } - match = true - } else if utils.IsStringPrefixInSlice(user.Password, pbkdfPwdPrefixes) { - match, err = comparePbkdf2PasswordAndHash(password, user.Password) - if err != nil { - logger.Warn(logSender, "error comparing password with pbkdf2 sha256 hash: %v", err) - return user, err - } - } - if !match { - err = errors.New("Invalid credentials") - } - return user, err -} - -func checkUserAndPubKey(user User, pubKey string) (User, error) { - if len(user.PublicKeys) == 0 { - return user, errors.New("Invalid credentials") - } - for i, k := range user.PublicKeys { - storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) - if err != nil { - logger.Warn(logSender, "error parsing stored public key %d for user %v: %v", i, user.Username, err) - return user, err - } - if string(storedPubKey.Marshal()) == pubKey { - return user, nil - } - } - return user, errors.New("Invalid credentials") -} - -func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) { - vals := strings.Split(hashedPassword, "$") - if len(vals) != 5 { - return false, fmt.Errorf("pbkdf2: hash is not in the correct format") - } - var hashFunc func() hash.Hash - var hashSize int - if strings.HasPrefix(hashedPassword, pbkdf2SHA256Prefix) { - hashSize = sha256.Size - hashFunc = sha256.New - } else if strings.HasPrefix(hashedPassword, pbkdf2SHA512Prefix) { - hashSize = sha512.Size - hashFunc = sha512.New - } else if strings.HasPrefix(hashedPassword, pbkdf2SHA1Prefix) { - hashSize = sha1.Size - hashFunc = sha1.New - } else { - return false, fmt.Errorf("pbkdf2: invalid or unsupported hash format %v", vals[1]) - } - iterations, err := strconv.Atoi(vals[2]) - if err != nil { - return false, err - } - salt := vals[3] - expected := vals[4] - df := pbkdf2.Key([]byte(password), []byte(salt), iterations, hashSize, hashFunc) - buf := make([]byte, base64.StdEncoding.EncodedLen(len(df))) - base64.StdEncoding.Encode(buf, df) - return subtle.ConstantTimeCompare(buf, []byte(expected)) == 1, nil -} - -func getSSLMode() string { - if config.Driver == PGSSQLDataProviderName { - if config.SSLMode == 0 { - return "disable" - } else if config.SSLMode == 1 { - return "require" - } else if config.SSLMode == 2 { - return "verify-ca" - } else if config.SSLMode == 3 { - return "verify-full" - } - } else if config.Driver == MySQLDataProviderName { - if config.SSLMode == 0 { - return "false" - } else if config.SSLMode == 1 { - return "true" - } else if config.SSLMode == 2 { - return "skip-verify" - } else if config.SSLMode == 3 { - return "preferred" - } - } - return "" -} diff --git a/dataprovider/mysql.go b/dataprovider/mysql.go deleted file mode 100644 index 92ba0992..00000000 --- a/dataprovider/mysql.go +++ /dev/null @@ -1,92 +0,0 @@ -package dataprovider - -import ( - "database/sql" - "fmt" - "runtime" - "time" - - "github.com/drakkan/sftpgo/logger" -) - -// MySQLProvider auth provider for MySQL/MariaDB database -type MySQLProvider struct { - dbHandle *sql.DB -} - -func initializeMySQLProvider() error { - var err error - var connectionString string - if len(config.ConnectionString) == 0 { - connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8&interpolateParams=true&timeout=10s&tls=%v", - config.Username, config.Password, config.Host, config.Port, config.Name, getSSLMode()) - } else { - connectionString = config.ConnectionString - } - dbHandle, err := sql.Open("mysql", connectionString) - if err == nil { - numCPU := runtime.NumCPU() - logger.Debug(logSender, "mysql database handle created, connection string: '%v', pool size: %v", connectionString, numCPU) - dbHandle.SetMaxIdleConns(numCPU) - dbHandle.SetMaxOpenConns(numCPU) - dbHandle.SetConnMaxLifetime(1800 * time.Second) - provider = MySQLProvider{dbHandle: dbHandle} - } else { - logger.Warn(logSender, "error creating mysql database handler, connection string: '%v', error: %v", connectionString, err) - } - return err -} - -func (p MySQLProvider) validateUserAndPass(username string, password string) (User, error) { - return sqlCommonValidateUserAndPass(username, password, p.dbHandle) -} - -func (p MySQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { - return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) -} - -func (p MySQLProvider) getUserByID(ID int64) (User, error) { - return sqlCommonGetUserByID(ID, p.dbHandle) -} - -func (p MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { - tx, err := p.dbHandle.Begin() - if err != nil { - logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) - return err - } - err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) - if err == nil { - err = tx.Commit() - } else { - err = tx.Rollback() - } - if err != nil { - logger.Warn(logSender, "error closing transaction to update quota for user %v: %v", username, err) - } - return err -} - -func (p MySQLProvider) getUsedQuota(username string) (int, int64, error) { - return sqlCommonGetUsedQuota(username, p.dbHandle) -} - -func (p MySQLProvider) userExists(username string) (User, error) { - return sqlCommonCheckUserExists(username, p.dbHandle) -} - -func (p MySQLProvider) addUser(user User) error { - return sqlCommonAddUser(user, p.dbHandle) -} - -func (p MySQLProvider) updateUser(user User) error { - return sqlCommonUpdateUser(user, p.dbHandle) -} - -func (p MySQLProvider) deleteUser(user User) error { - return sqlCommonDeleteUser(user, p.dbHandle) -} - -func (p MySQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle) -} diff --git a/dataprovider/pgsql.go b/dataprovider/pgsql.go deleted file mode 100644 index 20a01125..00000000 --- a/dataprovider/pgsql.go +++ /dev/null @@ -1,90 +0,0 @@ -package dataprovider - -import ( - "database/sql" - "fmt" - "runtime" - - "github.com/drakkan/sftpgo/logger" -) - -// PGSQLProvider auth provider for PostgreSQL database -type PGSQLProvider struct { - dbHandle *sql.DB -} - -func initializePGSQLProvider() error { - var err error - var connectionString string - if len(config.ConnectionString) == 0 { - connectionString = fmt.Sprintf("host='%v' port=%v dbname='%v' user='%v' password='%v' sslmode=%v connect_timeout=10", - config.Host, config.Port, config.Name, config.Username, config.Password, getSSLMode()) - } else { - connectionString = config.ConnectionString - } - dbHandle, err := sql.Open("postgres", connectionString) - if err == nil { - numCPU := runtime.NumCPU() - logger.Debug(logSender, "postgres database handle created, connection string: '%v', pool size: %v", connectionString, numCPU) - dbHandle.SetMaxIdleConns(numCPU) - dbHandle.SetMaxOpenConns(numCPU) - provider = PGSQLProvider{dbHandle: dbHandle} - } else { - logger.Warn(logSender, "error creating postgres database handler, connection string: '%v', error: %v", connectionString, err) - } - return err -} - -func (p PGSQLProvider) validateUserAndPass(username string, password string) (User, error) { - return sqlCommonValidateUserAndPass(username, password, p.dbHandle) -} - -func (p PGSQLProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { - return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) -} - -func (p PGSQLProvider) getUserByID(ID int64) (User, error) { - return sqlCommonGetUserByID(ID, p.dbHandle) -} - -func (p PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { - tx, err := p.dbHandle.Begin() - if err != nil { - logger.Warn(logSender, "error starting transaction to update quota for user %v: %v", username, err) - return err - } - err = sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) - if err == nil { - err = tx.Commit() - } else { - err = tx.Rollback() - } - if err != nil { - logger.Warn(logSender, "error closing transaction to update quota for user %v: %v", username, err) - } - return err -} - -func (p PGSQLProvider) getUsedQuota(username string) (int, int64, error) { - return sqlCommonGetUsedQuota(username, p.dbHandle) -} - -func (p PGSQLProvider) userExists(username string) (User, error) { - return sqlCommonCheckUserExists(username, p.dbHandle) -} - -func (p PGSQLProvider) addUser(user User) error { - return sqlCommonAddUser(user, p.dbHandle) -} - -func (p PGSQLProvider) updateUser(user User) error { - return sqlCommonUpdateUser(user, p.dbHandle) -} - -func (p PGSQLProvider) deleteUser(user User) error { - return sqlCommonDeleteUser(user, p.dbHandle) -} - -func (p PGSQLProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle) -} diff --git a/dataprovider/sqlcommon.go b/dataprovider/sqlcommon.go deleted file mode 100644 index 51f7e19b..00000000 --- a/dataprovider/sqlcommon.go +++ /dev/null @@ -1,252 +0,0 @@ -package dataprovider - -import ( - "database/sql" - "encoding/json" - "errors" - "time" - - "github.com/drakkan/sftpgo/logger" - "github.com/drakkan/sftpgo/utils" -) - -func getUserByUsername(username string, dbHandle *sql.DB) (User, error) { - var user User - q := getUserByUsernameQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Debug(logSender, "error preparing database query %v: %v", q, err) - return user, err - } - defer stmt.Close() - - row := stmt.QueryRow(username) - return getUserFromDbRow(row, nil) -} - -func sqlCommonValidateUserAndPass(username string, password string, dbHandle *sql.DB) (User, error) { - var user User - if len(password) == 0 { - return user, errors.New("Credentials cannot be null or empty") - } - user, err := getUserByUsername(username, dbHandle) - if err != nil { - logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) - return user, err - } - return checkUserAndPass(user, password) -} - -func sqlCommonValidateUserAndPubKey(username string, pubKey string, dbHandle *sql.DB) (User, error) { - var user User - if len(pubKey) == 0 { - return user, errors.New("Credentials cannot be null or empty") - } - user, err := getUserByUsername(username, dbHandle) - if err != nil { - logger.Warn(logSender, "error authenticating user: %v, error: %v", username, err) - return user, err - } - return checkUserAndPubKey(user, pubKey) -} - -func sqlCommonGetUserByID(ID int64, dbHandle *sql.DB) (User, error) { - var user User - q := getUserByIDQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Debug(logSender, "error preparing database query %v: %v", q, err) - return user, err - } - defer stmt.Close() - - row := stmt.QueryRow(ID) - return getUserFromDbRow(row, nil) -} - -func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { - q := getUpdateQuotaQuery(reset) - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Debug(logSender, "error preparing database query %v: %v", q, err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(sizeAdd, filesAdd, utils.GetTimeAsMsSinceEpoch(time.Now()), username) - if err == nil { - logger.Debug(logSender, "quota updated for user %v, files increment: %v size increment: %v is reset? %v", - username, filesAdd, sizeAdd, reset) - } else { - logger.Warn(logSender, "error updating quota for username %v: %v", username, err) - } - return err -} - -func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, error) { - q := getQuotaQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Warn(logSender, "error preparing database query %v: %v", q, err) - return 0, 0, err - } - defer stmt.Close() - - var usedFiles int - var usedSize int64 - err = stmt.QueryRow(username).Scan(&usedSize, &usedFiles) - if err != nil { - logger.Warn(logSender, "error getting user quota: %v, error: %v", username, err) - return 0, 0, err - } - return usedFiles, usedSize, err -} - -func sqlCommonCheckUserExists(username string, dbHandle *sql.DB) (User, error) { - var user User - q := getUserByUsernameQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Warn(logSender, "error preparing database query %v: %v", q, err) - return user, err - } - defer stmt.Close() - row := stmt.QueryRow(username) - return getUserFromDbRow(row, nil) -} - -func sqlCommonAddUser(user User, dbHandle *sql.DB) error { - err := validateUser(&user) - if err != nil { - return err - } - q := getAddUserQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Warn(logSender, "error preparing database query %v: %v", q, err) - return err - } - defer stmt.Close() - permissions, err := user.GetPermissionsAsJSON() - if err != nil { - return err - } - publicKeys, err := user.GetPublicKeysAsJSON() - if err != nil { - return err - } - _, err = stmt.Exec(user.Username, user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, - user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth) - return err -} - -func sqlCommonUpdateUser(user User, dbHandle *sql.DB) error { - err := validateUser(&user) - if err != nil { - return err - } - q := getUpdateUserQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Warn(logSender, "error preparing database query %v: %v", q, err) - return err - } - defer stmt.Close() - permissions, err := user.GetPermissionsAsJSON() - if err != nil { - return err - } - publicKeys, err := user.GetPublicKeysAsJSON() - if err != nil { - return err - } - _, err = stmt.Exec(user.Password, string(publicKeys), user.HomeDir, user.UID, user.GID, user.MaxSessions, user.QuotaSize, - user.QuotaFiles, string(permissions), user.UploadBandwidth, user.DownloadBandwidth, user.ID) - return err -} - -func sqlCommonDeleteUser(user User, dbHandle *sql.DB) error { - q := getDeleteUserQuery() - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Warn(logSender, "error preparing database query %v: %v", q, err) - return err - } - defer stmt.Close() - _, err = stmt.Exec(user.ID) - return err -} - -func sqlCommonGetUsers(limit int, offset int, order string, username string, dbHandle *sql.DB) ([]User, error) { - users := []User{} - q := getUsersQuery(order, username) - stmt, err := dbHandle.Prepare(q) - if err != nil { - logger.Warn(logSender, "error preparing database query %v: %v", q, err) - return nil, err - } - defer stmt.Close() - var rows *sql.Rows - if len(username) > 0 { - rows, err = stmt.Query(username, limit, offset) - } else { - rows, err = stmt.Query(limit, offset) - } - if err == nil { - defer rows.Close() - for rows.Next() { - u, err := getUserFromDbRow(nil, rows) - // hide password and public key - if err == nil { - u.Password = "" - u.PublicKeys = []string{} - users = append(users, u) - } else { - break - } - } - } - - return users, err -} - -func getUserFromDbRow(row *sql.Row, rows *sql.Rows) (User, error) { - var user User - var permissions sql.NullString - var password sql.NullString - var publicKey sql.NullString - var err error - if row != nil { - err = row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, - &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, - &user.UploadBandwidth, &user.DownloadBandwidth) - - } else { - err = rows.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, - &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, - &user.UploadBandwidth, &user.DownloadBandwidth) - } - if err != nil { - if err == sql.ErrNoRows { - return user, &RecordNotFoundError{err: err.Error()} - } - return user, err - } - if password.Valid { - user.Password = password.String - } - if publicKey.Valid { - var list []string - err = json.Unmarshal([]byte(publicKey.String), &list) - if err == nil { - user.PublicKeys = list - } - } - if permissions.Valid { - var list []string - err = json.Unmarshal([]byte(permissions.String), &list) - if err == nil { - user.Permissions = list - } - } - return user, err -} diff --git a/dataprovider/sqlite.go b/dataprovider/sqlite.go deleted file mode 100644 index b032b59d..00000000 --- a/dataprovider/sqlite.go +++ /dev/null @@ -1,91 +0,0 @@ -package dataprovider - -import ( - "database/sql" - "errors" - "fmt" - "os" - "path/filepath" - - "github.com/drakkan/sftpgo/logger" -) - -// SQLiteProvider auth provider for SQLite database -type SQLiteProvider struct { - dbHandle *sql.DB -} - -func initializeSQLiteProvider(basePath string) error { - var err error - var connectionString string - if len(config.ConnectionString) == 0 { - dbPath := config.Name - if !filepath.IsAbs(dbPath) { - dbPath = filepath.Join(basePath, dbPath) - } - fi, err := os.Stat(dbPath) - if err != nil { - logger.Warn(logSender, "sqlite database file does not exists, please be sure to create and initialize"+ - " a database before starting sftpgo") - return err - } - if fi.Size() == 0 { - return errors.New("sqlite database file is invalid, please be sure to create and initialize" + - " a database before starting sftpgo") - } - connectionString = fmt.Sprintf("file:%v?cache=shared", dbPath) - } else { - connectionString = config.ConnectionString - } - dbHandle, err := sql.Open("sqlite3", connectionString) - if err == nil { - logger.Debug(logSender, "sqlite database handle created, connection string: '%v'", connectionString) - dbHandle.SetMaxOpenConns(1) - provider = SQLiteProvider{dbHandle: dbHandle} - } else { - logger.Warn(logSender, "error creating sqlite database handler, connection string: '%v', error: %v", connectionString, err) - } - return err -} - -func (p SQLiteProvider) validateUserAndPass(username string, password string) (User, error) { - return sqlCommonValidateUserAndPass(username, password, p.dbHandle) -} - -func (p SQLiteProvider) validateUserAndPubKey(username string, publicKey string) (User, error) { - return sqlCommonValidateUserAndPubKey(username, publicKey, p.dbHandle) -} - -func (p SQLiteProvider) getUserByID(ID int64) (User, error) { - return sqlCommonGetUserByID(ID, p.dbHandle) -} - -func (p SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { - // we keep only 1 open connection (SetMaxOpenConns(1)) so a transaction is not needed and it could block - // the database access since it will try to open a new connection - return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) -} - -func (p SQLiteProvider) getUsedQuota(username string) (int, int64, error) { - return sqlCommonGetUsedQuota(username, p.dbHandle) -} - -func (p SQLiteProvider) userExists(username string) (User, error) { - return sqlCommonCheckUserExists(username, p.dbHandle) -} - -func (p SQLiteProvider) addUser(user User) error { - return sqlCommonAddUser(user, p.dbHandle) -} - -func (p SQLiteProvider) updateUser(user User) error { - return sqlCommonUpdateUser(user, p.dbHandle) -} - -func (p SQLiteProvider) deleteUser(user User) error { - return sqlCommonDeleteUser(user, p.dbHandle) -} - -func (p SQLiteProvider) getUsers(limit int, offset int, order string, username string) ([]User, error) { - return sqlCommonGetUsers(limit, offset, order, username, p.dbHandle) -} diff --git a/dataprovider/sqlqueries.go b/dataprovider/sqlqueries.go deleted file mode 100644 index aceaf018..00000000 --- a/dataprovider/sqlqueries.go +++ /dev/null @@ -1,70 +0,0 @@ -package dataprovider - -import "fmt" - -const ( - selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions," + - "used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth" -) - -func getSQLPlaceholders() []string { - var placeholders []string - for i := 1; i <= 20; i++ { - if config.Driver == PGSSQLDataProviderName { - placeholders = append(placeholders, fmt.Sprintf("$%v", i)) - } else { - placeholders = append(placeholders, "?") - } - } - return placeholders -} - -func getUserByUsernameQuery() string { - return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectUserFields, config.UsersTable, sqlPlaceholders[0]) -} - -func getUserByIDQuery() string { - return fmt.Sprintf(`SELECT %v FROM %v WHERE id = %v`, selectUserFields, config.UsersTable, sqlPlaceholders[0]) -} - -func getUsersQuery(order string, username string) string { - if len(username) > 0 { - return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v ORDER BY username %v LIMIT %v OFFSET %v`, - selectUserFields, config.UsersTable, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2]) - } - return fmt.Sprintf(`SELECT %v FROM %v ORDER BY username %v LIMIT %v OFFSET %v`, selectUserFields, config.UsersTable, - order, sqlPlaceholders[0], sqlPlaceholders[1]) -} - -func getUpdateQuotaQuery(reset bool) string { - if reset { - return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v - WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) - } - return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v - WHERE username = %v`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) -} - -func getQuotaQuery() string { - return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %v WHERE username = %v`, config.UsersTable, - sqlPlaceholders[0]) -} - -func getAddUserQuery() string { - return fmt.Sprintf(`INSERT INTO %v (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions, - used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth) - VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v)`, config.UsersTable, sqlPlaceholders[0], sqlPlaceholders[1], - sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], - sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11]) -} - -func getUpdateUserQuery() string { - return fmt.Sprintf(`UPDATE %v SET password=%v,public_keys=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v, - quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v WHERE id = %v`, config.UsersTable, - sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], - sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11]) -} - -func getDeleteUserQuery() string { - return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, config.UsersTable, sqlPlaceholders[0]) -} diff --git a/dataprovider/user.go b/dataprovider/user.go deleted file mode 100644 index 4722abe2..00000000 --- a/dataprovider/user.go +++ /dev/null @@ -1,122 +0,0 @@ -package dataprovider - -import ( - "encoding/json" - "path/filepath" - - "github.com/drakkan/sftpgo/utils" -) - -// Available permissions for SFTP users -const ( - // All permissions are granted - PermAny = "*" - // List items such as files and directories is allowed - PermListItems = "list" - // download files is allowed - PermDownload = "download" - // upload files is allowed - PermUpload = "upload" - // delete files or directories is allowed - PermDelete = "delete" - // rename files or directories is allowed - PermRename = "rename" - // create directories is allowed - PermCreateDirs = "create_dirs" - // create symbolic links is allowed - PermCreateSymlinks = "create_symlinks" -) - -// User defines an SFTP user -type User struct { - // Database unique identifier - ID int64 `json:"id"` - // Username - Username string `json:"username"` - // Password used for password authentication. - // For users created using SFTPGo REST API the password is be stored using argon2id hashing algo. - // Checking passwords stored with bcrypt is supported too. - // Currently, as fallback, there is a clear text password checking but you should not store passwords - // as clear text and this support could be removed at any time, so please don't depend on it. - Password string `json:"password,omitempty"` - // PublicKeys used for public key authentication. At least one between password and a public key is mandatory - PublicKeys []string `json:"public_keys,omitempty"` - // The user cannot upload or download files outside this directory. Must be an absolute path - HomeDir string `json:"home_dir"` - // If sftpgo runs as root system user then the created files and directories will be assigned to this system UID - UID int `json:"uid"` - // If sftpgo runs as root system user then the created files and directories will be assigned to this system GID - GID int `json:"gid"` - // Maximum concurrent sessions. 0 means unlimited - MaxSessions int `json:"max_sessions"` - // Maximum size allowed as bytes. 0 means unlimited - QuotaSize int64 `json:"quota_size"` - // Maximum number of files allowed. 0 means unlimited - QuotaFiles int `json:"quota_files"` - // List of the granted permissions - Permissions []string `json:"permissions"` - // Used quota as bytes - UsedQuotaSize int64 `json:"used_quota_size"` - // Used quota as number of files - UsedQuotaFiles int `json:"used_quota_files"` - // Last quota update as unix timestamp in milliseconds - LastQuotaUpdate int64 `json:"last_quota_update"` - // Maximum upload bandwidth as KB/s, 0 means unlimited - UploadBandwidth int64 `json:"upload_bandwidth"` - // Maximum download bandwidth as KB/s, 0 means unlimited - DownloadBandwidth int64 `json:"download_bandwidth"` -} - -// HasPerm returns true if the user has the given permission or any permission -func (u *User) HasPerm(permission string) bool { - if utils.IsStringInSlice(PermAny, u.Permissions) { - return true - } - return utils.IsStringInSlice(permission, u.Permissions) -} - -// GetPermissionsAsJSON returns the permissions as json byte array -func (u *User) GetPermissionsAsJSON() ([]byte, error) { - return json.Marshal(u.Permissions) -} - -// GetPublicKeysAsJSON returns the public keys as json byte array -func (u *User) GetPublicKeysAsJSON() ([]byte, error) { - return json.Marshal(u.PublicKeys) -} - -// GetUID returns a validate uid, suitable for use with os.Chown -func (u *User) GetUID() int { - if u.UID <= 0 || u.UID > 65535 { - return -1 - } - return u.UID -} - -// GetGID returns a validate gid, suitable for use with os.Chown -func (u *User) GetGID() int { - if u.GID <= 0 || u.GID > 65535 { - return -1 - } - return u.GID -} - -// GetHomeDir returns the shortest path name equivalent to the user's home directory -func (u *User) GetHomeDir() string { - return filepath.Clean(u.HomeDir) -} - -// HasQuotaRestrictions returns true if there is a quota restriction on number of files or size or both -func (u *User) HasQuotaRestrictions() bool { - return u.QuotaFiles > 0 || u.QuotaSize > 0 -} - -// GetRelativePath returns the path for a file relative to the user's home dir. -// This is the path as seen by SFTP users -func (u *User) GetRelativePath(path string) string { - rel, err := filepath.Rel(u.GetHomeDir(), path) - if err != nil { - return "" - } - return "/" + filepath.ToSlash(rel) -} diff --git a/docker/scripts/download-plugins.sh b/docker/scripts/download-plugins.sh new file mode 100755 index 00000000..7638cd56 --- /dev/null +++ b/docker/scripts/download-plugins.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARCH=$(uname -m) + +case ${ARCH} in + x86_64) + SUFFIX=amd64 + ;; + aarch64) + SUFFIX=arm64 + ;; + *) + SUFFIX=ppc64le + ;; +esac + +echo "Downloading plugins for arch ${SUFFIX}" + +PLUGINS=(geoipfilter kms pubsub eventstore eventsearch auth) + +for PLUGIN in "${PLUGINS[@]}"; do + URL="https://github.com/sftpgo/sftpgo-plugin-${PLUGIN}/releases/latest/download/sftpgo-plugin-${PLUGIN}-linux-${SUFFIX}" + DEST="/usr/local/bin/sftpgo-plugin-${PLUGIN}" + + echo "Downloading ${PLUGIN}..." + if curl --fail --silent --show-error -L "${URL}" --output "${DEST}"; then + chmod 755 "${DEST}" + else + echo "Error: Failed to download ${PLUGIN}" >&2 + exit 1 + fi +done + +echo "All plugins downloaded successfully" diff --git a/examples/OTP/authy/README.md b/examples/OTP/authy/README.md new file mode 100644 index 00000000..af582956 --- /dev/null +++ b/examples/OTP/authy/README.md @@ -0,0 +1,60 @@ +# Authy + +These example show how-to integrate [Twilio Authy API](https://www.twilio.com/docs/authy/api) for One-Time-Password logins. + +The examples assume that the user has the free [Authy app](https://authy.com/) installed and uses it to generate offline [TOTP](https://en.wikipedia.org/wiki/Time-based_One-time_Password_algorithm) codes (soft tokens). + +You first need to [create an Authy Application in the Twilio Console](https://twilio.com/console/authy/applications?_ga=2.205553366.451688189.1597667213-1526360003.1597667213), then you can create a new Authy user and store a reference to the matching SFTPGo account. + +Verify that your Authy application is successfully registered: + +```bash +export AUTHY_API_KEY= +curl 'https://api.authy.com/protected/json/app/details' -H "X-Authy-API-Key: $AUTHY_API_KEY" +``` + +now create an Authy user: + +```bash +curl -XPOST "https://api.authy.com/protected/json/users/new" \ +-H "X-Authy-API-Key: $AUTHY_API_KEY" \ +--data-urlencode user[email]="user@domain.com" \ +--data-urlencode user[cellphone]="317-338-9302" \ +--data-urlencode user[country_code]="54" +``` + +The response is something like this: + +```json +{"message":"User created successfully.","user":{"id":xxxxxxxx},"success":true} +``` + +Save the user id somewhere and add a reference to the matching SFTPGo account. You could also store this ID in the `additional_info` SFTPGo user field. + +After this step you can use the Authy app installed on your phone to generate TOTP codes. + +Now you can verify the token using an HTTP GET request: + +```bash +export TOKEN= +export AUTHY_ID= +curl -i "https://api.authy.com/protected/json/verify/${TOKEN}/${AUTHY_ID}" \ + -H "X-Authy-API-Key: $AUTHY_API_KEY" +``` + +So inside your hook you need to check: + +- the HTTP response code for the verify request, it must be `200` +- the JSON response body, it must contains the key `success` with the value `true` (as string) + +If these conditions are met the token is valid and you allow the user to login. + +We provide the following examples: + +- [Keyboard interactive authentication](./keyint/README.md) for 2FA using password + Authy one time token. +- [External authentication](./extauth/README.md) using Authy one time tokens as passwords. +- [Check password hook](./checkpwd/README.md) for 2FA using a password consisting of a fixed string and a One Time Token. + +Please note that these are sample programs not intended for production use, you should write your own hook based on them and you should prefer HTTP based hooks if performance is a concern. + +:warning: SFTPGo has also built-in 2FA support. diff --git a/examples/OTP/authy/checkpwd/README.md b/examples/OTP/authy/checkpwd/README.md new file mode 100644 index 00000000..ddf5f5c2 --- /dev/null +++ b/examples/OTP/authy/checkpwd/README.md @@ -0,0 +1,3 @@ +# Authy 2FA via check password hook + +This example shows how to use 2FA via the check password hook using a password consisting of a fixed part and an Authy TOTP token. The hook will check the TOTP token using the Authy API and SFTPGo will check the fixed part. Please read the [sample code](./main.go), it should be self explanatory. diff --git a/examples/OTP/authy/checkpwd/go.mod b/examples/OTP/authy/checkpwd/go.mod new file mode 100644 index 00000000..24dcba55 --- /dev/null +++ b/examples/OTP/authy/checkpwd/go.mod @@ -0,0 +1,3 @@ +module github.com/drakkan/sftpgo/authy/checkpwd + +go 1.22.2 diff --git a/examples/OTP/authy/checkpwd/main.go b/examples/OTP/authy/checkpwd/main.go new file mode 100644 index 00000000..83669026 --- /dev/null +++ b/examples/OTP/authy/checkpwd/main.go @@ -0,0 +1,106 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "time" +) + +type userMapping struct { + SFTPGoUsername string + AuthyID int64 + AuthyAPIKey string +} + +type checkPasswordResponse struct { + // 0 KO, 1 OK, 2 partial success + Status int `json:"status"` + // for status == 2 this is the password that SFTPGo will check against the one stored + // inside the data provider + ToVerify string `json:"to_verify"` +} + +var ( + mapping []userMapping +) + +func init() { + // this is for demo only, you probably want to get this mapping dynamically, for example using a database query + mapping = append(mapping, userMapping{ + SFTPGoUsername: "", + AuthyID: 1234567, + AuthyAPIKey: "", + }) +} + +func printResponse(status int, toVerify string) { + r := checkPasswordResponse{ + Status: status, + ToVerify: toVerify, + } + resp, _ := json.Marshal(r) + fmt.Printf("%v\n", string(resp)) + if status > 0 { + os.Exit(0) + } else { + os.Exit(1) + } +} + +func main() { + // get credentials from env vars + username := os.Getenv("SFTPGO_AUTHD_USERNAME") + password := os.Getenv("SFTPGO_AUTHD_PASSWORD") + + for _, m := range mapping { + if m.SFTPGoUsername == username { + // Authy token len is 7, we assume that we have the password followed by the token + pwdLen := len(password) + if pwdLen <= 7 { + printResponse(0, "") + } + pwd := password[:pwdLen-7] + authyToken := password[pwdLen-7:] + // now verify the authy token and instruct SFTPGo to check the password if the token is OK + url := fmt.Sprintf("https://api.authy.com/protected/json/verify/%v/%v", authyToken, m.AuthyID) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + log.Fatal(err) + } + req.Header.Set("X-Authy-API-Key", m.AuthyAPIKey) + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + resp, err := httpClient.Do(req) + if err != nil { + printResponse(0, "") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + // status code 200 is expected + printResponse(0, "") + } + var authyResponse map[string]interface{} + respBody, err := io.ReadAll(resp.Body) + if err != nil { + printResponse(0, "") + } + err = json.Unmarshal(respBody, &authyResponse) + if err != nil { + printResponse(0, "") + } + if authyResponse["success"].(string) == "true" { + printResponse(2, pwd) + } + printResponse(0, "") + break + } + } + + // no mapping found + printResponse(0, "") +} diff --git a/examples/OTP/authy/extauth/README.md b/examples/OTP/authy/extauth/README.md new file mode 100644 index 00000000..d6f6683e --- /dev/null +++ b/examples/OTP/authy/extauth/README.md @@ -0,0 +1,3 @@ +# Authy external authentication + +This example shows how to use Authy TOTP token as password for SFTPGo users. Please read the [sample code](./main.go), it should be self explanatory. diff --git a/examples/OTP/authy/extauth/go.mod b/examples/OTP/authy/extauth/go.mod new file mode 100644 index 00000000..7b4dbbc2 --- /dev/null +++ b/examples/OTP/authy/extauth/go.mod @@ -0,0 +1,3 @@ +module github.com/drakkan/sftpgo/authy/extauth + +go 1.22.2 diff --git a/examples/OTP/authy/extauth/main.go b/examples/OTP/authy/extauth/main.go new file mode 100644 index 00000000..60b8704c --- /dev/null +++ b/examples/OTP/authy/extauth/main.go @@ -0,0 +1,109 @@ +package main + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "time" +) + +type userMapping struct { + SFTPGoUsername string + AuthyID int64 + AuthyAPIKey string +} + +// we assume that the SFTPGo already exists, we only check the one time token. +// If you need to create the SFTPGo user more fields are needed here +type minimalSFTPGoUser struct { + Status int `json:"status,omitempty"` + Username string `json:"username"` + HomeDir string `json:"home_dir,omitempty"` + Permissions map[string][]string `json:"permissions"` +} + +var ( + mapping []userMapping +) + +func init() { + // this is for demo only, you probably want to get this mapping dynamically, for example using a database query + mapping = append(mapping, userMapping{ + SFTPGoUsername: "", + AuthyID: 1234567, + AuthyAPIKey: "", + }) +} + +func printResponse(username string) { + u := minimalSFTPGoUser{ + Username: username, + Status: 1, + HomeDir: filepath.Join(os.TempDir(), username), + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{"*"} + resp, _ := json.Marshal(u) + fmt.Printf("%v\n", string(resp)) + if len(username) > 0 { + os.Exit(0) + } else { + os.Exit(1) + } +} + +func main() { + // get credentials from env vars + username := os.Getenv("SFTPGO_AUTHD_USERNAME") + password := os.Getenv("SFTPGO_AUTHD_PASSWORD") + if len(password) == 0 { + // login method is not password + printResponse("") + return + } + + for _, m := range mapping { + if m.SFTPGoUsername == username { + // mapping found we can now verify the token + url := fmt.Sprintf("https://api.authy.com/protected/json/verify/%v/%v", password, m.AuthyID) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + log.Fatal(err) + } + req.Header.Set("X-Authy-API-Key", m.AuthyAPIKey) + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + resp, err := httpClient.Do(req) + if err != nil { + printResponse("") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + // status code 200 is expected + printResponse("") + } + var authyResponse map[string]interface{} + respBody, err := io.ReadAll(resp.Body) + if err != nil { + printResponse("") + } + err = json.Unmarshal(respBody, &authyResponse) + if err != nil { + printResponse("") + } + if authyResponse["success"].(string) == "true" { + printResponse(username) + } + printResponse("") + break + } + } + + // no mapping found + printResponse("") +} diff --git a/examples/OTP/authy/keyint/README.md b/examples/OTP/authy/keyint/README.md new file mode 100644 index 00000000..f240bad9 --- /dev/null +++ b/examples/OTP/authy/keyint/README.md @@ -0,0 +1,3 @@ +# Authy 2FA using keyboard interactive authentication + +This example shows how to authenticate SFTP users using 2FA (password + Authy token). Please read the [sample code](./main.go), it should be self explanatory. diff --git a/examples/OTP/authy/keyint/go.mod b/examples/OTP/authy/keyint/go.mod new file mode 100644 index 00000000..583d606a --- /dev/null +++ b/examples/OTP/authy/keyint/go.mod @@ -0,0 +1,3 @@ +module github.com/drakkan/sftpgo/authy/keyint + +go 1.22.2 diff --git a/examples/OTP/authy/keyint/main.go b/examples/OTP/authy/keyint/main.go new file mode 100644 index 00000000..ab04c913 --- /dev/null +++ b/examples/OTP/authy/keyint/main.go @@ -0,0 +1,137 @@ +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" +) + +type userMapping struct { + SFTPGoUsername string + AuthyID int64 + AuthyAPIKey string +} + +type keyboardAuthHookResponse struct { + Instruction string `json:"instruction,omitempty"` + Questions []string `json:"questions,omitempty"` + Echos []bool `json:"echos,omitempty"` + AuthResult int `json:"auth_result"` + CheckPwd int `json:"check_password,omitempty"` +} + +var ( + mapping []userMapping +) + +func init() { + // this is for demo only, you probably want to get this mapping dynamically, for example using a database query + mapping = append(mapping, userMapping{ + SFTPGoUsername: "", + AuthyID: 1234567, + AuthyAPIKey: "", + }) +} + +func printAuthResponse(result int) { + resp, _ := json.Marshal(keyboardAuthHookResponse{ + AuthResult: result, + }) + fmt.Printf("%v\n", string(resp)) + if result == 1 { + os.Exit(0) + } else { + os.Exit(1) + } +} + +func main() { + // get credentials from env vars + username := os.Getenv("SFTPGO_AUTHD_USERNAME") + var userMap userMapping + for _, m := range mapping { + if m.SFTPGoUsername == username { + userMap = m + break + } + } + + if userMap.SFTPGoUsername != username { + // no mapping found + os.Exit(1) + } + + checkPwdQuestion := keyboardAuthHookResponse{ + Instruction: "This is a sample keyboard authentication program that ask for your password + Authy token", + Questions: []string{"Your password: "}, + Echos: []bool{false}, + CheckPwd: 1, + AuthResult: 0, + } + + q, _ := json.Marshal(checkPwdQuestion) + fmt.Printf("%v\n", string(q)) + + // in a real world app you probably want to use a read timeout + scanner := bufio.NewScanner(os.Stdin) + scanner.Scan() + if scanner.Err() != nil { + printAuthResponse(-1) + } + response := scanner.Text() + if response != "OK" { + printAuthResponse(-1) + } + + checkTokenQuestion := keyboardAuthHookResponse{ + Instruction: "", + Questions: []string{"Authy token: "}, + Echos: []bool{false}, + CheckPwd: 0, + AuthResult: 0, + } + + q, _ = json.Marshal(checkTokenQuestion) + fmt.Printf("%v\n", string(q)) + scanner.Scan() + if scanner.Err() != nil { + printAuthResponse(-1) + } + authyToken := scanner.Text() + + url := fmt.Sprintf("https://api.authy.com/protected/json/verify/%v/%v", authyToken, userMap.AuthyID) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + printAuthResponse(-1) + } + req.Header.Set("X-Authy-API-Key", userMap.AuthyAPIKey) + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + resp, err := httpClient.Do(req) + if err != nil { + printAuthResponse(-1) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + // status code 200 is expected + printAuthResponse(-1) + } + var authyResponse map[string]interface{} + respBody, err := io.ReadAll(resp.Body) + if err != nil { + printAuthResponse(-1) + } + err = json.Unmarshal(respBody, &authyResponse) + if err != nil { + printAuthResponse(-1) + } + if authyResponse["success"].(string) == "true" { + printAuthResponse(1) + } + printAuthResponse(-1) +} diff --git a/examples/backup/README.md b/examples/backup/README.md new file mode 100644 index 00000000..3190390c --- /dev/null +++ b/examples/backup/README.md @@ -0,0 +1,19 @@ +# Data Backup + +:warning: Since v2.4.0 you can use the [EventManager](https://docs.sftpgo.com/latest/eventmanager/) to schedule backups. + +The `backup` example script shows how to use the SFTPGo REST API to backup your data. + +The script is written in Python and has the following requirements: + +- python3 or python2 +- python [Requests](https://requests.readthedocs.io/en/master/) module + +The provided example tries to connect to an SFTPGo instance running on `127.0.0.1:8080` using the following credentials: + +- username: `admin` +- password: `password` + +and, if you execute it daily, it saves a different backup file for each day of the week. The backups will be saved within the configured `backups_path`. + +Please edit the script according to your needs. diff --git a/examples/backup/backup b/examples/backup/backup new file mode 100755 index 00000000..4bdea55d --- /dev/null +++ b/examples/backup/backup @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +from datetime import datetime +import sys + +import requests + +try: + import urllib.parse as urlparse +except ImportError: + import urlparse + +# change base_url to point to your SFTPGo installation +base_url = "http://127.0.0.1:8080" +# set to False if you want to skip TLS certificate validation +verify_tls_cert = True +# set the credentials for a valid admin here +admin_user = "admin" +admin_password = "password" + +# get a JWT token +auth = requests.auth.HTTPBasicAuth(admin_user, admin_password) +r = requests.get(urlparse.urljoin(base_url, "api/v2/token"), auth=auth, verify=verify_tls_cert, timeout=10) +if r.status_code != 200: + print("error getting access token: {}".format(r.text)) + sys.exit(1) +access_token = r.json()["access_token"] +auth_header = {"Authorization": "Bearer " + access_token} + +r = requests.get(urlparse.urljoin(base_url, "api/v2/dumpdata"), + params={"output-file":"backup_{}.json".format(datetime.today().strftime('%w'))}, + headers=auth_header, verify=verify_tls_cert, timeout=10) +if r.status_code == 200: + print("backup OK") +else: + print("backup error, status {}, response: {}".format(r.status_code, r.text)) diff --git a/examples/bulkupdate/README.md b/examples/bulkupdate/README.md new file mode 100644 index 00000000..5fffb81a --- /dev/null +++ b/examples/bulkupdate/README.md @@ -0,0 +1,17 @@ +# Bulk user update + +The `bulkuserupdate` example script shows how to use the SFTPGo REST API to easily update some common parameters for multiple users while preserving the others. + +The script is written in Python and has the following requirements: + +- python3 or python2 +- python [Requests](https://requests.readthedocs.io/en/master/) module + +The provided example tries to connect to an SFTPGo instance running on `127.0.0.1:8080` using the following credentials: + +- username: `admin` +- password: `password` + +and it updates some fields for `user1`, `user2` and `user3`. + +Please edit the script according to your needs. diff --git a/examples/bulkupdate/bulkuserupdate b/examples/bulkupdate/bulkuserupdate new file mode 100755 index 00000000..656ad57e --- /dev/null +++ b/examples/bulkupdate/bulkuserupdate @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +import posixpath +import sys + +import requests + +try: + import urllib.parse as urlparse +except ImportError: + import urlparse + +# change base_url to point to your SFTPGo installation +base_url = "http://127.0.0.1:8080" +# set to False if you want to skip TLS certificate validation +verify_tls_cert = True +# set the credentials for a valid admin here +admin_user = "admin" +admin_password = "password" +# insert here the users you want to update +users_to_update = ["user1", "user2", "user3"] +# set here the fields you need to update +fields_to_update = {"status":0, "quota_files": 1000, "additional_info":"updated using the bulkuserupdate example script"} + +# get a JWT token +auth = requests.auth.HTTPBasicAuth(admin_user, admin_password) +r = requests.get(urlparse.urljoin(base_url, "api/v2/token"), auth=auth, verify=verify_tls_cert, timeout=10) +if r.status_code != 200: + print("error getting access token: {}".format(r.text)) + sys.exit(1) +access_token = r.json()["access_token"] +auth_header = {"Authorization": "Bearer " + access_token} + +for username in users_to_update: + r = requests.get(urlparse.urljoin(base_url, posixpath.join("api/v2/users", username)), + headers=auth_header, verify=verify_tls_cert, timeout=10) + if r.status_code != 200: + print("error getting user {}: {}".format(username, r.text)) + continue + user = r.json() + user.update(fields_to_update) + r = requests.put(urlparse.urljoin(base_url, posixpath.join("api/v2/users", username)), + headers=auth_header, verify=verify_tls_cert, json=user, timeout=10) + if r.status_code == 200: + print("user {} updated".format(username)) + else: + print("error updating user {}, response code: {} response text: {}".format(username, + r.status_code, + r.text)) diff --git a/examples/convertusers/README.md b/examples/convertusers/README.md new file mode 100644 index 00000000..565e2c2f --- /dev/null +++ b/examples/convertusers/README.md @@ -0,0 +1,51 @@ +# Import users from other stores + +`convertusers` is a very simple command line client, written in python, to import users from other stores. It requires `python3` or `python2`. + +Here is the usage: + +```console +usage: convertusers [-h] [--min-uid MIN_UID] [--max-uid MAX_UID] [--usernames USERNAMES [USERNAMES ...]] + [--force-uid FORCE_UID] [--force-gid FORCE_GID] + input_file {unix-passwd,pure-ftpd,proftpd} output_file + +Convert users to a JSON format suitable to use with loadddata + +positional arguments: + input_file + {unix-passwd,pure-ftpd,proftpd} + To import from unix-passwd format you need the permission to read /etc/shadow that is typically + granted to the root user only + output_file + +optional arguments: + -h, --help show this help message and exit + --min-uid MIN_UID if >= 0 only import users with UID greater or equal to this value. Default: -1 + --max-uid MAX_UID if >= 0 only import users with UID lesser or equal to this value. Default: -1 + --usernames USERNAMES [USERNAMES ...] + Only import users with these usernames. Default: [] + --force-uid FORCE_UID + if >= 0 the imported users will have this UID in SFTPGo. Default: -1 + --force-gid FORCE_GID + if >= 0 the imported users will have this GID in SFTPGo. Default: -1 +``` + +Let's see some examples: + +```console +python convertusers "" unix-passwd unix_users.json --min-uid 500 --force-uid 1000 --force-gid 1000 +``` + +```console +python convertusers pureftpd.passwd pure-ftpd pure_users.json --usernames "user1" "user2" +``` + +```console +python convertusers proftpd.passwd proftpd pro_users.json +``` + +The generated json file can be used as input for the `loaddata` REST API. + +Please note that when importing Linux/Unix users the input file is not required: `/etc/passwd` and `/etc/shadow` are automatically parsed. `/etc/shadow` read permission is typically granted to the `root` user only, so you need to execute `convertusers` as `root`. + +:warning: SFTPGo does not currently support `yescrypt` hashed passwords. diff --git a/examples/convertusers/convertusers b/examples/convertusers/convertusers new file mode 100755 index 00000000..edf3d797 --- /dev/null +++ b/examples/convertusers/convertusers @@ -0,0 +1,208 @@ +#!/usr/bin/env python + +import argparse +import json +import sys +import time + +try: + import pwd + import spwd +except ImportError: + pwd = None + + +class ConvertUsers: + + def __init__(self, input_file, users_format, output_file, min_uid, max_uid, usernames, force_uid, force_gid): + self.input_file = input_file + self.users_format = users_format + self.output_file = output_file + self.min_uid = min_uid + self.max_uid = max_uid + self.usernames = usernames + self.force_uid = force_uid + self.force_gid = force_gid + self.SFTPGoUsers = [] + + def buildUserObject(self, username, password, home_dir, uid, gid, max_sessions, quota_size, quota_files, upload_bandwidth, + download_bandwidth, status, expiration_date, allowed_ip=[], denied_ip=[]): + return {'id':0, 'username':username, 'password':password, 'home_dir':home_dir, 'uid':uid, 'gid':gid, + 'max_sessions':max_sessions, 'quota_size':quota_size, 'quota_files':quota_files, 'permissions':{'/':["*"]}, + 'upload_bandwidth':upload_bandwidth, 'download_bandwidth':download_bandwidth, + 'status':status, 'expiration_date':expiration_date, + 'filters':{'allowed_ip':allowed_ip, 'denied_ip':denied_ip}} + + def addUser(self, user): + user['id'] = len(self.SFTPGoUsers) + 1 + print('') + print('New user imported: {}'.format(user)) + print('') + self.SFTPGoUsers.append(user) + + def saveUsers(self): + if self.SFTPGoUsers: + data = {'users':self.SFTPGoUsers} + jsonData = json.dumps(data) + with open(self.output_file, 'w') as f: + f.write(jsonData) + print() + print('Number of users saved to "{}": {}. You can import them using loaddata'.format(self.output_file, + len(self.SFTPGoUsers))) + print() + sys.exit(0) + else: + print('No user imported') + sys.exit(1) + + def convert(self): + if self.users_format == 'unix-passwd': + self.convertFromUnixPasswd() + elif self.users_format == 'pure-ftpd': + self.convertFromPureFTPD() + else: + self.convertFromProFTPD() + self.saveUsers() + + def isUserValid(self, username, uid): + if self.usernames and not username in self.usernames: + return False + if self.min_uid >= 0 and uid < self.min_uid: + return False + if self.max_uid >= 0 and uid > self.max_uid: + return False + return True + + def convertFromUnixPasswd(self): + days_from_epoch_time = time.time() / 86400 + for user in pwd.getpwall(): + username = user.pw_name + password = user.pw_passwd + uid = user.pw_uid + gid = user.pw_gid + home_dir = user.pw_dir + status = 1 + expiration_date = 0 + if not self.isUserValid(username, uid): + continue + if self.force_uid >= 0: + uid = self.force_uid + if self.force_gid >= 0: + gid = self.force_gid + # FIXME: if the passwords aren't in /etc/shadow they are probably DES encrypted and we don't support them + if password == 'x' or password == '*': + user_info = spwd.getspnam(username) + password = user_info.sp_pwdp + if not password or password == '!!' or password == '!*': + print('cannot import user "{}" without a password'.format(username)) + continue + if user_info.sp_inact > 0: + last_pwd_change_diff = days_from_epoch_time - user_info.sp_lstchg + if last_pwd_change_diff > user_info.sp_inact: + status = 0 + if user_info.sp_expire > 0: + expiration_date = user_info.sp_expire * 86400 + self.addUser(self.buildUserObject(username, password, home_dir, uid, gid, 0, 0, 0, 0, 0, status, + expiration_date)) + + def convertFromProFTPD(self): + with open(self.input_file, 'r') as f: + for line in f: + fields = line.split(':') + if len(fields) > 6: + username = fields[0] + password = fields[1] + uid = int(fields[2]) + gid = int(fields[3]) + home_dir = fields[5] + if not self.isUserValid(username, uid): + continue + if self.force_uid >= 0: + uid = self.force_uid + if self.force_gid >= 0: + gid = self.force_gid + self.addUser(self.buildUserObject(username, password, home_dir, uid, gid, 0, 0, 0, 0, 0, 1, 0)) + + def convertPureFTPDIP(self, fields): + result = [] + if not fields: + return result + for v in fields.split(','): + ip_mask = v.strip() + if not ip_mask: + continue + if ip_mask.count('.') < 3 and ip_mask.count(':') < 3: + print('cannot import pure-ftpd IP: {}'.format(ip_mask)) + continue + if '/' not in ip_mask: + ip_mask += '/32' + result.append(ip_mask) + return result + + def convertFromPureFTPD(self): + with open(self.input_file, 'r') as f: + for line in f: + fields = line.split(':') + if len(fields) > 16: + username = fields[0] + password = fields[1] + uid = int(fields[2]) + gid = int(fields[3]) + home_dir = fields[5] + upload_bandwidth = 0 + if fields[6]: + upload_bandwidth = int(int(fields[6]) / 1024) + download_bandwidth = 0 + if fields[7]: + download_bandwidth = int(int(fields[7]) / 1024) + max_sessions = 0 + if fields[10]: + max_sessions = int(fields[10]) + quota_files = 0 + if fields[11]: + quota_files = int(fields[11]) + quota_size = 0 + if fields[12]: + quota_size = int(fields[12]) + allowed_ip = self.convertPureFTPDIP(fields[15]) + denied_ip = self.convertPureFTPDIP(fields[16]) + if not self.isUserValid(username, uid): + continue + if self.force_uid >= 0: + uid = self.force_uid + if self.force_gid >= 0: + gid = self.force_gid + self.addUser(self.buildUserObject(username, password, home_dir, uid, gid, max_sessions, quota_size, + quota_files, upload_bandwidth, download_bandwidth, 1, 0, allowed_ip, + denied_ip)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description= + 'Convert users to a JSON format suitable to use with loadddata') + supportedUsersFormats = [] + help_text = '' + if pwd is not None: + supportedUsersFormats.append('unix-passwd') + help_text = 'To import from unix-passwd format you need the permission to read /etc/shadow that is typically granted to the root user only' + supportedUsersFormats.append('pure-ftpd') + supportedUsersFormats.append('proftpd') + parser.add_argument('input_file', type=str) + parser.add_argument('users_format', type=str, choices=supportedUsersFormats, help=help_text) + parser.add_argument('output_file', type=str) + parser.add_argument('--min-uid', type=int, default=-1, help='if >= 0 only import users with UID greater or equal ' + + 'to this value. Default: %(default)s') + parser.add_argument('--max-uid', type=int, default=-1, help='if >= 0 only import users with UID lesser or equal ' + + 'to this value. Default: %(default)s') + parser.add_argument('--usernames', type=str, nargs='+', default=[], help='Only import users with these usernames. ' + + 'Default: %(default)s') + parser.add_argument('--force-uid', type=int, default=-1, help='if >= 0 the imported users will have this UID in ' + + 'SFTPGo. Default: %(default)s') + parser.add_argument('--force-gid', type=int, default=-1, help='if >= 0 the imported users will have this GID in ' + + 'SFTPGo. Default: %(default)s') + + args = parser.parse_args() + + convertUsers = ConvertUsers(args.input_file, args.users_format, args.output_file, args.min_uid, args.max_uid, + args.usernames, args.force_uid, args.force_gid) + convertUsers.convert() diff --git a/examples/ldapauth/README.md b/examples/ldapauth/README.md new file mode 100644 index 00000000..c55d4f96 --- /dev/null +++ b/examples/ldapauth/README.md @@ -0,0 +1,48 @@ +# LDAPAuth + +This is an example for an external authentication program. It performs authentication against an LDAP server. +It is tested against [389ds](https://directory.fedoraproject.org/) and can be used as starting point to authenticate using any LDAP server including Active Directory. + +You need to change the LDAP connection parameters and the user search query to match your environment. +You can build this example using the following command: + +```console +go build -ldflags "-s -w" -o ldapauth +``` + +This program assumes that the 389ds schema was extended to add support for public keys using the following ldif file placed in `/etc/dirsrv/schema/98openssh-ldap.ldif`: + +```console +dn: cn=schema +changetype: modify +add: attributetypes +attributetypes: ( 1.3.6.1.4.1.24552.500.1.1.1.13 NAME 'sshPublicKey' DESC 'MANDATORY: OpenSSH Public key' EQUALITY octetStringMatch SYNTAX 1.3.6.1.4.1.1466.115.121.1.40 ) +- +add: objectclasses +objectClasses: ( 1.3.6.1.4.1.24552.500.1.1.2.0 NAME 'ldapPublicKey' SUP top AUXILIARY DESC 'MANDATORY: OpenSSH LPK objectclass' MUST ( uid ) MAY ( sshPublicKey ) ) +- + +dn: cn=sshpublickey,cn=default indexes,cn=config,cn=ldbm database,cn=plugins,cn=config +changetype: add +cn: sshpublickey +nsIndexType: eq +nsIndexType: pres +nsSystemIndex: false +objectClass: top +objectClass: nsIndex + +dn: cn=sshpublickey_self_manage,ou=groups,dc=example,dc=com +changetype: add +objectClass: top +objectClass: groupofuniquenames +cn: sshpublickey_self_manage +description: Members of this group gain the ability to edit their own sshPublicKey field + +dn: dc=example,dc=com +changetype: modify +add: aci +aci: (targetattr = "sshPublicKey") (version 3.0; acl "Allow members of sshpublickey_self_manage to edit their keys"; allow(write) (groupdn = "ldap:///cn=sshpublickey_self_manage,ou=groups,dc=example,dc=com" and userdn="ldap:///self" ); ) +- +``` + +:warning: A plugin for LDAP/Active Directory authentication is also [available](https://github.com/sftpgo/sftpgo-plugin-auth). diff --git a/examples/ldapauth/go.mod b/examples/ldapauth/go.mod new file mode 100644 index 00000000..81df3ad5 --- /dev/null +++ b/examples/ldapauth/go.mod @@ -0,0 +1,15 @@ +module github.com/drakkan/ldapauth + +go 1.25.0 + +require ( + github.com/go-ldap/ldap/v3 v3.4.12 + golang.org/x/crypto v0.45.0 +) + +require ( + github.com/Azure/go-ntlmssp v0.1.0 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect + github.com/google/uuid v1.6.0 // indirect + golang.org/x/sys v0.38.0 // indirect +) diff --git a/examples/ldapauth/go.sum b/examples/ldapauth/go.sum new file mode 100644 index 00000000..67fe852a --- /dev/null +++ b/examples/ldapauth/go.sum @@ -0,0 +1,40 @@ +github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= +github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= +github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/ldapauth/main.go b/examples/ldapauth/main.go new file mode 100644 index 00000000..31531043 --- /dev/null +++ b/examples/ldapauth/main.go @@ -0,0 +1,175 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "log/syslog" + "os" + "strconv" + "strings" + + "github.com/go-ldap/ldap/v3" + "golang.org/x/crypto/ssh" +) + +const ( + rootDN = "dc=example,dc=com" + bindUsername = "cn=sftpgo," + rootDN + bindURL = "ldap:///" // That is, the server on the default port of localhost. + passwordFile = "/etc/sftpgo/admin-password.txt" // make this file readable only by the server + publicDir = "/var/www/webdav/public" +) + +type userFilters struct { + DeniedLoginMethods []string `json:"denied_login_methods,omitempty"` +} + +type minimalSFTPGoUser struct { + Status int `json:"status,omitempty"` + Username string `json:"username"` + HomeDir string `json:"home_dir,omitempty"` + UID int `json:"uid,omitempty"` + GID int `json:"gid,omitempty"` + Permissions map[string][]string `json:"permissions"` + Filters userFilters `json:"filters"` +} + +func exitError() { + log.Printf("exitError\n") + u := minimalSFTPGoUser{ + Username: "", + } + resp, _ := json.Marshal(u) + fmt.Printf("%v\n", string(resp)) + os.Exit(1) +} + +func printSuccessResponse(username, homeDir string, uid, gid int, permissions []string) { + u := minimalSFTPGoUser{ + Username: username, + HomeDir: homeDir, + UID: uid, + GID: gid, + Status: 1, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = permissions + // uncomment the next line to require publickey+password authentication + //u.Filters.DeniedLoginMethods = []string{"publickey", "password", "keyboard-interactive", "publickey+keyboard-interactive"} + resp, _ := json.Marshal(u) + log.Printf("%v\n", string(resp)) + fmt.Printf("%v\n", string(resp)) + os.Exit(0) +} + +func main() { + logWriter, err := syslog.New(syslog.LOG_NOTICE, "sftpgo") + if err == nil { + log.SetOutput(logWriter) + } + // get credentials from env vars + username := os.Getenv("SFTPGO_AUTHD_USERNAME") + password := os.Getenv("SFTPGO_AUTHD_PASSWORD") + publickey := os.Getenv("SFTPGO_AUTHD_PUBLIC_KEY") + if strings.ToLower(username) == "anonymous" { + printSuccessResponse("anonymous", publicDir, 0, 0, []string{"list", "download"}) + return + } + l, err := ldap.DialURL(bindURL) + if err != nil { + log.Printf("DialURL: %s\n", err.Error()) + exitError() + } + defer l.Close() + // bind to the ldap server with an account that can read users + bindPassword, err := os.ReadFile(passwordFile) + if err != nil { + log.Printf("ReadFile(%s): %s\n", passwordFile, err.Error()) + exitError() + } + err = l.Bind(bindUsername, string(bindPassword)) + if err != nil { + log.Printf("Bind(%s): %s\n", bindUsername, err.Error()) + exitError() + } + + // search the user trying to login and fetch some attributes, this search string is tested against 389ds using the default configuration + log.Printf("username=%s\n", username) + searchFilter := fmt.Sprintf("(uid=%s)", ldap.EscapeFilter(username)) + searchRequest := ldap.NewSearchRequest( + "ou=people," + rootDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + searchFilter, + []string{"dn", "uid", "homeDirectory", "uidNumber", "gidNumber", "nsSshPublicKey"}, + nil, + ) + + sr, err := l.Search(searchRequest) + if err != nil { + log.Printf("Search(%s): %s\n", searchFilter, err.Error()) + exitError() + } + + // we expect exactly one user + if len(sr.Entries) != 1 { + log.Printf("Search(%s): %d entries\n", searchFilter, len(sr.Entries)) + exitError() + } + + if len(publickey) > 0 { + // check public key + userKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publickey)) + if err != nil { + log.Printf("ParseAuthorizedKey(%s): %s\n", publickey, err.Error()) + exitError() + } + authOk := false + for _, k := range sr.Entries[0].GetAttributeValues("nsSshPublicKey") { + key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) + // we skip an invalid public key stored inside the LDAP server + if err != nil { + continue + } + if bytes.Equal(key.Marshal(), userKey.Marshal()) { + authOk = true + break + } + } + if !authOk { + log.Printf("publickey %s !authOk\n", publickey) + exitError() + } + } else { + // bind to the LDAP server with the user dn and the given password to check the password + userdn := sr.Entries[0].DN + // log.Printf("password=%s\n", password) + err = l.Bind(userdn, password) + if err != nil { + log.Printf("Bind(%s): %s\n", userdn, err.Error()) + exitError() + } + } + + // People in the LDAP directory aren't necessarily Linux users; + // so they might not have a uidNumber or gidNumber. + uidNumber := sr.Entries[0].GetAttributeValue("uidNumber") + uid, err := strconv.Atoi(uidNumber) + if err != nil { + //log.Printf("uid Atoi(%s) = %s\n", uidNumber, err.Error()) + uid = 0 + } + gidNumber := sr.Entries[0].GetAttributeValue("gidNumber") + gid, err := strconv.Atoi(gidNumber) + if err != nil { + //log.Printf("gid Atoi(%s) = %s\n", gidNumber, err.Error()) + gid = 0 + } + homeDir := sr.Entries[0].GetAttributeValue("homeDirectory") + if (len(homeDir) <= 0) { + homeDir = publicDir // homeDir is a required attribute. + } + // return the authenticated user + printSuccessResponse(sr.Entries[0].GetAttributeValue("uid"), homeDir, uid, gid, []string{"*"}) +} diff --git a/examples/ldapauthserver/README.md b/examples/ldapauthserver/README.md new file mode 100644 index 00000000..a45c45a2 --- /dev/null +++ b/examples/ldapauthserver/README.md @@ -0,0 +1,13 @@ +# LDAPAuthServer + +This is an example for an HTTP server to use as external authentication HTTP hook. It performs authentication against an LDAP server. +It is tested against [389ds](https://directory.fedoraproject.org/) and can be used as starting point to authenticate using any LDAP server including Active Directory. + +You can configure the server using the [ldapauth.toml](./ldapauth.toml) configuration file. +You can build this example using the following command: + +```console +go build -ldflags "-s -w" -o ldapauthserver +``` + +:warning: A plugin for LDAP/Active Directory authentication is also [available](https://github.com/sftpgo/sftpgo-plugin-auth). diff --git a/examples/ldapauthserver/cmd/root.go b/examples/ldapauthserver/cmd/root.go new file mode 100644 index 00000000..4a9bc37f --- /dev/null +++ b/examples/ldapauthserver/cmd/root.go @@ -0,0 +1,158 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/drakkan/sftpgo/ldapauthserver/config" + "github.com/drakkan/sftpgo/ldapauthserver/utils" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +const ( + logSender = "cmd" + configDirFlag = "config-dir" + configDirKey = "config_dir" + configFileFlag = "config-file" + configFileKey = "config_file" + logFilePathFlag = "log-file-path" + logFilePathKey = "log_file_path" + logMaxSizeFlag = "log-max-size" + logMaxSizeKey = "log_max_size" + logMaxBackupFlag = "log-max-backups" + logMaxBackupKey = "log_max_backups" + logMaxAgeFlag = "log-max-age" + logMaxAgeKey = "log_max_age" + logCompressFlag = "log-compress" + logCompressKey = "log_compress" + logVerboseFlag = "log-verbose" + logVerboseKey = "log_verbose" + profilerFlag = "profiler" + profilerKey = "profiler" + defaultConfigDir = "." + defaultConfigName = config.DefaultConfigName + defaultLogFile = "ldapauth.log" + defaultLogMaxSize = 10 + defaultLogMaxBackup = 5 + defaultLogMaxAge = 28 + defaultLogCompress = false + defaultLogVerbose = true +) + +var ( + configDir string + configFile string + logFilePath string + logMaxSize int + logMaxBackups int + logMaxAge int + logCompress bool + logVerbose bool + + rootCmd = &cobra.Command{ + Use: "ldapauthserver", + Short: "LDAP Authentication Server for SFTPGo", + } +) + +func init() { + version := utils.GetAppVersion() + rootCmd.Flags().BoolP("version", "v", false, "") + rootCmd.Version = version.GetVersionAsString() + rootCmd.SetVersionTemplate(`{{printf "LDAP Authentication Server version: "}}{{printf "%s" .Version}} +`) +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func addConfigFlags(cmd *cobra.Command) { + viper.SetDefault(configDirKey, defaultConfigDir) + viper.BindEnv(configDirKey, "LDAPAUTH_CONFIG_DIR") + cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), + `Location for the config dir. This directory +should contain the "ldapauth" configuration +file or the configured config-file. This flag +can be set using LDAPAUTH_CONFIG_DIR env var too. +`) + viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) + + viper.SetDefault(configFileKey, defaultConfigName) + viper.BindEnv(configFileKey, "LDAPAUTH_CONFIG_FILE") + cmd.Flags().StringVarP(&configFile, configFileFlag, "f", viper.GetString(configFileKey), + `Name for the configuration file. It must be +the name of a file stored in config-dir not +the absolute path to the configuration file. +The specified file name must have no extension +we automatically load JSON, YAML, TOML, HCL and +Java properties. Therefore if you set \"ldapauth\" +then \"ldapauth.toml\", \"ldapauth.yaml\" and +so on are searched. This flag can be set using +LDAPAUTH_CONFIG_FILE env var too. +`) + viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) +} + +func addServeFlags(cmd *cobra.Command) { + addConfigFlags(cmd) + + viper.SetDefault(logFilePathKey, defaultLogFile) + viper.BindEnv(logFilePathKey, "LDAPAUTH_LOG_FILE_PATH") + cmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey), + `Location for the log file. Leave empty to write +logs to the standard output. This flag can be +set using LDAPAUTH_LOG_FILE_PATH env var too. +`) + viper.BindPFlag(logFilePathKey, cmd.Flags().Lookup(logFilePathFlag)) + + viper.SetDefault(logMaxSizeKey, defaultLogMaxSize) + viper.BindEnv(logMaxSizeKey, "LDAPAUTH_LOG_MAX_SIZE") + cmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey), + `Maximum size in megabytes of the log file +before it gets rotated. This flag can be set +using LDAPAUTH_LOG_MAX_SIZE env var too. It +is unused if log-file-path is empty.`) + viper.BindPFlag(logMaxSizeKey, cmd.Flags().Lookup(logMaxSizeFlag)) + + viper.SetDefault(logMaxBackupKey, defaultLogMaxBackup) + viper.BindEnv(logMaxBackupKey, "LDAPAUTH_LOG_MAX_BACKUPS") + cmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey), + `Maximum number of old log files to retain. +This flag can be set using LDAPAUTH_LOG_MAX_BACKUPS +env var too. It is unused if log-file-path is +empty.`) + viper.BindPFlag(logMaxBackupKey, cmd.Flags().Lookup(logMaxBackupFlag)) + + viper.SetDefault(logMaxAgeKey, defaultLogMaxAge) + viper.BindEnv(logMaxAgeKey, "LDAPAUTH_LOG_MAX_AGE") + cmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey), + `Maximum number of days to retain old log files. +This flag can be set using LDAPAUTH_LOG_MAX_AGE +env var too. It is unused if log-file-path is +empty.`) + viper.BindPFlag(logMaxAgeKey, cmd.Flags().Lookup(logMaxAgeFlag)) + + viper.SetDefault(logCompressKey, defaultLogCompress) + viper.BindEnv(logCompressKey, "LDAPAUTH_LOG_COMPRESS") + cmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), + `Determine if the rotated log files +should be compressed using gzip. This flag can +be set using LDAPAUTH_LOG_COMPRESS env var too. +It is unused if log-file-path is empty.`) + viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) + + viper.SetDefault(logVerboseKey, defaultLogVerbose) + viper.BindEnv(logVerboseKey, "LDAPAUTH_LOG_VERBOSE") + cmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey), + `Enable verbose logs. This flag can be set +using LDAPAUTH_LOG_VERBOSE env var too. +`) + viper.BindPFlag(logVerboseKey, cmd.Flags().Lookup(logVerboseFlag)) +} diff --git a/examples/ldapauthserver/cmd/serve.go b/examples/ldapauthserver/cmd/serve.go new file mode 100644 index 00000000..b9faf202 --- /dev/null +++ b/examples/ldapauthserver/cmd/serve.go @@ -0,0 +1,49 @@ +package cmd + +import ( + "path/filepath" + + "github.com/drakkan/sftpgo/ldapauthserver/config" + "github.com/drakkan/sftpgo/ldapauthserver/httpd" + "github.com/drakkan/sftpgo/ldapauthserver/logger" + "github.com/drakkan/sftpgo/ldapauthserver/utils" + "github.com/rs/zerolog" + "github.com/spf13/cobra" +) + +var ( + serveCmd = &cobra.Command{ + Use: "serve", + Short: "Start the LDAP Authentication Server", + Long: `To start the server with the default values for the command line flags simply use: + +ldapauthserver serve + +Please take a look at the usage below to customize the startup options`, + Run: func(cmd *cobra.Command, args []string) { + startServer() + }, + } +) + +func init() { + rootCmd.AddCommand(serveCmd) + addServeFlags(serveCmd) +} + +func startServer() error { + logLevel := zerolog.DebugLevel + if !logVerbose { + logLevel = zerolog.InfoLevel + } + if !filepath.IsAbs(logFilePath) && utils.IsFileInputValid(logFilePath) { + logFilePath = filepath.Join(configDir, logFilePath) + } + logger.InitLogger(logFilePath, logMaxSize, logMaxBackups, logMaxAge, logCompress, logLevel) + version := utils.GetAppVersion() + logger.Info(logSender, "", "starting LDAP Auth Server %v, config dir: %v, config file: %v, log max size: %v log max backups: %v "+ + "log max age: %v log verbose: %v, log compress: %v", version.GetVersionAsString(), configDir, configFile, logMaxSize, + logMaxBackups, logMaxAge, logVerbose, logCompress) + config.LoadConfig(configDir, configFile) + return httpd.StartHTTPServer(configDir, config.GetHTTPDConfig()) +} diff --git a/examples/ldapauthserver/config/config.go b/examples/ldapauthserver/config/config.go new file mode 100644 index 00000000..402a95cc --- /dev/null +++ b/examples/ldapauthserver/config/config.go @@ -0,0 +1,158 @@ +package config + +import ( + "strings" + + "github.com/drakkan/sftpgo/ldapauthserver/logger" + "github.com/spf13/viper" +) + +const ( + logSender = "config" + // DefaultConfigName defines the name for the default config file. + // This is the file name without extension, we use viper and so we + // support all the config files format supported by viper + DefaultConfigName = "ldapauth" + // ConfigEnvPrefix defines a prefix that ENVIRONMENT variables will use + configEnvPrefix = "ldapauth" +) + +// HTTPDConfig defines configuration for the HTTPD server +type HTTPDConfig struct { + BindAddress string `mapstructure:"bind_address"` + BindPort int `mapstructure:"bind_port"` + AuthUserFile string `mapstructure:"auth_user_file"` + CertificateFile string `mapstructure:"certificate_file"` + CertificateKeyFile string `mapstructure:"certificate_key_file"` +} + +// LDAPConfig defines the configuration parameters for LDAP connections and searches +type LDAPConfig struct { + BaseDN string `mapstructure:"basedn"` + BindURL string `mapstructure:"bind_url"` + BindUsername string `mapstructure:"bind_username"` + BindPassword string `mapstructure:"bind_password"` + SearchFilter string `mapstructure:"search_filter"` + SearchBaseAttrs []string `mapstructure:"search_base_attrs"` + DefaultUID int `mapstructure:"default_uid"` + DefaultGID int `mapstructure:"default_gid"` + ForceDefaultUID bool `mapstructure:"force_default_uid"` + ForceDefaultGID bool `mapstructure:"force_default_gid"` + InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` + CACertificates []string `mapstructure:"ca_certificates"` +} + +type appConfig struct { + HTTPD HTTPDConfig `mapstructure:"httpd"` + LDAP LDAPConfig `mapstructure:"ldap"` +} + +var conf appConfig + +func init() { + conf = appConfig{ + HTTPD: HTTPDConfig{ + BindAddress: "", + BindPort: 9000, + AuthUserFile: "", + CertificateFile: "", + CertificateKeyFile: "", + }, + LDAP: LDAPConfig{ + BaseDN: "dc=example,dc=com", + BindURL: "ldap://192.168.1.103:389", + BindUsername: "cn=Directory Manager", + BindPassword: "YOUR_ADMIN_PASSWORD_HERE", + SearchFilter: "(&(objectClass=nsPerson)(uid=%s))", + SearchBaseAttrs: []string{ + "dn", + "homeDirectory", + "uidNumber", + "gidNumber", + "nsSshPublicKey", + }, + DefaultUID: 0, + DefaultGID: 0, + ForceDefaultUID: true, + ForceDefaultGID: true, + InsecureSkipVerify: false, + CACertificates: nil, + }, + } + viper.SetEnvPrefix(configEnvPrefix) + replacer := strings.NewReplacer(".", "__") + viper.SetEnvKeyReplacer(replacer) + viper.SetConfigName(DefaultConfigName) + viper.AutomaticEnv() + viper.AllowEmptyEnv(true) +} + +// GetHomeDirectory returns the configured name for the LDAP field to use as home directory +func (l *LDAPConfig) GetHomeDirectory() string { + if len(l.SearchBaseAttrs) > 1 { + return l.SearchBaseAttrs[1] + } + return "homeDirectory" +} + +// GetUIDNumber returns the configured name for the LDAP field to use as UID +func (l *LDAPConfig) GetUIDNumber() string { + if len(l.SearchBaseAttrs) > 2 { + return l.SearchBaseAttrs[2] + } + return "uidNumber" +} + +// GetGIDNumber returns the configured name for the LDAP field to use as GID +func (l *LDAPConfig) GetGIDNumber() string { + if len(l.SearchBaseAttrs) > 3 { + return l.SearchBaseAttrs[3] + } + return "gidNumber" +} + +// GetPublicKey returns the configured name for the LDAP field to use as public keys +func (l *LDAPConfig) GetPublicKey() string { + if len(l.SearchBaseAttrs) > 4 { + return l.SearchBaseAttrs[4] + } + return "nsSshPublicKey" +} + +// GetHTTPDConfig returns the configuration for the HTTP server +func GetHTTPDConfig() HTTPDConfig { + return conf.HTTPD +} + +// GetLDAPConfig returns LDAP related settings +func GetLDAPConfig() LDAPConfig { + return conf.LDAP +} + +func getRedactedConf() appConfig { + c := conf + return c +} + +// LoadConfig loads the configuration +func LoadConfig(configDir, configName string) error { + var err error + viper.AddConfigPath(configDir) + viper.AddConfigPath(".") + viper.SetConfigName(configName) + if err = viper.ReadInConfig(); err != nil { + logger.Warn(logSender, "", "error loading configuration file: %v. Default configuration will be used: %+v", + err, getRedactedConf()) + logger.WarnToConsole("error loading configuration file: %v. Default configuration will be used.", err) + return err + } + err = viper.Unmarshal(&conf) + if err != nil { + logger.Warn(logSender, "", "error parsing configuration file: %v. Default configuration will be used: %+v", + err, getRedactedConf()) + logger.WarnToConsole("error parsing configuration file: %v. Default configuration will be used.", err) + return err + } + logger.Debug(logSender, "", "config file used: '%q', config loaded: %+v", viper.ConfigFileUsed(), getRedactedConf()) + return err +} diff --git a/examples/ldapauthserver/go.mod b/examples/ldapauthserver/go.mod new file mode 100644 index 00000000..f65828cc --- /dev/null +++ b/examples/ldapauthserver/go.mod @@ -0,0 +1,37 @@ +module github.com/drakkan/sftpgo/ldapauthserver + +go 1.25.0 + +require ( + github.com/go-chi/chi/v5 v5.2.3 + github.com/go-chi/render v1.0.3 + github.com/go-ldap/ldap/v3 v3.4.12 + github.com/nathanaelle/password/v2 v2.0.1 + github.com/rs/zerolog v1.34.0 + github.com/spf13/cobra v1.10.1 + github.com/spf13/viper v1.21.0 + golang.org/x/crypto v0.45.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 +) + +require ( + github.com/Azure/go-ntlmssp v0.1.0 // indirect + github.com/ajg/form v1.5.1 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.12.0 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect +) diff --git a/examples/ldapauthserver/go.sum b/examples/ldapauthserver/go.sum new file mode 100644 index 00000000..99ea3a37 --- /dev/null +++ b/examples/ldapauthserver/go.sum @@ -0,0 +1,114 @@ +github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A= +github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk= +github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI= +github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= +github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4= +github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/nathanaelle/password/v2 v2.0.1 h1:ItoCTdsuIWzilYmllQPa3DR3YoCXcpfxScWLqr8Ii2s= +github.com/nathanaelle/password/v2 v2.0.1/go.mod h1:eaoT+ICQEPNtikBRIAatN8ThWwMhVG+r1jTw60BvPJk= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= +github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= +github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/ldapauthserver/httpd/auth.go b/examples/ldapauthserver/httpd/auth.go new file mode 100644 index 00000000..dd76d52c --- /dev/null +++ b/examples/ldapauthserver/httpd/auth.go @@ -0,0 +1,145 @@ +package httpd + +import ( + "encoding/csv" + "errors" + "fmt" + "net/http" + "os" + "sync" + + unixcrypt "github.com/nathanaelle/password/v2" + + "github.com/drakkan/sftpgo/ldapauthserver/logger" + "github.com/drakkan/sftpgo/ldapauthserver/utils" + "golang.org/x/crypto/bcrypt" +) + +const ( + authenticationHeader = "WWW-Authenticate" + authenticationRealm = "LDAP Auth Server" + unauthResponse = "Unauthorized" +) + +var ( + md5CryptPwdPrefixes = []string{"$1$", "$apr1$"} + bcryptPwdPrefixes = []string{"$2a$", "$2$", "$2x$", "$2y$", "$2b$"} +) + +type httpAuthProvider interface { + getHashedPassword(username string) (string, bool) + isEnabled() bool +} + +type basicAuthProvider struct { + Path string + sync.RWMutex + Info os.FileInfo + Users map[string]string +} + +func newBasicAuthProvider(authUserFile string) (httpAuthProvider, error) { + basicAuthProvider := basicAuthProvider{ + Path: authUserFile, + Info: nil, + Users: make(map[string]string), + } + return &basicAuthProvider, basicAuthProvider.loadUsers() +} + +func (p *basicAuthProvider) isEnabled() bool { + return len(p.Path) > 0 +} + +func (p *basicAuthProvider) isReloadNeeded(info os.FileInfo) bool { + p.RLock() + defer p.RUnlock() + return p.Info == nil || p.Info.ModTime() != info.ModTime() || p.Info.Size() != info.Size() +} + +func (p *basicAuthProvider) loadUsers() error { + if !p.isEnabled() { + return nil + } + info, err := os.Stat(p.Path) + if err != nil { + logger.Debug(logSender, "", "unable to stat basic auth users file: %v", err) + return err + } + if p.isReloadNeeded(info) { + r, err := os.Open(p.Path) + if err != nil { + logger.Debug(logSender, "", "unable to open basic auth users file: %v", err) + return err + } + defer r.Close() + reader := csv.NewReader(r) + reader.Comma = ':' + reader.Comment = '#' + reader.TrimLeadingSpace = true + records, err := reader.ReadAll() + if err != nil { + logger.Debug(logSender, "", "unable to parse basic auth users file: %v", err) + return err + } + p.Lock() + defer p.Unlock() + p.Users = make(map[string]string) + for _, record := range records { + if len(record) == 2 { + p.Users[record[0]] = record[1] + } + } + logger.Debug(logSender, "", "number of users loaded for httpd basic auth: %v", len(p.Users)) + p.Info = info + } + return nil +} + +func (p *basicAuthProvider) getHashedPassword(username string) (string, bool) { + err := p.loadUsers() + if err != nil { + return "", false + } + p.RLock() + defer p.RUnlock() + pwd, ok := p.Users[username] + return pwd, ok +} + +func checkAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !validateCredentials(r) { + w.Header().Set(authenticationHeader, fmt.Sprintf("Basic realm=\"%v\"", authenticationRealm)) + sendAPIResponse(w, r, errors.New(unauthResponse), "", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func validateCredentials(r *http.Request) bool { + if !httpAuth.isEnabled() { + return true + } + username, password, ok := r.BasicAuth() + if !ok { + return false + } + if hashedPwd, ok := httpAuth.getHashedPassword(username); ok { + if utils.IsStringPrefixInSlice(hashedPwd, bcryptPwdPrefixes) { + err := bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(password)) + return err == nil + } + if utils.IsStringPrefixInSlice(hashedPwd, md5CryptPwdPrefixes) { + crypter, ok := unixcrypt.MD5.CrypterFound(hashedPwd) + if !ok { + err := errors.New("cannot found matching MD5 crypter") + logger.Debug(logSender, "", "error comparing password with MD5 crypt hash: %v", err) + return false + } + return crypter.Verify([]byte(password)) + } + } + return false +} diff --git a/examples/ldapauthserver/httpd/httpd.go b/examples/ldapauthserver/httpd/httpd.go new file mode 100644 index 00000000..08483d3a --- /dev/null +++ b/examples/ldapauthserver/httpd/httpd.go @@ -0,0 +1,149 @@ +package httpd + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/drakkan/sftpgo/ldapauthserver/config" + "github.com/drakkan/sftpgo/ldapauthserver/logger" + "github.com/drakkan/sftpgo/ldapauthserver/utils" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" +) + +const ( + logSender = "httpd" + versionPath = "/api/v1/version" + checkAuthPath = "/api/v1/check_auth" + maxRequestSize = 1 << 18 // 256KB +) + +var ( + ldapConfig config.LDAPConfig + httpAuth httpAuthProvider + certMgr *certManager + rootCAs *x509.CertPool +) + +// StartHTTPServer initializes and starts the HTTP Server +func StartHTTPServer(configDir string, httpConfig config.HTTPDConfig) error { + var err error + authUserFile := getConfigPath(httpConfig.AuthUserFile, configDir) + httpAuth, err = newBasicAuthProvider(authUserFile) + if err != nil { + return err + } + + router := chi.NewRouter() + router.Use(middleware.RequestID) + router.Use(middleware.RealIP) + router.Use(logger.NewStructuredLogger(logger.GetLogger())) + router.Use(middleware.Recoverer) + + router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) + })) + + router.MethodNotAllowed(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sendAPIResponse(w, r, nil, "Method not allowed", http.StatusMethodNotAllowed) + })) + + router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { + render.JSON(w, r, utils.GetAppVersion()) + }) + + router.Group(func(router chi.Router) { + router.Use(checkAuth) + + router.Post(checkAuthPath, checkSFTPGoUserAuth) + }) + + ldapConfig = config.GetLDAPConfig() + loadCACerts(configDir) + + certificateFile := getConfigPath(httpConfig.CertificateFile, configDir) + certificateKeyFile := getConfigPath(httpConfig.CertificateKeyFile, configDir) + + httpServer := &http.Server{ + Addr: fmt.Sprintf("%s:%d", httpConfig.BindAddress, httpConfig.BindPort), + Handler: router, + ReadTimeout: 70 * time.Second, + WriteTimeout: 70 * time.Second, + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 16, // 64KB + } + if len(certificateFile) > 0 && len(certificateKeyFile) > 0 { + certMgr, err = newCertManager(certificateFile, certificateKeyFile) + if err != nil { + return err + } + config := &tls.Config{ + GetCertificate: certMgr.GetCertificateFunc(), + MinVersion: tls.VersionTLS12, + } + httpServer.TLSConfig = config + return httpServer.ListenAndServeTLS("", "") + } + return httpServer.ListenAndServe() +} + +func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { + var errorString string + if err != nil { + errorString = err.Error() + } + resp := apiResponse{ + Error: errorString, + Message: message, + HTTPStatus: code, + } + ctx := context.WithValue(r.Context(), render.StatusCtxKey, code) + render.JSON(w, r.WithContext(ctx), resp) +} + +func loadCACerts(configDir string) error { + var err error + rootCAs, err = x509.SystemCertPool() + if err != nil { + rootCAs = x509.NewCertPool() + } + for _, ca := range ldapConfig.CACertificates { + caPath := getConfigPath(ca, configDir) + certs, err := os.ReadFile(caPath) + if err != nil { + logger.Warn(logSender, "", "error loading ca cert %q: %v", caPath, err) + return err + } + if !rootCAs.AppendCertsFromPEM(certs) { + logger.Warn(logSender, "", "unable to add ca cert %q", caPath) + } else { + logger.Debug(logSender, "", "ca cert %q added to the trusted certificates", caPath) + } + } + + return nil +} + +// ReloadTLSCertificate reloads the TLS certificate and key from the configured paths +func ReloadTLSCertificate() { + if certMgr != nil { + certMgr.loadCertificate() + } +} + +func getConfigPath(name, configDir string) string { + if !utils.IsFileInputValid(name) { + return "" + } + if len(name) > 0 && !filepath.IsAbs(name) { + return filepath.Join(configDir, name) + } + return name +} diff --git a/examples/ldapauthserver/httpd/ldapauth.go b/examples/ldapauthserver/httpd/ldapauth.go new file mode 100644 index 00000000..94aafb10 --- /dev/null +++ b/examples/ldapauthserver/httpd/ldapauth.go @@ -0,0 +1,143 @@ +package httpd + +import ( + "bytes" + "crypto/tls" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/drakkan/sftpgo/ldapauthserver/logger" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" + "github.com/go-ldap/ldap/v3" + "golang.org/x/crypto/ssh" +) + +func getSFTPGoUser(entry *ldap.Entry, username string) (SFTPGoUser, error) { + var err error + var user SFTPGoUser + uid := ldapConfig.DefaultUID + gid := ldapConfig.DefaultGID + status := 1 + + if !ldapConfig.ForceDefaultUID { + uid, err = strconv.Atoi(entry.GetAttributeValue(ldapConfig.GetUIDNumber())) + if err != nil { + return user, err + } + } + + if !ldapConfig.ForceDefaultGID { + uid, err = strconv.Atoi(entry.GetAttributeValue(ldapConfig.GetGIDNumber())) + if err != nil { + return user, err + } + } + + sftpgoUser := SFTPGoUser{ + Username: username, + HomeDir: entry.GetAttributeValue(ldapConfig.GetHomeDirectory()), + UID: uid, + GID: gid, + Status: status, + } + sftpgoUser.Permissions = make(map[string][]string) + sftpgoUser.Permissions["/"] = []string{"*"} + return sftpgoUser, nil +} + +func checkSFTPGoUserAuth(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + var authReq externalAuthRequest + err := render.DecodeJSON(r.Body, &authReq) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "error decoding auth request: %v", err) + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + l, err := ldap.DialURL(ldapConfig.BindURL, ldap.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldapConfig.InsecureSkipVerify, + RootCAs: rootCAs, + })) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "error connecting to the LDAP server: %v", err) + sendAPIResponse(w, r, err, "Error connecting to the LDAP server", http.StatusInternalServerError) + return + } + defer l.Close() + + err = l.Bind(ldapConfig.BindUsername, ldapConfig.BindPassword) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "error binding to the LDAP server: %v", err) + sendAPIResponse(w, r, err, "Error binding to the LDAP server", http.StatusInternalServerError) + return + } + + searchRequest := ldap.NewSearchRequest( + ldapConfig.BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + strings.Replace(ldapConfig.SearchFilter, "%s", ldap.EscapeFilter(authReq.Username), 1), + ldapConfig.SearchBaseAttrs, + nil, + ) + + sr, err := l.Search(searchRequest) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "error searching LDAP user %q: %v", authReq.Username, err) + sendAPIResponse(w, r, err, "Error searching LDAP user", http.StatusInternalServerError) + return + } + + if len(sr.Entries) != 1 { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "expected one user, found: %v", len(sr.Entries)) + sendAPIResponse(w, r, nil, fmt.Sprintf("Expected one user, found: %v", len(sr.Entries)), http.StatusNotFound) + return + } + + if len(authReq.PublicKey) > 0 { + userKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(authReq.PublicKey)) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "invalid public key for user %q: %v", authReq.Username, err) + sendAPIResponse(w, r, err, "Invalid public key", http.StatusBadRequest) + return + } + authOk := false + for _, k := range sr.Entries[0].GetAttributeValues(ldapConfig.GetPublicKey()) { + key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k)) + // we skip an invalid public key stored inside the LDAP server + if err != nil { + continue + } + if bytes.Equal(key.Marshal(), userKey.Marshal()) { + authOk = true + break + } + } + if !authOk { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "public key authentication failed for user: %q", authReq.Username) + sendAPIResponse(w, r, nil, "public key authentication failed", http.StatusForbidden) + return + } + } else { + // bind to the LDAP server with the user dn and the given password to check the password + userdn := sr.Entries[0].DN + err = l.Bind(userdn, authReq.Password) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "password authentication failed for user: %q", authReq.Username) + sendAPIResponse(w, r, nil, "password authentication failed", http.StatusForbidden) + return + } + } + + user, err := getSFTPGoUser(sr.Entries[0], authReq.Username) + if err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "get user from LDAP entry failed for username %q: %v", + authReq.Username, err) + sendAPIResponse(w, r, err, "mapping LDAP user failed", http.StatusInternalServerError) + return + } + + render.JSON(w, r, user) +} diff --git a/examples/ldapauthserver/httpd/models.go b/examples/ldapauthserver/httpd/models.go new file mode 100644 index 00000000..42aef279 --- /dev/null +++ b/examples/ldapauthserver/httpd/models.go @@ -0,0 +1,109 @@ +package httpd + +type apiResponse struct { + Error string `json:"error"` + Message string `json:"message"` + HTTPStatus int `json:"status"` +} + +type externalAuthRequest struct { + Username string `json:"username"` + Password string `json:"password"` + PublicKey string `json:"public_key"` +} + +// SFTPGoExtensionsFilter defines filters based on file extensions +type SFTPGoExtensionsFilter struct { + Path string `json:"path"` + AllowedExtensions []string `json:"allowed_extensions,omitempty"` + DeniedExtensions []string `json:"denied_extensions,omitempty"` +} + +// SFTPGoUserFilters defines additional restrictions for an SFTPGo user +type SFTPGoUserFilters struct { + AllowedIP []string `json:"allowed_ip,omitempty"` + DeniedIP []string `json:"denied_ip,omitempty"` + DeniedLoginMethods []string `json:"denied_login_methods,omitempty"` + FileExtensions []SFTPGoExtensionsFilter `json:"file_extensions,omitempty"` +} + +// S3FsConfig defines the configuration for S3 based filesystem +type S3FsConfig struct { + Bucket string `json:"bucket,omitempty"` + KeyPrefix string `json:"key_prefix,omitempty"` + Region string `json:"region,omitempty"` + AccessKey string `json:"access_key,omitempty"` + AccessSecret string `json:"access_secret,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + StorageClass string `json:"storage_class,omitempty"` + UploadPartSize int64 `json:"upload_part_size,omitempty"` + UploadConcurrency int `json:"upload_concurrency,omitempty"` +} + +// GCSFsConfig defines the configuration for Google Cloud Storage based filesystem +type GCSFsConfig struct { + Bucket string `json:"bucket,omitempty"` + KeyPrefix string `json:"key_prefix,omitempty"` + Credentials string `json:"credentials,omitempty"` + AutomaticCredentials int `json:"automatic_credentials,omitempty"` + StorageClass string `json:"storage_class,omitempty"` +} + +// SFTPGoFilesystem defines cloud storage filesystem details +type SFTPGoFilesystem struct { + // 0 local filesystem, 1 AWS S3 compatible, 2 Google Cloud Storage + Provider int `json:"provider"` + S3Config S3FsConfig `json:"s3config,omitempty"` + GCSConfig GCSFsConfig `json:"gcsconfig,omitempty"` +} + +type virtualFolder struct { + VirtualPath string `json:"virtual_path"` + MappedPath string `json:"mapped_path"` +} + +// SFTPGoUser defines an SFTPGo user +type SFTPGoUser struct { + // Database unique identifier + ID int64 `json:"id"` + // 1 enabled, 0 disabled (login is not allowed) + Status int `json:"status"` + // Username + Username string `json:"username"` + // Account expiration date as unix timestamp in milliseconds. An expired account cannot login. + // 0 means no expiration + ExpirationDate int64 `json:"expiration_date"` + Password string `json:"password,omitempty"` + PublicKeys []string `json:"public_keys,omitempty"` + HomeDir string `json:"home_dir"` + // Mapping between virtual paths and filesystem paths outside the home directory. Supported for local filesystem only + VirtualFolders []virtualFolder `json:"virtual_folders,omitempty"` + // If sftpgo runs as root system user then the created files and directories will be assigned to this system UID + UID int `json:"uid"` + // If sftpgo runs as root system user then the created files and directories will be assigned to this system GID + GID int `json:"gid"` + // Maximum concurrent sessions. 0 means unlimited + MaxSessions int `json:"max_sessions"` + // Maximum size allowed as bytes. 0 means unlimited + QuotaSize int64 `json:"quota_size"` + // Maximum number of files allowed. 0 means unlimited + QuotaFiles int `json:"quota_files"` + // List of the granted permissions + Permissions map[string][]string `json:"permissions"` + // Used quota as bytes + UsedQuotaSize int64 `json:"used_quota_size"` + // Used quota as number of files + UsedQuotaFiles int `json:"used_quota_files"` + // Last quota update as unix timestamp in milliseconds + LastQuotaUpdate int64 `json:"last_quota_update"` + // Maximum upload bandwidth as KB/s, 0 means unlimited + UploadBandwidth int64 `json:"upload_bandwidth"` + // Maximum download bandwidth as KB/s, 0 means unlimited + DownloadBandwidth int64 `json:"download_bandwidth"` + // Last login as unix timestamp in milliseconds + LastLogin int64 `json:"last_login"` + // Additional restrictions + Filters SFTPGoUserFilters `json:"filters"` + // Filesystem configuration details + FsConfig SFTPGoFilesystem `json:"filesystem"` +} diff --git a/examples/ldapauthserver/httpd/tlsutils.go b/examples/ldapauthserver/httpd/tlsutils.go new file mode 100644 index 00000000..67f4ead5 --- /dev/null +++ b/examples/ldapauthserver/httpd/tlsutils.go @@ -0,0 +1,49 @@ +package httpd + +import ( + "crypto/tls" + "sync" + + "github.com/drakkan/sftpgo/ldapauthserver/logger" +) + +type certManager struct { + certPath string + keyPath string + sync.RWMutex + cert *tls.Certificate +} + +func (m *certManager) loadCertificate() error { + newCert, err := tls.LoadX509KeyPair(m.certPath, m.keyPath) + if err != nil { + logger.Warn(logSender, "", "unable to load https certificate: %v", err) + return err + } + logger.Debug(logSender, "", "https certificate successfully loaded") + m.Lock() + defer m.Unlock() + m.cert = &newCert + return nil +} + +func (m *certManager) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + m.RLock() + defer m.RUnlock() + return m.cert, nil + } +} + +func newCertManager(certificateFile, certificateKeyFile string) (*certManager, error) { + manager := &certManager{ + cert: nil, + certPath: certificateFile, + keyPath: certificateKeyFile, + } + err := manager.loadCertificate() + if err != nil { + return nil, err + } + return manager, nil +} diff --git a/examples/ldapauthserver/ldapauth.toml b/examples/ldapauthserver/ldapauth.toml new file mode 100644 index 00000000..4f3e160c --- /dev/null +++ b/examples/ldapauthserver/ldapauth.toml @@ -0,0 +1,33 @@ +[httpd] +bind_address = "" +bind_port = 9000 +# Path to a file used to store usernames and passwords for basic authentication. It can be generated using the Apache htpasswd tool +auth_user_file = "" +# If both the certificate and the private key are provided, the server will expect HTTPS connections +certificate_file = "" +certificate_key_file = "" + +[ldap] +basedn = "dc=example,dc=com" +bind_url = "ldap://127.0.0.1:389" +bind_username = "cn=Directory Manager" +bind_password = "YOUR_ADMIN_PASSWORD_HERE" +search_filter = "(&(objectClass=nsPerson)(uid=%s))" +# you can change the name of the search base attributes to adapt them to your schema but the order must remain the same +search_base_attrs = [ + "dn", + "homeDirectory", + "uidNumber", + "gidNumber", + "nsSshPublicKey" +] +default_uid = 0 +default_gid = 0 +force_default_uid = true +force_default_gid = true +# if true, ldaps accepts any certificate presented by the LDAP server and any host name in that certificate. +# This should be used only for testing +insecure_skip_verify = false +# list of root CA to use for ldaps connections +# If you use a self signed certificate is better to add the root CA to this list than set insecure_skip_verify to true +ca_certificates = [] diff --git a/examples/ldapauthserver/logger/logger.go b/examples/ldapauthserver/logger/logger.go new file mode 100644 index 00000000..998a8e7c --- /dev/null +++ b/examples/ldapauthserver/logger/logger.go @@ -0,0 +1,126 @@ +package logger + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + + "github.com/rs/zerolog" + lumberjack "gopkg.in/natefinch/lumberjack.v2" +) + +const ( + dateFormat = "2006-01-02T15:04:05.000" // YYYY-MM-DDTHH:MM:SS.ZZZ +) + +var ( + logger zerolog.Logger + consoleLogger zerolog.Logger +) + +// GetLogger get the configured logger instance +func GetLogger() *zerolog.Logger { + return &logger +} + +// InitLogger initialize loggers +func InitLogger(logFilePath string, logMaxSize, logMaxBackups, logMaxAge int, logCompress bool, level zerolog.Level) { + zerolog.TimeFieldFormat = dateFormat + if isLogFilePathValid(logFilePath) { + logger = zerolog.New(&lumberjack.Logger{ + Filename: logFilePath, + MaxSize: logMaxSize, + MaxBackups: logMaxBackups, + MaxAge: logMaxAge, + Compress: logCompress, + }) + EnableConsoleLogger(level) + } else { + logger = zerolog.New(&logSyncWrapper{ + output: os.Stdout, + }) + consoleLogger = zerolog.Nop() + } + logger.Level(level) +} + +// DisableLogger disable the main logger. +// ConsoleLogger will not be affected +func DisableLogger() { + logger = zerolog.Nop() +} + +// EnableConsoleLogger enables the console logger +func EnableConsoleLogger(level zerolog.Level) { + consoleOutput := zerolog.ConsoleWriter{ + Out: os.Stdout, + TimeFormat: dateFormat, + NoColor: runtime.GOOS == "windows", + } + consoleLogger = zerolog.New(consoleOutput).With().Timestamp().Logger().Level(level) +} + +// Debug logs at debug level for the specified sender +func Debug(prefix, requestID string, format string, v ...interface{}) { + logger.Debug(). + Timestamp(). + Str("sender", prefix). + Str("request_id", requestID). + Msg(fmt.Sprintf(format, v...)) +} + +// Info logs at info level for the specified sender +func Info(prefix, requestID string, format string, v ...interface{}) { + logger.Info(). + Timestamp(). + Str("sender", prefix). + Str("request_id", requestID). + Msg(fmt.Sprintf(format, v...)) +} + +// Warn logs at warn level for the specified sender +func Warn(prefix, requestID string, format string, v ...interface{}) { + logger.Warn(). + Timestamp(). + Str("sender", prefix). + Str("request_id", requestID). + Msg(fmt.Sprintf(format, v...)) +} + +// Error logs at error level for the specified sender +func Error(prefix, requestID string, format string, v ...interface{}) { + logger.Error(). + Timestamp(). + Str("sender", prefix). + Str("request_id", requestID). + Msg(fmt.Sprintf(format, v...)) +} + +// DebugToConsole logs at debug level to stdout +func DebugToConsole(format string, v ...interface{}) { + consoleLogger.Debug().Msg(fmt.Sprintf(format, v...)) +} + +// InfoToConsole logs at info level to stdout +func InfoToConsole(format string, v ...interface{}) { + consoleLogger.Info().Msg(fmt.Sprintf(format, v...)) +} + +// WarnToConsole logs at info level to stdout +func WarnToConsole(format string, v ...interface{}) { + consoleLogger.Warn().Msg(fmt.Sprintf(format, v...)) +} + +// ErrorToConsole logs at error level to stdout +func ErrorToConsole(format string, v ...interface{}) { + consoleLogger.Error().Msg(fmt.Sprintf(format, v...)) +} + +func isLogFilePathValid(logFilePath string) bool { + cleanInput := filepath.Clean(logFilePath) + if cleanInput == "." || cleanInput == ".." { + return false + } + return true +} diff --git a/logger/request_logger.go b/examples/ldapauthserver/logger/request_logger.go similarity index 71% rename from logger/request_logger.go rename to examples/ldapauthserver/logger/request_logger.go index a21cdb3a..e406d2e4 100644 --- a/logger/request_logger.go +++ b/examples/ldapauthserver/logger/request_logger.go @@ -5,7 +5,7 @@ import ( "net/http" "time" - "github.com/go-chi/chi/middleware" + "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" ) @@ -15,12 +15,9 @@ type StructuredLogger struct { Logger *zerolog.Logger } -// StructuredLoggerEntry defines a log entry. -// It implements chi.middleware.LogEntry interface +// StructuredLoggerEntry ... type StructuredLoggerEntry struct { - // The zerolog logger Logger *zerolog.Logger - // fields to write in the log fields map[string]interface{} } @@ -53,18 +50,24 @@ func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { } // Write logs a new entry at the end of the HTTP request -func (l *StructuredLoggerEntry) Write(status, bytes int, elapsed time.Duration) { - l.Logger.Info().Fields(l.fields).Int( - "resp_status", status).Int( - "resp_size", bytes).Int64( - "elapsed_ms", elapsed.Nanoseconds()/1000000).Str( - "sender", "httpd").Msg( - "") +func (l *StructuredLoggerEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { + l.Logger.Info(). + Timestamp(). + Str("sender", "httpd"). + Fields(l.fields). + Int("resp_status", status). + Int("resp_size", bytes). + Int64("elapsed_ms", elapsed.Nanoseconds()/1000000). + Send() } // Panic logs panics func (l *StructuredLoggerEntry) Panic(v interface{}, stack []byte) { - l.Logger.Error().Fields(l.fields).Str( - "stack", string(stack)).Str( - "panic", fmt.Sprintf("%+v", v)).Msg("") + l.Logger.Error(). + Timestamp(). + Str("sender", "httpd"). + Fields(l.fields). + Str("stack", string(stack)). + Str("panic", fmt.Sprintf("%+v", v)). + Send() } diff --git a/examples/ldapauthserver/logger/sync_wrapper.go b/examples/ldapauthserver/logger/sync_wrapper.go new file mode 100644 index 00000000..c3737604 --- /dev/null +++ b/examples/ldapauthserver/logger/sync_wrapper.go @@ -0,0 +1,17 @@ +package logger + +import ( + "os" + "sync" +) + +type logSyncWrapper struct { + sync.Mutex + output *os.File +} + +func (l *logSyncWrapper) Write(b []byte) (n int, err error) { + l.Lock() + defer l.Unlock() + return l.output.Write(b) +} diff --git a/examples/ldapauthserver/main.go b/examples/ldapauthserver/main.go new file mode 100644 index 00000000..25521625 --- /dev/null +++ b/examples/ldapauthserver/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/drakkan/sftpgo/ldapauthserver/cmd" + +func main() { + cmd.Execute() +} diff --git a/examples/ldapauthserver/utils/utils.go b/examples/ldapauthserver/utils/utils.go new file mode 100644 index 00000000..62eb340b --- /dev/null +++ b/examples/ldapauthserver/utils/utils.go @@ -0,0 +1,28 @@ +package utils + +import ( + "path/filepath" + "strings" +) + +// IsFileInputValid returns true this is a valid file name. +// This method must be used before joining a file name, generally provided as +// user input, with a directory +func IsFileInputValid(fileInput string) bool { + cleanInput := filepath.Clean(fileInput) + if cleanInput == "." || cleanInput == ".." { + return false + } + return true +} + +// IsStringPrefixInSlice searches a string prefix in a slice and returns true +// if a matching prefix is found +func IsStringPrefixInSlice(obj string, list []string) bool { + for _, v := range list { + if strings.HasPrefix(obj, v) { + return true + } + } + return false +} diff --git a/utils/version.go b/examples/ldapauthserver/utils/version.go similarity index 84% rename from utils/version.go rename to examples/ldapauthserver/utils/version.go index 4867300a..11de6f4f 100644 --- a/utils/version.go +++ b/examples/ldapauthserver/utils/version.go @@ -1,6 +1,6 @@ package utils -const version = "0.9.1" +const version = "0.1.0-dev" var ( commit = "" @@ -15,6 +15,14 @@ type VersionInfo struct { CommitHash string `json:"commit_hash"` } +func init() { + versionInfo = VersionInfo{ + Version: version, + CommitHash: commit, + BuildDate: date, + } +} + // GetVersionAsString returns the string representation of the VersionInfo struct func (v *VersionInfo) GetVersionAsString() string { versionString := v.Version @@ -27,10 +35,7 @@ func (v *VersionInfo) GetVersionAsString() string { return versionString } -func init() { - versionInfo = VersionInfo{ - Version: version, - CommitHash: commit, - BuildDate: date, - } +// GetAppVersion returns VersionInfo struct +func GetAppVersion() VersionInfo { + return versionInfo } diff --git a/examples/php-activedirectory-http-server/README.md b/examples/php-activedirectory-http-server/README.md new file mode 100644 index 00000000..48d8f1b1 --- /dev/null +++ b/examples/php-activedirectory-http-server/README.md @@ -0,0 +1,195 @@ +# SFTPGo on Windows with Active Directory Integration + Caddy Static File Server Example + +[![SFTPGo on Windows with Active Directory Integration + Caddy Static File Server Example](https://img.youtube.com/vi/M5UcJI8t4AI/0.jpg)](https://www.youtube.com/watch?v=M5UcJI8t4AI) + +This is similar to the ldapauthserver example, but is more specific to using Active Directory along with using SFTPGo on a Windows Server. + +The Youtube Walkthrough/Tutorial video above goes into considerable more detail, but in short, it walks through setting up SFTPGo on a new Windows Server, and enables the External Authentication feature within SFTPGo, along with my `sftpgo-ldap-http-server` project, to allow for user authentication into SFTPGo to occur through one or more Active Directory connections. + +Additionally, I go through using the Caddy web server, to help enable serving of static files, if this is something that would be of interest for you. + +To get started, you'll want to download the latest release ZIP package from the [sftpgo-ldap-http-server repository](https://github.com/orware/sftpgo-ldap-http-server). + +The ZIP itself contains the `sftpgo-ldap-http-server.exe` file, along with an `OpenLDAP` folder (mainly to help if you want to use TLS for your LDAP connections), and a `Data` which contains a logs folder, a configuration.example.php file, a functions.php file, and the LICENSE and README files. + +The video above goes through the whole process, but to get started you'll want to install SFTPGo on your server, and then extract the `sftpgo-ldap-http-server` ZIP file on the server as well into a separate folder. Then you'll want to copy the configuration.example.php file and name it `configuration.php` and begin customizing the settings (e.g. add in your own LDAP settings, along with how you may want to have your folders be created). At the very minimum you'll want to make sure that the home directories are set correctly to how you want the folders to be created for your environment (you don't have to use the virtual folders or really any of the other functionality if you don't need it). + +Once configured, from a command prompt window, if you are already in the same folder as where you extracted the `sftpgo-ldap-http-server` ZIP, you may simply call the `sftpgo-ldap-http-server.exe` and it should start up a simple HTTP server on Port 9001 running on localhost (the port can be adjusted via the `configuration.php` file as well). Now all you have to do is point SFTPGo's `external_auth_hook` option to point to `http://localhost:9001/` and you should be able to run some authentication tests (assuming you have all of your settings correct and there are no intermediate issues). + +The video above definitely goes through some troubleshooting situations you might find yourself coming across, so while it is long (at about 1 hour, 42 minutes), it may be helpful to review and avoid some issues and just to learn a bit more about SFTPGo and the integration above. + +## Example Virtual Folders Configuration (Allowing for Both a Public and Private Folder) + +The following can be utilized if you'd like to assign your users both a Private Virtual Folder and Public Virtual Folder. + +By itself, the Public Virtual Folder isn't necessarily public, so keep that in mind. Only by combining things together with the Caddy web server (and Caddyfile example configuration down below) can you be successful in making the `F:\files\public` folder from the example public. + +```php +$virtual_folders['example'] = [ + [ + //"id" => 0, + "name" => "private-#USERNAME#", + "mapped_path" => 'F:\files\private\#USERNAME#', + //"used_quota_size" => 0, + //"used_quota_files" => 0, + //"last_quota_update" => 0, + "virtual_path" => "/_private", + "quota_size" => -1, + "quota_files" => -1 + ], + [ + //"id" => 0, + "name" => "public-#USERNAME#", + "mapped_path" => 'F:\files\public\#USERNAME#', + //"used_quota_size" => 0, + //"used_quota_files" => 0, + //"last_quota_update" => 0, + "virtual_path" => "/_public", + "quota_size" => -1, + "quota_files" => -1 + ] +]; +``` + +## Example Connection "Output Object" Allowing For No Files in the User's Home Directory ("Root Directory") but Allowing for Files in the Public/Private Virtual Folders + +The magic here happens in the "permissions" value, by limiting the root/home directory to just the list/download permissions, and then allowing all permissions on the Public/Private virtual folders. + +```php +$connection_output_objects['example'] = [ + 'status' => 1, + 'username' => '', + 'expiration_date' => 0, + 'home_dir' => '', + 'uid' => 0, + 'gid' => 0, + 'max_sessions' => 0, + 'quota_size' => 0, + 'quota_files' => 100000, + 'permissions' => [ + "/" => ["list", "download"], + "/_private" => ["*"], + "/_public" => ["*"], + ], + 'upload_bandwidth' => 0, + 'download_bandwidth' => 0, + 'filters' => [ + 'allowed_ip' => [], + 'denied_ip' => [], + ], + 'public_keys' => [], +]; +``` + +## Recommended Usage of Automatic Groups Mode (Limiting by Group Prefix) + +The `sftpgo-ldap-http-server` project is able to automatically create virtual folders for any groups your user is a memberof if the automatic mode is turned on. However, by having a specific set of allowed prefixes defined, you can limit things to just those groups that begin with the prefixes you've listed, which can be helpful. The prefix itself will be removed from the group name when added as a virtual folder for the user. + +```php +// If automatic groups mode is disabled, then you have to manually add the allowed groups into $allowed_groups down below: +// If enabled, then any groups you are a memberof will automatically be added in using the template below. +$auto_groups_mode = true; + +$auto_groups_mode_virtual_folder_template = [ + [ + //"id" => 0, + "name" => "groups-#GROUP#", + "mapped_path" => 'F:\files\groups\#GROUP#', + //"used_quota_size" => 0, + //"used_quota_files" => 0, + //"last_quota_update" => 0, + "virtual_path" => "/groups/#GROUP#", + "quota_size" => 0, + "quota_files" => 100000 + ] +]; + +// Used only when auto groups mode is enabled and will help prevent all your groups from being +// added into SFTPGo since only groups with the prefixes defined here will be automatically added +// with prefixes automatically removed when listed as a virtual folder (e.g. a group with name +// "sftpgo-example" would simply become "example"). +$allowed_group_prefixes = [ + 'sftpgo-' +]; +``` + +## Example Caddyfile Configuration You Can Adapt for Your Needs + +```shell +### Re-usable snippets: + +(add_static_file_serving_features) { + + # Allow accessing files without requiring .html: + try_files {path} {path}.html + + # Enable Static File Server and Directory Browsing: + file_server browse + + # Enable templating functionality: + templates + + # Enable Compression for Output: + encode zstd gzip + + handle_errors { + respond "
{http.error.status_code} {http.error.status_text}
" + } +} + +(add_hsts_headers) { + header { + # Enable HTTP Strict Transport Security (HSTS) to force clients to always + + # connect via HTTPS (do not use if only testing) + Strict-Transport-Security "max-age=31536000; includeSubDomains" + + # Enable cross-site filter (XSS) and tell browser to block detected attacks + X-XSS-Protection "1; mode=block" + + # Prevent some browsers from MIME-sniffing a response away from the declared Content-Type + X-Content-Type-Options "nosniff" + + # Disallow the site to be rendered within a frame (clickjacking protection) + X-Frame-Options "DENY" + + # keep referrer data off of HTTP connections + Referrer-Policy no-referrer-when-downgrade + } +} + +(add_logging_with_path) { + log { + output file "{args.0}" { + roll_size 100mb + roll_keep 5 + roll_keep_for 720h + } + + format json + #format console + #format single_field common_log + } +} + +### Site Definitions: + +public.example.com { + + # Site Root: + root * F:\files\public + + import add_logging_with_path "F:\caddy\logs\public_example_com_access.log" + import add_static_file_serving_features + import add_hsts_headers +} + + +### Reverse Proxy Definitions: + +webdav.example.com { + reverse_proxy localhost:9000 + + import add_logging_with_path "F:\caddy\logs\webdav_example_com_access.log" +} +``` diff --git a/examples/quotascan/README.md b/examples/quotascan/README.md new file mode 100644 index 00000000..0830a308 --- /dev/null +++ b/examples/quotascan/README.md @@ -0,0 +1,23 @@ +# Update user quota + +:warning: Since v2.4.0 you can use the [EventManager](https://docs.sftpgo.com/latest/eventmanager/) to schedule quota scans. + +The `scanuserquota` example script shows how to use the SFTPGo REST API to update the users' quota. + +The stored quota may be incorrect for several reasons, such as an unexpected shutdown while uploading files, temporary provider failures, files copied outside of SFTPGo, and so on. + +A quota scan updates the number of files and their total size for the specified user and the virtual folders, if any, included in his quota. + +If you want to track quotas, a scheduled quota scan is recommended. You can use this example as a starting point. + +The script is written in Python and has the following requirements: + +- python3 or python2 +- python [Requests](https://requests.readthedocs.io/en/master/) module + +The provided example tries to connect to an SFTPGo instance running on `127.0.0.1:8080` using the following credentials: + +- username: `admin` +- password: `password` + +Please edit the script according to your needs. diff --git a/examples/quotascan/scanuserquota b/examples/quotascan/scanuserquota new file mode 100755 index 00000000..7648bdd5 --- /dev/null +++ b/examples/quotascan/scanuserquota @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +from datetime import datetime +import sys +import time + +import pytz +import requests + +try: + import urllib.parse as urlparse +except ImportError: + import urlparse + +# change base_url to point to your SFTPGo installation +base_url = "http://127.0.0.1:8080" +# set to False if you want to skip TLS certificate validation +verify_tls_cert = True +# set the credentials for a valid admin here +admin_user = "admin" +admin_password = "password" + + +# set your update conditions here +def needQuotaUpdate(user): + if user["status"] == 0: # inactive user + return False + if user["quota_size"] == 0 and user["quota_files"] == 0: # no quota restrictions + return False + return True + + +class UpdateQuota: + + def __init__(self): + self.limit = 100 + self.offset = 0 + self.access_token = "" + self.access_token_expiration = None + + def printLog(self, message): + print("{} - {}".format(datetime.now(), message)) + + def checkAccessToken(self): + if self.access_token != "" and self.access_token_expiration: + expire_diff = self.access_token_expiration - datetime.now(tz=pytz.UTC) + # we don't use total_seconds to be python 2 compatible + seconds_to_expire = expire_diff.days * 86400 + expire_diff.seconds + if seconds_to_expire > 180: + return + + auth = requests.auth.HTTPBasicAuth(admin_user, admin_password) + r = requests.get(urlparse.urljoin(base_url, "api/v2/token"), auth=auth, verify=verify_tls_cert, timeout=10) + if r.status_code != 200: + self.printLog("error getting access token: {}".format(r.text)) + sys.exit(1) + self.access_token = r.json()["access_token"] + self.access_token_expiration = pytz.timezone("UTC").localize(datetime.strptime(r.json()["expires_at"], + "%Y-%m-%dT%H:%M:%SZ")) + + def getAuthHeader(self): + self.checkAccessToken() + return {"Authorization": "Bearer " + self.access_token} + + def waitForQuotaUpdate(self, username): + while True: + auth_header = self.getAuthHeader() + r = requests.get(urlparse.urljoin(base_url, "api/v2/quotas/users/scans"), headers=auth_header, verify=verify_tls_cert, + timeout=10) + if r.status_code != 200: + self.printLog("error getting quota scans while waiting for {}: {}".format(username, r.text)) + sys.exit(1) + + scanning = False + for scan in r.json(): + if scan["username"] == username: + scanning = True + if not scanning: + break + self.printLog("waiting for the quota scan to complete for user {}".format(username)) + time.sleep(2) + + self.printLog("quota update for user {} finished".format(username)) + + def updateUserQuota(self, username): + self.printLog("starting quota update for user {}".format(username)) + auth_header = self.getAuthHeader() + r = requests.post(urlparse.urljoin(base_url, "api/v2/quotas/users/" + username + "/scan"), headers=auth_header, + verify=verify_tls_cert, timeout=10) + if r.status_code != 202: + self.printLog("error starting quota scan for user {}: {}".format(username, r.text)) + sys.exit(1) + self.waitForQuotaUpdate(username) + + def updateUsersQuota(self): + while True: + self.printLog("get users, limit {} offset {}".format(self.limit, self.offset)) + auth_header = self.getAuthHeader() + payload = {"limit":self.limit, "offset":self.offset} + r = requests.get(urlparse.urljoin(base_url, "api/v2/users"), headers=auth_header, params=payload, + verify=verify_tls_cert, timeout=10) + if r.status_code != 200: + self.printLog("error getting users: {}".format(r.text)) + sys.exit(1) + users = r.json() + for user in users: + if needQuotaUpdate(user): + self.updateUserQuota(user["username"]) + else: + self.printLog("user {} does not need a quota update".format(user["username"])) + + self.offset += len(users) + if len(users) < self.limit: + break + + +if __name__ == '__main__': + q = UpdateQuota() + q.updateUsersQuota() diff --git a/go.mod b/go.mod index f3438371..d2fafb83 100644 --- a/go.mod +++ b/go.mod @@ -1,27 +1,185 @@ -module github.com/drakkan/sftpgo +module github.com/drakkan/sftpgo/v2 -go 1.12 +go 1.25.0 require ( - github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802 - github.com/go-chi/chi v4.0.2+incompatible - github.com/go-chi/render v1.0.1 - github.com/go-sql-driver/mysql v1.4.1 - github.com/lib/pq v1.2.0 - github.com/magiconair/properties v1.8.1 // indirect - github.com/mattn/go-sqlite3 v1.11.0 - github.com/pelletier/go-toml v1.4.0 // indirect - github.com/pkg/sftp v1.10.1 - github.com/rs/xid v1.2.1 - github.com/rs/zerolog v1.15.0 - github.com/spf13/afero v1.2.2 // indirect - github.com/spf13/cobra v0.0.5 - github.com/spf13/jwalterweatherman v1.1.0 // indirect - github.com/spf13/viper v1.4.0 - go.etcd.io/bbolt v1.3.3 - golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 - golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 // indirect - golang.org/x/sys v0.0.0-20190830142957-1e83adbbebd0 // indirect - google.golang.org/appengine v1.6.2 // indirect - gopkg.in/natefinch/lumberjack.v2 v2.0.0 + cloud.google.com/go/storage v1.60.0 + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 + github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 + github.com/alexedwards/argon2id v1.0.0 + github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 + github.com/aws/aws-sdk-go-v2 v1.41.3 + github.com/aws/aws-sdk-go-v2/config v1.32.11 + github.com/aws/aws-sdk-go-v2/credentials v1.19.11 + github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4 + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 + github.com/bmatcuk/doublestar/v4 v4.10.0 + github.com/cockroachdb/cockroach-go/v2 v2.4.3 + github.com/coreos/go-oidc/v3 v3.17.0 + github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b + github.com/eikenb/pipeat v0.0.0-20251030185646-385cd3c3e07b + github.com/fclairamb/ftpserverlib v0.30.0 + github.com/go-acme/lego/v4 v4.32.0 + github.com/go-chi/chi/v5 v5.2.5 + github.com/go-chi/render v1.0.3 + github.com/go-jose/go-jose/v4 v4.1.3 + github.com/go-sql-driver/mysql v1.9.3 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 + github.com/google/uuid v1.6.0 + github.com/hashicorp/go-hclog v1.6.3 + github.com/hashicorp/go-plugin v1.7.0 + github.com/hashicorp/go-retryablehttp v0.7.8 + github.com/jackc/pgx/v5 v5.8.0 + github.com/jlaffaye/ftp v0.2.0 + github.com/klauspost/compress v1.18.4 + github.com/lithammer/shortuuid/v4 v4.2.0 + github.com/mattn/go-sqlite3 v1.14.34 + github.com/mhale/smtpd v0.8.3 + github.com/minio/sio v0.4.3 + github.com/otiai10/copy v1.14.1 + github.com/pires/go-proxyproto v0.11.0 + github.com/pkg/sftp v1.13.10 + github.com/pquerna/otp v1.5.0 + github.com/prometheus/client_golang v1.23.2 + github.com/robfig/cron/v3 v3.0.1 + github.com/rs/cors v1.11.1 + github.com/rs/xid v1.6.0 + github.com/rs/zerolog v1.34.0 + github.com/sftpgo/sdk v0.1.9 + github.com/shirou/gopsutil/v3 v3.24.5 + github.com/spf13/afero v1.15.0 + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 + github.com/stretchr/testify v1.11.1 + github.com/studio-b12/gowebdav v0.12.0 + github.com/subosito/gotenv v1.6.0 + github.com/unrolled/secure v1.17.0 + github.com/wagslane/go-password-validator v0.3.0 + github.com/wneessen/go-mail v0.7.2 + github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a + go.etcd.io/bbolt v1.4.3 + gocloud.dev v0.45.0 + golang.org/x/crypto v0.49.0 + golang.org/x/net v0.52.0 + golang.org/x/oauth2 v0.36.0 + golang.org/x/sys v0.42.0 + golang.org/x/term v0.41.0 + golang.org/x/time v0.15.0 + google.golang.org/api v0.271.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 +) + +require ( + cel.dev/expr v0.25.1 // indirect + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.18.2 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/iam v1.5.3 // indirect + cloud.google.com/go/monitoring v1.24.3 // indirect + filippo.io/edwards25519 v1.2.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.7.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect + github.com/ajg/form v1.7.1 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/smithy-go v1.24.2 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/boombuler/barcode v1.1.0 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/envoyproxy/go-control-plane/envoy v1.37.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect + github.com/fatih/color v1.18.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect + github.com/go-viper/mapstructure/v2 v2.5.0 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.18.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/yamux v0.1.2 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kr/fs v0.1.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/miekg/dns v1.1.72 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/oklog/run v1.2.0 // indirect + github.com/otiai10/mint v1.6.3 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.67.5 // indirect + github.com/prometheus/procfs v0.20.1 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sagikazarmark/locafero v0.12.0 // indirect + github.com/shoenig/go-m1cpu v0.2.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/tklauser/go-sysconf v0.3.16 // indirect + github.com/tklauser/numcpus v0.11.0 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/detectors/gcp v1.42.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect + go.opentelemetry.io/otel v1.42.0 // indirect + go.opentelemetry.io/otel/metric v1.42.0 // indirect + go.opentelemetry.io/otel/sdk v1.42.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect + go.yaml.in/yaml/v2 v2.4.4 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/tools v0.43.0 // indirect + golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect + google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c // indirect + google.golang.org/grpc v1.79.2 // indirect + google.golang.org/protobuf v1.36.11 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace ( + github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f + github.com/robfig/cron/v3 => github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 ) diff --git a/go.sum b/go.sum index 512b8d6c..d81c977d 100644 --- a/go.sum +++ b/go.sum @@ -1,211 +1,477 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= -github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= -github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802 h1:RwMM1q/QSKYIGbHfOkf843hE8sSUJtf1dMwFPtEDmm0= -github.com/alexedwards/argon2id v0.0.0-20190612080829-01a59b2b8802/go.mod h1:4dsm7ufQm1Gwl8S2ss57u+2J7KlxIL2QUmFGlGtWogY= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= -github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM= +cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= +cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= +cloud.google.com/go/kms v1.26.0 h1:cK9mN2cf+9V63D3H1f6koxTatWy39aTI/hCjz1I+adU= +cloud.google.com/go/kms v1.26.0/go.mod h1:pHKOdFJm63hxBsiPkYtowZPltu9dW0MWvBa6IA4HM58= +cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA= +cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak= +cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= +cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= +cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= +cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= +cloud.google.com/go/storage v1.60.0 h1:oBfZrSOCimggVNz9Y/bXY35uUcts7OViubeddTTVzQ8= +cloud.google.com/go/storage v1.60.0/go.mod h1:q+5196hXfejkctrnx+VYU8RKQr/L3c0cBIlrjmiAKE0= +cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U= +cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= +github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4 h1:jWQK1GI+LeGGUKBADtcH2rRqPxYB1Ljwms5gFA2LqrM= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.4/go.mod h1:8mwH4klAm9DUgR2EEHyEEAQlRDvLPyg5fQry3y+cDew= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.7.0 h1:4iB+IesclUXdP0ICgAabvq2FYLXrJWKx1fJQ+GxSo3Y= +github.com/AzureAD/microsoft-authentication-library-for-go v1.7.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 h1:IEjq88XO4PuBDcvmjQJcQGg+w+UaafSy8G5Kcb5tBhI= +github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5/go.mod h1:exZ0C/1emQJAw5tHOaUDyY1ycttqBAPcxuzf7QbY6ec= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/ajg/form v1.7.1 h1:OsnBDzTkrWdrxvEnO68I72ZVGJGNaMwPhoAm0V+llgc= +github.com/ajg/form v1.7.1/go.mod h1:HL757PzLyNkj5AIfptT6L+iGNeXTlnrr/oDePGc/y7Q= +github.com/alexedwards/argon2id v1.0.0 h1:wJzDx66hqWX7siL/SRUmgz3F8YMrd/nfX/xHHcQQP0w= +github.com/alexedwards/argon2id v1.0.0/go.mod h1:tYKkqIjzXvZdzPvADMWOEZ+l6+BD6CtBXMj5fnJppiw= +github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964 h1:I9YN9WMo3SUh7p/4wKeNvD/IQla3U3SUa61U7ul+xM4= +github.com/amoghe/go-crypt v0.0.0-20220222110647-20eada5f5964/go.mod h1:eFiR01PwTcpbzXtdMces7zxg6utvFM5puiWHpWB8D/k= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 h1:N4lRUXZpZ1KVEUn6hxtco/1d2lgYhNn1fHkkl8WhlyQ= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= +github.com/aws/aws-sdk-go-v2/config v1.32.11/go.mod h1:twF11+6ps9aNRKEDimksp923o44w/Thk9+8YIlzWMmo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20 h1:qi3e/dmpdONhj1RyIZdi6DKKpDXS5Lb8ftr3p7cyHJc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20/go.mod h1:V1K+TeJVD5JOk3D9e5tsX2KUdL7BlB+FV6cBhdobN8c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11 h1:BYf7XNsJMzl4mObARUBUib+j2tf0U//JAAtTnYqvqCw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11/go.mod h1:aEUS4WrNk/+FxkBZZa7tVgp4pGH+kFGW40Y8rCPqt5g= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19 h1:JnQeStZvPHFHeyky/7LbMlyQjUa+jIBj36OlWm0pzIk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19/go.mod h1:HGyasyHvYdFQeJhvDHfH7HXkHh57htcJGKDZ+7z+I24= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4 h1:4ExZyubQ6LQQVuF2Qp9OsfEvsTdAWh5Gfwf6PgIdLdk= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.4/go.mod h1:NF3JcMGOiARAss1ld3WGORCw71+4ExDD2cbbdKS5PpA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7/go.mod h1:idzZ7gmDeqeNrSPkdbtMp9qWMgcBwykA7P7Rzh5DXVU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bmatcuk/doublestar/v4 v4.10.0 h1:zU9WiOla1YA122oLM6i4EXvGW62DvKZVxIe6TYWexEs= +github.com/bmatcuk/doublestar/v4 v4.10.0/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo= +github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= +github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik= +github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4= +github.com/cockroachdb/cockroach-go/v2 v2.4.3 h1:LJO3K3jC5WXvMePRQSJE1NsIGoFGcEx1LW83W6RAlhw= +github.com/cockroachdb/cockroach-go/v2 v2.4.3/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= +github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= -github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-chi/chi v4.0.2+incompatible h1:maB6vn6FqCxrpz4FqWdh4+lwpyZIQS7YEAUcHlgXVRs= -github.com/go-chi/chi v4.0.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= -github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= -github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= -github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= -github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= -github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= -github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= -github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= -github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= -github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= -github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0 h1:EW9gIJRmt9lzk66Fhh4S8VEtURA6QHZqGeSRE9Nb2/U= +github.com/drakkan/cron/v3 v3.0.0-20230222140221-217a1e4d96c0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f h1:S9JUlrOzjK58UKoLqqb40YLyVlt0bcIFtYrvnanV3zc= +github.com/drakkan/ftp v0.0.0-20240430173938-7ba8270c8e7f/go.mod h1:4p8lUl4vQ80L598CygL+3IFtm+3nggvvW/palOlViwE= +github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b h1:Y1tLiQ8fnxM5f3wiBjAXsHzHNwiY9BR+mXZA75nZwrs= +github.com/drakkan/webdav v0.0.0-20241026165615-b8b8f74ae71b/go.mod h1:zOVb1QDhwwqWn2L2qZ0U3swMSO4GTSNyIwXCGO/UGWE= +github.com/eikenb/pipeat v0.0.0-20251030185646-385cd3c3e07b h1:G2Mm3YhlyjkFrNnvu5E6LtNcPJtggXL1i5ekDV4hDD4= +github.com/eikenb/pipeat v0.0.0-20251030185646-385cd3c3e07b/go.mod h1:XccPiThW83W5pzeOCsJAylEUtWeH+3zQVwiO402FXXc= +github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= +github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU= +github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ= +github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= +github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds= +github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/fclairamb/ftpserverlib v0.30.0 h1:caB9sDn1Au//q0j2ev/icPn388qPuk4k1ajSvglDcMQ= +github.com/fclairamb/ftpserverlib v0.30.0/go.mod h1:QmogtltTOgkihyKza0GNo37Mu4AEzbJ+sH6W9Y0MBIQ= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-acme/lego/v4 v4.32.0 h1:z7Ss7aa1noabhKj+DBzhNCO2SM96xhE3b0ucVW3x8Tc= +github.com/go-acme/lego/v4 v4.32.0/go.mod h1:lI2fZNdgeM/ymf9xQ9YKbgZm6MeDuf91UrohMQE4DhI= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= +github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= +github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= +github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= +github.com/google/wire v0.7.0/go.mod h1:n6YbUQD9cPKTnHXEBN2DXlOp/mVADhVErcMFb0v3J18= +github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= +github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= +github.com/googleapis/gax-go/v2 v2.18.0 h1:jxP5Uuo3bxm3M6gGtV94P4lliVetoCB4Wk2x8QA86LI= +github.com/googleapis/gax-go/v2 v2.18.0/go.mod h1:uSzZN4a356eRG985CzJ3WfbFSpqkLTjsnhWGJR6EwrE= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-plugin v1.7.0 h1:YghfQH/0QmPNc/AZMTFE3ac8fipZyZECHdDPshfk+mA= +github.com/hashicorp/go-plugin v1.7.0/go.mod h1:BExt6KEaIYx804z8k4gRzRLEvxKVb+kn0NMcihqOqb8= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= +github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= +github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94= +github.com/jhump/protoreflect v1.17.0/go.mod h1:h9+vUUL38jiBzck8ck+6G/aeMX8Z4QUY/NiJPwPNi+8= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= -github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/magiconair/properties v1.8.0 h1:LLgXmsheXeRoUOBOjtwPQCWIYqM/LU1ayDtDePerRcY= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4= -github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= -github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= -github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= -github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.4.0 h1:u3Z1r+oOXJIkxqw34zVhyPgjBsm6X2wn21NWs/HfSeg= -github.com/pelletier/go-toml v1.4.0/go.mod h1:PN7xzY2wHTK0K9p34ErDQMlFxa51Fk0OUruD3k1mMwo= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/sftp v1.10.1 h1:VasscCm72135zRysgrJDKsntdmPN+OuU3+nnHYA9wyc= -github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= +github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y= +github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 h1:PTw+yKnXcOFCR6+8hHTyWBeQ/P4Nb7dd4/0ohEcWQuM= +github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mhale/smtpd v0.8.3 h1:8j8YNXajksoSLZja3HdwvYVZPuJSqAxFsib3adzRRt8= +github.com/mhale/smtpd v0.8.3/go.mod h1:MQl+y2hwIEQCXtNhe5+55n0GZOjSmeqORDIXbqUL3x4= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= +github.com/minio/sio v0.4.3 h1:JqyID1XM86KwBZox5RAdLD4MLPIDoCY2cke2CXCJCkg= +github.com/minio/sio v0.4.3/go.mod h1:4ANoe4CCXqnt1FCiLM0+vlBUhhWZzVOhYCz0069KtFc= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/oklog/run v1.2.0 h1:O8x3yXwah4A73hJdlrwo/2X6J62gE5qTMusH0dvz60E= +github.com/oklog/run v1.2.0/go.mod h1:mgDbKRSwPhJfesJ4PntqFUbKQRZ50NgmZTSPlFA0YFk= +github.com/otiai10/copy v1.14.1 h1:5/7E6qsUMBaH5AnQ0sSLzzTg1oTECmcCmT6lvF45Na8= +github.com/otiai10/copy v1.14.1/go.mod h1:oQwrEDDOci3IM8dJF0d8+jnbfPDllW6vUjNc3DoZm9I= +github.com/otiai10/mint v1.6.3 h1:87qsV/aw1F5as1eH1zS/yqHY85ANKVMgkDrf9rcxbQs= +github.com/otiai10/mint v1.6.3/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4= +github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.10 h1:+5FbKNTe5Z9aspU88DPIKJ9z2KZoaGCu6Sr6kKR/5mU= +github.com/pkg/sftp v1.13.10/go.mod h1:bJ1a7uDhrX/4OII+agvy28lzRvQrmIQuaHrcI1HbeGA= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/common v0.0.0-20181113130724-41aa239b4cce/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/common v0.4.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/prometheus/procfs v0.0.0-20190507164030-5867b95ac084/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= -github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= -github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= -github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.15.0 h1:uPRuwkWF4J6fGsJ2R0Gn2jB1EQiav9k3S6CSdygQJXY= -github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= -github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= -github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/afero v1.2.2 h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc= -github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= -github.com/spf13/cast v1.3.0 h1:oget//CVOEoFewqQxwr0Ej5yjygnqGkvggSE/gB35Q8= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/jwalterweatherman v1.0.0 h1:XHEdyB+EcvlqZamSM4ZOMGlc93t6AcsBEu9Gc1vn7yk= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmqnqMYADFk= -github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= -github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= -github.com/spf13/viper v1.4.0 h1:yXHLWeravcrgGyFSyCgdYpXQ9dR9c/WED3pg1RhxqEU= -github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= +github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc= +github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= +github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= +github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4 h1:PT+ElG/UUFMfqy5HrxJxNzj3QBOf7dZwupeVC+mG1Lo= +github.com/secsy/goftp v0.0.0-20200609142545-aa2de14babf4/go.mod h1:MnkX001NG75g3p8bhFycnyIjeQoOjGL6CEIsdE/nKSY= +github.com/sftpgo/sdk v0.1.9 h1:onBWfibCt34xHeKC2KFYPZ1DBqXGl9um/cAw+AVdgzY= +github.com/sftpgo/sdk v0.1.9/go.mod h1:ehimvlTP+XTEiE3t1CPwWx9n7+6A6OGvMGlZ7ouvKFk= +github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= +github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= +github.com/shoenig/go-m1cpu v0.2.0 h1:t4GNqvPZ84Vjtpboo/kT3pIkbaK3vc+JIlD/Wz1zSFY= +github.com/shoenig/go-m1cpu v0.2.0/go.mod h1:KkDOw6m3ZJQAPHbrzkZki4hnx+pDRR1Lo+ldA56wD5w= +github.com/shoenig/test v1.7.0 h1:eWcHtTXa6QLnBvm0jgEabMRN/uJ4DMV3M8xUGgRkZmk= +github.com/shoenig/test v1.7.0/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= -github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.etcd.io/bbolt v1.3.2 h1:Z/90sZLPOeCy2PwprqkFa25PdkusRzaj9P8zm/KNyvk= -go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= -go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= -golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/studio-b12/gowebdav v0.12.0 h1:kFRtQECt8jmVAvA6RHBz3geXUGJHUZA6/IKpOVUs5kM= +github.com/studio-b12/gowebdav v0.12.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA= +github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= +github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= +github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= +github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= +github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= +github.com/wagslane/go-password-validator v0.3.0 h1:vfxOPzGHkz5S146HDpavl0cw1DSVP061Ry2PX0/ON6I= +github.com/wagslane/go-password-validator v0.3.0/go.mod h1:TI1XJ6T5fRdRnHqHt14pvy1tNVnrwe7m3/f1f2fDphQ= +github.com/wneessen/go-mail v0.7.2 h1:xxPnhZ6IZLSgxShebmZ6DPKh1b6OJcoHfzy7UjOkzS8= +github.com/wneessen/go-mail v0.7.2/go.mod h1:+TkW6QP3EVkgTEqHtVmnAE/1MRhmzb8Y9/W3pweuS+k= +github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a h1:XfF01GyP+0eWCaVp0y6rNN+kFp7pt9Da4UUYrJ5XPWA= +github.com/yl2chen/cidranger v1.0.3-0.20210928021809-d1cb2c52f37a/go.mod h1:aXb8yZQEWo1XHGMf1qQfnb83GR/EJ2EBlwtUgAaNBoE= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= +go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/detectors/gcp v1.42.0 h1:kpt2PEJuOuqYkPcktfJqWWDjTEd/FNgrxcniL7kQrXQ= +go.opentelemetry.io/contrib/detectors/gcp v1.42.0/go.mod h1:W9zQ439utxymRrXsUOzZbFX4JhLxXU4+ZnCt8GG7yA8= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0 h1:5gn2urDL/FBnK8OkCfD1j3/ER79rUuTYmCvlXBKeYL8= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.39.0/go.mod h1:0fBG6ZJxhqByfFZDwSwpZGzJU671HkwpWaNe2t4VUPI= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ= +go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +gocloud.dev v0.45.0 h1:WknIK8IbRdmynDvara3Q7G6wQhmEiOGwpgJufbM39sY= +gocloud.dev v0.45.0/go.mod h1:0kXKmkCLG6d31N7NyLZWzt7jDSQura9zD/mWgiB6THI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= -golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM= -golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM= -golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190830142957-1e83adbbebd0 h1:7z820YPX9pxWR59qM7BE5+fglp4D/mKqAwCvGt11b+8= -golang.org/x/sys v0.0.0-20190830142957-1e83adbbebd0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.6.2 h1:j8RI1yW0SkI+paT6uGwMlrMI/6zwYA6/CFil8rxOzGI= -google.golang.org/appengine v1.6.2/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= +golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY= +google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q= +google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c h1:ZhFDeBMmFc/4g8/GwxnJ4rzB3O4GwQVNr+8Mh7Y5z4g= +google.golang.org/genproto v0.0.0-20260311181403-84a4fc48630c/go.mod h1:hf4r/rBuzaTkLUWRO03771Xvcs6P5hwdQK3UUEJjqo0= +google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c h1:OyQPd6I3pN/9gDxz6L13kYGJgqkpdrAohJRBeXyxlgI= +google.golang.org/genproto/googleapis/api v0.0.0-20260311181403-84a4fc48630c/go.mod h1:X2gu9Qwng7Nn009s/r3RUxqkzQNqOrAy79bluY7ojIg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c h1:xgCzyF2LFIO/0X2UAoVRiXKU5Xg6VjToG4i2/ecSswk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260311181403-84a4fc48630c/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= +google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= -gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= -gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= -gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/img/7digital.png b/img/7digital.png new file mode 100644 index 00000000..c6c1c540 Binary files /dev/null and b/img/7digital.png differ diff --git a/img/Aledade_logo.png b/img/Aledade_logo.png new file mode 100644 index 00000000..1639a321 Binary files /dev/null and b/img/Aledade_logo.png differ diff --git a/img/IDCS.png b/img/IDCS.png new file mode 100644 index 00000000..0a890e89 Binary files /dev/null and b/img/IDCS.png differ diff --git a/img/jumptrading.png b/img/jumptrading.png new file mode 100644 index 00000000..88161521 Binary files /dev/null and b/img/jumptrading.png differ diff --git a/img/logo.png b/img/logo.png new file mode 100644 index 00000000..81b1f94b Binary files /dev/null and b/img/logo.png differ diff --git a/img/reui.png b/img/reui.png new file mode 100644 index 00000000..20973803 Binary files /dev/null and b/img/reui.png differ diff --git a/img/servinga.png b/img/servinga.png new file mode 100644 index 00000000..c5ccb638 Binary files /dev/null and b/img/servinga.png differ diff --git a/img/wpengine.png b/img/wpengine.png new file mode 100644 index 00000000..0ca2381f Binary files /dev/null and b/img/wpengine.png differ diff --git a/init/com.github.drakkan.sftpgo.plist b/init/com.github.drakkan.sftpgo.plist new file mode 100644 index 00000000..58b83f8d --- /dev/null +++ b/init/com.github.drakkan.sftpgo.plist @@ -0,0 +1,36 @@ + + + + + Label + com.github.drakkan.sftpgo + EnvironmentVariables + + SFTPGO_CONFIG_DIR + /usr/local/opt/sftpgo/etc + SFTPGO_LOG_FILE_PATH + /usr/local/opt/sftpgo/var/log/sftpgo.log + SFTPGO_HTTPD__TEMPLATES_PATH + /usr/local/opt/sftpgo/usr/share/templates + SFTPGO_HTTPD__STATIC_FILES_PATH + /usr/local/opt/sftpgo/usr/share/static + SFTPGO_HTTPD__OPENAPI_PATH + /usr/local/opt/sftpgo/usr/share/openapi + SFTPGO_HTTPD__BACKUPS_PATH + /usr/local/opt/sftpgo/var/lib/backups + SFTPGO_DATA_PROVIDER__CREDENTIALS_PATH + /usr/local/opt/sftpgo/var/lib/credentials + + WorkingDirectory + /usr/local/opt/sftpgo/etc + ProgramArguments + + /usr/local/opt/sftpgo/bin/sftpgo + serve + + KeepAlive + + ThrottleInterval + 10 + + diff --git a/init/sftpgo b/init/sftpgo new file mode 100755 index 00000000..6da97d15 --- /dev/null +++ b/init/sftpgo @@ -0,0 +1,102 @@ +#! /bin/sh + +### BEGIN INIT INFO +# Provides: SFTPGo +# Required-Start: $remote_fs $syslog +# Required-Stop: $remote_fs $syslog +# Default-Start: 2 3 4 5 +# Default-Stop: +# Short-Description: SFTPGo server +### END INIT INFO + +set -e + +# /etc/init.d/sftpgo: start and stop the SFTPGo "server" daemon + +SFTPGO_USER="sftpgo" +SFTPGO_GROUP="sftpgo" +SFTPGO_BIN_NAME="sftpgo" +SFTPGO_BIN="/usr/bin/sftpgo" +SFTPGO_PID="/run/sftpgo.pid" +SFTPGO_CONF_DIR="/etc/sftpgo" +SFTPGO_CONF_FILE="sftpgo.json" +SFTPGO_OPTS="serve -c $SFTPGO_CONF_DIR --config-file $SFTPGO_CONF_FILE" + +umask 022 + +test -x $SFTPGO_BIN || exit 0 + + +if test -f /etc/default/$SFTPGO_BIN_NAME; then + . /etc/default/$SFTPGO_BIN_NAME +fi + +. /lib/lsb/init-functions + +if [ -n "$2" ]; then + SFTPGO_OPTS="$SFTPGO_OPTS $2" +fi + +# Are we running from init? +run_by_init() { + ([ "$previous" ] && [ "$runlevel" ]) || [ "$runlevel" = S ] +} + +check_dev_null() { + if [ ! -c /dev/null ]; then + if [ "$1" = log_end_msg ]; then + log_end_msg 1 || true + fi + if ! run_by_init; then + log_action_msg "/dev/null is not a character device!" || true + fi + exit 1 + fi +} + +write_pid() { + sleep 0.25 + echo $(/bin/pidof $SFTPGO_BIN_NAME) > $SFTPGO_PID +} + +export PATH="${PATH:+$PATH:}/usr/sbin:/sbin" + +case "$1" in + start) + check_dev_null + log_daemon_msg "Starting SFTPGo server" "$SFTPGO_BIN_NAME" || true + if start-stop-daemon --start --background --quiet --oknodo --chuid $SFTPGO_USER:$SFTPGO_GROUP --pidfile $SFTPGO_PID --exec $SFTPGO_BIN -- $SFTPGO_OPTS; then + log_end_msg 0 || true + write_pid + else + log_end_msg 1 || true + fi + ;; + stop) + log_daemon_msg "Stopping SFTPGo server" "$SFTPGO_BIN_NAME" || true + if start-stop-daemon --stop --quiet --oknodo --pidfile $SFTPGO_PID --exec $SFTPGO_BIN; then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + + reload) + log_daemon_msg "Reloading SFTPGo server" "$SFTPGO_BIN_NAME" || true + if start-stop-daemon --stop --signal 1 --quiet --oknodo --pidfile $SFTPGO_PID --exec $SFTPGO_BIN; then + log_end_msg 0 || true + else + log_end_msg 1 || true + fi + ;; + + status) + status_of_proc -p $SFTPGO_PID $SFTPGO_BIN $SFTPGO_BIN_NAME && exit 0 || exit $? + ;; + + *) + log_action_msg "Usage: /etc/init.d/$SFTPGO_BIN_NAME {start|stop|reload|status}" || true + exit 1 +esac + +exit 0 diff --git a/init/sftpgo.service b/init/sftpgo.service index 8dae1580..a8014c55 100644 --- a/init/sftpgo.service +++ b/init/sftpgo.service @@ -1,19 +1,29 @@ [Unit] -Description=SFTPGo sftp server +Description=SFTPGo Server After=network.target [Service] -User=root -Group=root +User=sftpgo +Group=sftpgo Type=simple WorkingDirectory=/etc/sftpgo +RuntimeDirectory=sftpgo Environment=SFTPGO_CONFIG_DIR=/etc/sftpgo/ -Environment=SFTPGO_LOG_FILE_PATH=/var/log/sftpgo.log +Environment=SFTPGO_LOG_FILE_PATH= EnvironmentFile=-/etc/sftpgo/sftpgo.env ExecStart=/usr/bin/sftpgo serve +ExecReload=/bin/kill -s HUP $MAINPID +LimitNOFILE=8192 KillMode=mixed +PrivateTmp=true Restart=always RestartSec=10s +NoNewPrivileges=yes +PrivateDevices=yes +DevicePolicy=closed +ProtectSystem=true +RestrictAddressFamilies=AF_INET AF_INET6 AF_UNIX +AmbientCapabilities=CAP_NET_BIND_SERVICE [Install] WantedBy=multi-user.target diff --git a/internal/acme/account.go b/internal/acme/account.go new file mode 100644 index 00000000..6f581a80 --- /dev/null +++ b/internal/acme/account.go @@ -0,0 +1,46 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package acme + +import ( + "crypto" + + "github.com/go-acme/lego/v4/registration" +) + +type account struct { + Email string `json:"email"` + Registration *registration.Resource `json:"registration"` + key crypto.PrivateKey +} + +/** Implementation of the registration.User interface **/ + +// GetEmail returns the email address for the account. +func (a *account) GetEmail() string { + return a.Email +} + +// GetRegistration returns the server registration. +func (a *account) GetRegistration() *registration.Resource { + return a.Registration +} + +// GetPrivateKey returns the private account key. +func (a *account) GetPrivateKey() crypto.PrivateKey { + return a.key +} + +/** End **/ diff --git a/internal/acme/acme.go b/internal/acme/acme.go new file mode 100644 index 00000000..6c4612e5 --- /dev/null +++ b/internal/acme/acme.go @@ -0,0 +1,859 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package acme provides automatic access to certificates from Let's Encrypt and any other ACME-based CA +// The code here is largely coiped from https://github.com/go-acme/lego/tree/master/cmd +// This package is intended to provide basic functionality for obtaining and renewing certificates +// and implements the "HTTP-01" and "TLSALPN-01" challenge types. +// For more advanced features use external tools such as "lego" +package acme + +import ( + "crypto" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "math/rand" + "net/url" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "time" + + "github.com/go-acme/lego/v4/certcrypto" + "github.com/go-acme/lego/v4/certificate" + "github.com/go-acme/lego/v4/challenge" + "github.com/go-acme/lego/v4/challenge/http01" + "github.com/go-acme/lego/v4/challenge/tlsalpn01" + "github.com/go-acme/lego/v4/lego" + "github.com/go-acme/lego/v4/log" + "github.com/go-acme/lego/v4/providers/http/webroot" + "github.com/go-acme/lego/v4/registration" + "github.com/hashicorp/go-retryablehttp" + "github.com/robfig/cron/v3" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/telemetry" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + logSender = "acme" +) + +var ( + config *Configuration + initialConfig Configuration + scheduler *cron.Cron + logMode int + supportedKeyTypes = []string{ + string(certcrypto.EC256), + string(certcrypto.EC384), + string(certcrypto.RSA2048), + string(certcrypto.RSA3072), + string(certcrypto.RSA4096), + string(certcrypto.RSA8192), + } + fnReloadHTTPDCerts func() error +) + +// SetReloadHTTPDCertsFn set the function to call to reload HTTPD certificates +func SetReloadHTTPDCertsFn(fn func() error) { + fnReloadHTTPDCerts = fn +} + +// GetCertificates tries to obtain the certificates using the global configuration +func GetCertificates() error { + if config == nil { + return errors.New("acme is disabled") + } + return config.getCertificates() +} + +// GetCertificatesForConfig tries to obtain the certificates using the provided +// configuration override. This is a NOOP if we already have certificates +func GetCertificatesForConfig(c *dataprovider.ACMEConfigs, configDir string) error { + if c.Domain == "" { + acmeLog(logger.LevelDebug, "no domain configured, nothing to do") + return nil + } + config := mergeConfig(getConfiguration(), c) + if err := config.Initialize(configDir); err != nil { + return err + } + hasCerts, err := config.hasCertificates(c.Domain) + if err != nil { + return fmt.Errorf("unable to check if we already have certificates for domain %q: %w", c.Domain, err) + } + if hasCerts { + return nil + } + return config.getCertificates() +} + +// GetHTTP01WebRoot returns the web root for HTTP-01 challenge +func GetHTTP01WebRoot() string { + return initialConfig.HTTP01Challenge.WebRoot +} + +func mergeConfig(config Configuration, c *dataprovider.ACMEConfigs) Configuration { + config.Domains = []string{c.Domain} + config.Email = c.Email + config.HTTP01Challenge.Port = c.HTTP01Challenge.Port + config.TLSALPN01Challenge.Port = 0 + return config +} + +// getConfiguration returns the configuration set using config file and env vars +func getConfiguration() Configuration { + return initialConfig +} + +func loadProviderConf(c Configuration) (Configuration, error) { + configs, err := dataprovider.GetConfigs() + if err != nil { + return c, fmt.Errorf("unable to load config from provider: %w", err) + } + configs.SetNilsToEmpty() + if configs.ACME.Domain == "" { + return c, nil + } + return mergeConfig(c, configs.ACME), nil +} + +// Initialize validates and set the configuration +func Initialize(c Configuration, configDir string, checkRenew bool) error { + config = nil + initialConfig = c + c, err := loadProviderConf(c) + if err != nil { + return err + } + util.CertsBasePath = "" + setLogMode(checkRenew) + + if err := c.Initialize(configDir); err != nil { + return err + } + if len(c.Domains) == 0 { + return nil + } + util.CertsBasePath = c.CertsPath + acmeLog(logger.LevelInfo, "configured domains: %+v, certs base path %q", c.Domains, c.CertsPath) + config = &c + if checkRenew { + return startScheduler() + } + return nil +} + +// HTTP01Challenge defines the configuration for HTTP-01 challenge type +type HTTP01Challenge struct { + Port int `json:"port" mapstructure:"port"` + WebRoot string `json:"webroot" mapstructure:"webroot"` + ProxyHeader string `json:"proxy_header" mapstructure:"proxy_header"` +} + +func (c *HTTP01Challenge) isEnabled() bool { + return c.Port > 0 || c.WebRoot != "" +} + +func (c *HTTP01Challenge) validate() error { + if !c.isEnabled() { + return nil + } + if c.WebRoot != "" { + if !filepath.IsAbs(c.WebRoot) { + return fmt.Errorf("invalid HTTP-01 challenge web root, please set an absolute path") + } + _, err := os.Stat(c.WebRoot) + if err != nil { + return fmt.Errorf("invalid HTTP-01 challenge web root: %w", err) + } + } else { + if c.Port > 65535 { + return fmt.Errorf("invalid HTTP-01 challenge port: %d", c.Port) + } + } + return nil +} + +// TLSALPN01Challenge defines the configuration for TLSALPN-01 challenge type +type TLSALPN01Challenge struct { + Port int `json:"port" mapstructure:"port"` +} + +func (c *TLSALPN01Challenge) isEnabled() bool { + return c.Port > 0 +} + +func (c *TLSALPN01Challenge) validate() error { + if !c.isEnabled() { + return nil + } + if c.Port > 65535 { + return fmt.Errorf("invalid TLSALPN-01 challenge port: %d", c.Port) + } + return nil +} + +// Configuration holds the ACME configuration +type Configuration struct { + Email string `json:"email" mapstructure:"email"` + KeyType string `json:"key_type" mapstructure:"key_type"` + CertsPath string `json:"certs_path" mapstructure:"certs_path"` + CAEndpoint string `json:"ca_endpoint" mapstructure:"ca_endpoint"` + // if a certificate is to be valid for multiple domains specify the names separated by commas, + // for example: example.com,www.example.com + Domains []string `json:"domains" mapstructure:"domains"` + RenewDays int `json:"renew_days" mapstructure:"renew_days"` + HTTP01Challenge HTTP01Challenge `json:"http01_challenge" mapstructure:"http01_challenge"` + TLSALPN01Challenge TLSALPN01Challenge `json:"tls_alpn01_challenge" mapstructure:"tls_alpn01_challenge"` + accountConfigPath string + accountKeyPath string + lockPath string + tempDir string +} + +// Initialize validates and initialize the configuration +func (c *Configuration) Initialize(configDir string) error { + c.checkDomains() + if len(c.Domains) == 0 { + acmeLog(logger.LevelInfo, "no domains configured, acme disabled") + return nil + } + if c.Email == "" || !util.IsEmailValid(c.Email) { + return util.NewI18nError( + fmt.Errorf("invalid email address %q", c.Email), + util.I18nErrorInvalidEmail, + ) + } + if c.RenewDays < 1 { + return fmt.Errorf("invalid number of days remaining before renewal: %d", c.RenewDays) + } + if !slices.Contains(supportedKeyTypes, c.KeyType) { + return fmt.Errorf("invalid key type %q", c.KeyType) + } + caURL, err := url.Parse(c.CAEndpoint) + if err != nil { + return fmt.Errorf("invalid CA endopoint: %w", err) + } + if !util.IsFileInputValid(c.CertsPath) { + return fmt.Errorf("invalid certs path %q", c.CertsPath) + } + if !filepath.IsAbs(c.CertsPath) { + c.CertsPath = filepath.Join(configDir, c.CertsPath) + } + err = os.MkdirAll(c.CertsPath, 0700) + if err != nil { + return fmt.Errorf("unable to create certs path %q: %w", c.CertsPath, err) + } + c.tempDir = filepath.Join(c.CertsPath, "temp") + err = os.MkdirAll(c.CertsPath, 0700) + if err != nil { + return fmt.Errorf("unable to create certs temp path %q: %w", c.tempDir, err) + } + serverPath := strings.NewReplacer(":", "_", "/", string(os.PathSeparator)).Replace(caURL.Host) + accountPath := filepath.Join(c.CertsPath, serverPath) + err = os.MkdirAll(accountPath, 0700) + if err != nil { + return fmt.Errorf("unable to create account path %q: %w", accountPath, err) + } + c.accountConfigPath = filepath.Join(accountPath, c.Email+".json") + c.accountKeyPath = filepath.Join(accountPath, c.Email+".key") + c.lockPath = filepath.Join(c.CertsPath, "lock") + + return c.validateChallenges() +} + +func (c *Configuration) validateChallenges() error { + if !c.HTTP01Challenge.isEnabled() && !c.TLSALPN01Challenge.isEnabled() { + return fmt.Errorf("no challenge type defined") + } + if err := c.HTTP01Challenge.validate(); err != nil { + return err + } + return c.TLSALPN01Challenge.validate() +} + +func (c *Configuration) checkDomains() { + var domains []string + for _, domain := range c.Domains { + domain = strings.TrimSpace(domain) + if domain == "" { + continue + } + if d, ok := isDomainValid(domain); ok { + domains = append(domains, d) + } + } + c.Domains = util.RemoveDuplicates(domains, true) +} + +func (c *Configuration) setLockTime() error { + lockTime := fmt.Sprintf("%v", util.GetTimeAsMsSinceEpoch(time.Now())) + err := os.WriteFile(c.lockPath, []byte(lockTime), 0600) + if err != nil { + acmeLog(logger.LevelError, "unable to save lock time to %q: %v", c.lockPath, err) + return fmt.Errorf("unable to save lock time: %w", err) + } + acmeLog(logger.LevelDebug, "lock time saved: %q", lockTime) + return nil +} + +func (c *Configuration) getLockTime() (time.Time, error) { + content, err := os.ReadFile(c.lockPath) + if err != nil { + if os.IsNotExist(err) { + acmeLog(logger.LevelDebug, "lock file %q not found", c.lockPath) + return time.Time{}, nil + } + acmeLog(logger.LevelError, "unable to read lock file %q: %v", c.lockPath, err) + return time.Time{}, err + } + msec, err := strconv.ParseInt(strings.TrimSpace(util.BytesToString(content)), 10, 64) + if err != nil { + acmeLog(logger.LevelError, "unable to parse lock time: %v", err) + return time.Time{}, fmt.Errorf("unable to parse lock time: %w", err) + } + return util.GetTimeFromMsecSinceEpoch(msec), nil +} + +func (c *Configuration) saveAccount(account *account) error { + jsonBytes, err := json.MarshalIndent(account, "", "\t") + if err != nil { + return err + } + err = os.WriteFile(c.accountConfigPath, jsonBytes, 0600) + if err != nil { + acmeLog(logger.LevelError, "unable to save account to file %q: %v", c.accountConfigPath, err) + return fmt.Errorf("unable to save account: %w", err) + } + return nil +} + +func (c *Configuration) getAccount(privateKey crypto.PrivateKey) (account, error) { + _, err := os.Stat(c.accountConfigPath) + if err != nil && os.IsNotExist(err) { + acmeLog(logger.LevelDebug, "account does not exist") + return account{Email: c.Email, key: privateKey}, nil + } + var account account + fileBytes, err := os.ReadFile(c.accountConfigPath) + if err != nil { + acmeLog(logger.LevelError, "unable to read account from file %q: %v", c.accountConfigPath, err) + return account, fmt.Errorf("unable to read account from file: %w", err) + } + err = json.Unmarshal(fileBytes, &account) + if err != nil { + acmeLog(logger.LevelError, "invalid account file content: %v", err) + return account, fmt.Errorf("unable to parse account file as JSON: %w", err) + } + account.key = privateKey + if account.Registration == nil || account.Registration.Body.Status == "" { + acmeLog(logger.LevelInfo, "couldn't load account but got a key. Try to look the account up") + reg, err := c.tryRecoverRegistration(privateKey) + if err != nil { + acmeLog(logger.LevelError, "unable to look the account up: %v", err) + return account, fmt.Errorf("unable to look the account up: %w", err) + } + account.Registration = reg + err = c.saveAccount(&account) + if err != nil { + return account, err + } + } + + return account, nil +} + +func (c *Configuration) loadPrivateKey() (crypto.PrivateKey, error) { + keyBytes, err := os.ReadFile(c.accountKeyPath) + if err != nil { + acmeLog(logger.LevelError, "unable to read account key from file %q: %v", c.accountKeyPath, err) + return nil, fmt.Errorf("unable to read account key: %w", err) + } + + keyBlock, _ := pem.Decode(keyBytes) + if keyBlock == nil { + acmeLog(logger.LevelError, "unable to parse private key from file %q: pem decoding failed", c.accountKeyPath) + return nil, errors.New("pem decoding failed") + } + + var privateKey crypto.PrivateKey + switch keyBlock.Type { + case "RSA PRIVATE KEY": + privateKey, err = x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "EC PRIVATE KEY": + privateKey, err = x509.ParseECPrivateKey(keyBlock.Bytes) + default: + err = fmt.Errorf("unknown private key type %q", keyBlock.Type) + } + if err != nil { + acmeLog(logger.LevelError, "unable to parse private key from file %q: %v", c.accountKeyPath, err) + return privateKey, fmt.Errorf("unable to parse private key: %w", err) + } + return privateKey, nil +} + +func (c *Configuration) generatePrivateKey() (crypto.PrivateKey, error) { + privateKey, err := certcrypto.GeneratePrivateKey(certcrypto.KeyType(c.KeyType)) + if err != nil { + acmeLog(logger.LevelError, "unable to generate private key: %v", err) + return nil, fmt.Errorf("unable to generate private key: %w", err) + } + certOut, err := os.Create(c.accountKeyPath) + if err != nil { + acmeLog(logger.LevelError, "unable to save private key to file %q: %v", c.accountKeyPath, err) + return nil, fmt.Errorf("unable to save private key: %w", err) + } + defer certOut.Close() + + pemKey := certcrypto.PEMBlock(privateKey) + err = pem.Encode(certOut, pemKey) + if err != nil { + acmeLog(logger.LevelError, "unable to encode private key: %v", err) + return nil, fmt.Errorf("unable to encode private key: %w", err) + } + acmeLog(logger.LevelDebug, "new account private key generated") + + return privateKey, nil +} + +func (c *Configuration) getPrivateKey() (crypto.PrivateKey, error) { + _, err := os.Stat(c.accountKeyPath) + if err != nil && os.IsNotExist(err) { + acmeLog(logger.LevelDebug, "private key file %q does not exist, generating new private key", c.accountKeyPath) + return c.generatePrivateKey() + } + acmeLog(logger.LevelDebug, "loading private key from file %q, stat error: %v", c.accountKeyPath, err) + return c.loadPrivateKey() +} + +func (c *Configuration) loadCertificatesForDomain(domain string) ([]*x509.Certificate, error) { + domain = util.SanitizeDomain(domain) + acmeLog(logger.LevelDebug, "loading certificates for domain %q", domain) + content, err := os.ReadFile(filepath.Join(c.CertsPath, domain+".crt")) + if err != nil { + acmeLog(logger.LevelError, "unable to load certificates for domain %q: %v", domain, err) + return nil, fmt.Errorf("unable to load certificates for domain %q: %w", domain, err) + } + certs, err := certcrypto.ParsePEMBundle(content) + if err != nil { + acmeLog(logger.LevelError, "unable to parse certificates for domain %q: %v", domain, err) + return certs, fmt.Errorf("unable to parse certificates for domain %q: %w", domain, err) + } + return certs, nil +} + +func (c *Configuration) needRenewal(x509Cert *x509.Certificate, domain string) bool { + if x509Cert.IsCA { + acmeLog(logger.LevelError, "certificate bundle starts with a CA certificate, cannot renew domain %v", domain) + return false + } + notAfter := int(time.Until(x509Cert.NotAfter).Hours() / 24.0) + if notAfter > c.RenewDays { + acmeLog(logger.LevelDebug, "the certificate for domain %q expires in %d days, no renewal", domain, notAfter) + return false + } + return true +} + +func (c *Configuration) setup() (*account, *lego.Client, error) { + privateKey, err := c.getPrivateKey() + if err != nil { + return nil, nil, err + } + account, err := c.getAccount(privateKey) + if err != nil { + return nil, nil, err + } + config := lego.NewConfig(&account) + config.CADirURL = c.CAEndpoint + config.Certificate.KeyType = certcrypto.KeyType(c.KeyType) + config.Certificate.OverallRequestLimit = 6 + config.UserAgent = version.GetServerVersion("/", false) + + retryClient := retryablehttp.NewClient() + retryClient.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"} + retryClient.RetryMax = 5 + retryClient.HTTPClient = config.HTTPClient + + config.HTTPClient = retryClient.StandardClient() + + client, err := lego.NewClient(config) + if err != nil { + acmeLog(logger.LevelError, "unable to get ACME client: %v", err) + return nil, nil, fmt.Errorf("unable to get ACME client: %w", err) + } + err = c.setupChalleges(client) + if err != nil { + return nil, nil, err + } + return &account, client, nil +} + +func (c *Configuration) setupChalleges(client *lego.Client) error { + client.Challenge.Remove(challenge.DNS01) + if c.HTTP01Challenge.isEnabled() { + if c.HTTP01Challenge.WebRoot != "" { + acmeLog(logger.LevelDebug, "configuring HTTP-01 web root challenge, path %q", c.HTTP01Challenge.WebRoot) + providerServer, err := webroot.NewHTTPProvider(c.HTTP01Challenge.WebRoot) + if err != nil { + acmeLog(logger.LevelError, "unable to create HTTP-01 web root challenge provider from path %q: %v", + c.HTTP01Challenge.WebRoot, err) + return fmt.Errorf("unable to create HTTP-01 web root challenge provider: %w", err) + } + err = client.Challenge.SetHTTP01Provider(providerServer) + if err != nil { + acmeLog(logger.LevelError, "unable to set HTTP-01 challenge provider: %v", err) + return fmt.Errorf("unable to set HTTP-01 challenge provider: %w", err) + } + } else { + acmeLog(logger.LevelDebug, "configuring HTTP-01 challenge, port %d", c.HTTP01Challenge.Port) + providerServer := http01.NewProviderServer("", fmt.Sprintf("%d", c.HTTP01Challenge.Port)) + if c.HTTP01Challenge.ProxyHeader != "" { + acmeLog(logger.LevelDebug, "setting proxy header to \"%s\"", c.HTTP01Challenge.ProxyHeader) + providerServer.SetProxyHeader(c.HTTP01Challenge.ProxyHeader) + } + err := client.Challenge.SetHTTP01Provider(providerServer) + if err != nil { + acmeLog(logger.LevelError, "unable to set HTTP-01 challenge provider: %v", err) + return fmt.Errorf("unable to set HTTP-01 challenge provider: %w", err) + } + } + } else { + client.Challenge.Remove(challenge.HTTP01) + } + if c.TLSALPN01Challenge.isEnabled() { + acmeLog(logger.LevelDebug, "configuring TLSALPN-01 challenge, port %d", c.TLSALPN01Challenge.Port) + err := client.Challenge.SetTLSALPN01Provider(tlsalpn01.NewProviderServer("", fmt.Sprintf("%d", c.TLSALPN01Challenge.Port))) + if err != nil { + acmeLog(logger.LevelError, "unable to set TLSALPN-01 challenge provider: %v", err) + return fmt.Errorf("unable to set TLSALPN-01 challenge provider: %w", err) + } + } else { + client.Challenge.Remove(challenge.TLSALPN01) + } + + return nil +} + +func (c *Configuration) register(client *lego.Client) (*registration.Resource, error) { + return client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) +} + +func (c *Configuration) tryRecoverRegistration(privateKey crypto.PrivateKey) (*registration.Resource, error) { + config := lego.NewConfig(&account{key: privateKey}) + config.CADirURL = c.CAEndpoint + config.UserAgent = version.GetServerVersion("/", false) + + retryClient := retryablehttp.NewClient() + retryClient.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"} + retryClient.RetryMax = 5 + retryClient.HTTPClient = config.HTTPClient + + config.HTTPClient = retryClient.StandardClient() + + client, err := lego.NewClient(config) + if err != nil { + acmeLog(logger.LevelError, "unable to get the ACME client: %v", err) + return nil, err + } + + return client.Registration.ResolveAccountByKey() +} + +func (c *Configuration) getCrtPath(domain string) string { + return filepath.Join(c.CertsPath, domain+".crt") +} + +func (c *Configuration) getKeyPath(domain string) string { + return filepath.Join(c.CertsPath, domain+".key") +} + +func (c *Configuration) getResourcePath(domain string) string { + return filepath.Join(c.CertsPath, domain+".json") +} + +func (c *Configuration) obtainAndSaveCertificate(client *lego.Client, domain string) error { + domains := getDomains(domain) + acmeLog(logger.LevelInfo, "requesting certificates for domains %+v", domains) + request := certificate.ObtainRequest{ + Domains: domains, + Bundle: true, + MustStaple: false, + PreferredChain: "", + AlwaysDeactivateAuthorizations: false, + } + cert, err := client.Certificate.Obtain(request) + if err != nil { + acmeLog(logger.LevelError, "unable to obtain certificates for domains %+v: %v", domains, err) + return fmt.Errorf("unable to obtain certificates: %w", err) + } + domain = util.SanitizeDomain(domain) + err = os.WriteFile(c.getCrtPath(domain), cert.Certificate, 0600) + if err != nil { + acmeLog(logger.LevelError, "unable to save certificate for domain %s: %v", domain, err) + return fmt.Errorf("unable to save certificate: %w", err) + } + err = os.WriteFile(c.getKeyPath(domain), cert.PrivateKey, 0600) + if err != nil { + acmeLog(logger.LevelError, "unable to save private key for domain %s: %v", domain, err) + return fmt.Errorf("unable to save private key: %w", err) + } + jsonBytes, err := json.MarshalIndent(cert, "", "\t") + if err != nil { + acmeLog(logger.LevelError, "unable to marshal certificate resources for domain %v: %v", domain, err) + return err + } + err = os.WriteFile(c.getResourcePath(domain), jsonBytes, 0600) + if err != nil { + acmeLog(logger.LevelError, "unable to save certificate resources for domain %v: %v", domain, err) + return fmt.Errorf("unable to save certificate resources: %w", err) + } + + acmeLog(logger.LevelInfo, "certificates for domains %+v saved", domains) + return nil +} + +// hasCertificates returns true if certificates for the specified domain has already been issued +func (c *Configuration) hasCertificates(domain string) (bool, error) { + domain = util.SanitizeDomain(domain) + if _, err := os.Stat(c.getCrtPath(domain)); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + if _, err := os.Stat(c.getKeyPath(domain)); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + return true, nil +} + +// getCertificates tries to obtain the certificates for the configured domains +func (c *Configuration) getCertificates() error { + account, client, err := c.setup() + if err != nil { + return err + } + if account.Registration == nil { + reg, err := c.register(client) + if err != nil { + acmeLog(logger.LevelError, "unable to register account: %v", err) + return fmt.Errorf("unable to register account: %w", err) + } + account.Registration = reg + err = c.saveAccount(account) + if err != nil { + return err + } + } + for _, domain := range c.Domains { + err = c.obtainAndSaveCertificate(client, domain) + if err != nil { + return err + } + } + return nil +} + +func (c *Configuration) notifyCertificateRenewal(domain string, err error) { + if domain == "" { + domain = strings.Join(c.Domains, ",") + } + params := common.EventParams{ + Name: domain, + Event: "Certificate renewal", + Timestamp: time.Now(), + } + if err != nil { + params.Status = 2 + params.AddError(err) + } else { + params.Status = 1 + } + common.HandleCertificateEvent(params) +} + +func (c *Configuration) renewCertificates() error { + lockTime, err := c.getLockTime() + if err != nil { + return err + } + acmeLog(logger.LevelDebug, "certificate renew lock time %v", lockTime) + if lockTime.Add(-30*time.Second).Before(time.Now()) && lockTime.Add(5*time.Minute).After(time.Now()) { + acmeLog(logger.LevelInfo, "certificate renew skipped, lock time too close: %v", lockTime) + return nil + } + err = c.setLockTime() + if err != nil { + c.notifyCertificateRenewal("", err) + return err + } + account, client, err := c.setup() + if err != nil { + c.notifyCertificateRenewal("", err) + return err + } + if account.Registration == nil { + acmeLog(logger.LevelError, "cannot renew certificates, your account is not registered") + err = errors.New("cannot renew certificates, your account is not registered") + c.notifyCertificateRenewal("", err) + return err + } + var errRenew error + needReload := false + for _, domain := range c.Domains { + certificates, err := c.loadCertificatesForDomain(domain) + if err != nil { + c.notifyCertificateRenewal(domain, err) + errRenew = err + continue + } + cert := certificates[0] + if !c.needRenewal(cert, domain) { + continue + } + err = c.obtainAndSaveCertificate(client, domain) + if err != nil { + c.notifyCertificateRenewal(domain, err) + errRenew = err + } else { + c.notifyCertificateRenewal(domain, nil) + needReload = true + } + } + if needReload { + // at least one certificate has been renewed, sends a reload to all services that may be using certificates + err = ftpd.ReloadCertificateMgr() + acmeLog(logger.LevelInfo, "ftpd certificate manager reloaded , error: %v", err) + if fnReloadHTTPDCerts != nil { + err = fnReloadHTTPDCerts() + acmeLog(logger.LevelInfo, "httpd certificates manager reloaded , error: %v", err) + } + err = webdavd.ReloadCertificateMgr() + acmeLog(logger.LevelInfo, "webdav certificates manager reloaded , error: %v", err) + err = telemetry.ReloadCertificateMgr() + acmeLog(logger.LevelInfo, "telemetry certificates manager reloaded , error: %v", err) + } + + return errRenew +} + +func isDomainValid(domain string) (string, bool) { + isValid := false + for d := range strings.SplitSeq(domain, ",") { + d = strings.TrimSpace(d) + if d != "" { + isValid = true + break + } + } + return domain, isValid +} + +func getDomains(domain string) []string { + var domains []string + + delimiter := "," + if !strings.Contains(domain, ",") && strings.Contains(domain, " ") { + delimiter = " " + } + + for d := range strings.SplitSeq(domain, delimiter) { + d = strings.TrimSpace(d) + if d != "" { + domains = append(domains, d) + } + } + return util.RemoveDuplicates(domains, false) +} + +func stopScheduler() { + if scheduler != nil { + scheduler.Stop() + scheduler = nil + } +} + +func startScheduler() error { + stopScheduler() + + randSecs := rand.Intn(59) + scheduler = cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger)) + _, err := scheduler.AddFunc(fmt.Sprintf("@every 12h0m%ds", randSecs), renewCertificates) + if err != nil { + return fmt.Errorf("unable to schedule certificates renewal: %w", err) + } + + acmeLog(logger.LevelInfo, "starting scheduler, initial certificates check in %d seconds", randSecs) + initialTimer := time.NewTimer(time.Duration(randSecs) * time.Second) + go func() { + <-initialTimer.C + renewCertificates() + }() + + scheduler.Start() + return nil +} + +func renewCertificates() { + if config != nil { + if err := config.renewCertificates(); err != nil { + acmeLog(logger.LevelError, "unable to renew certificates: %v", err) + } + } +} + +func setLogMode(checkRenew bool) { + if checkRenew { + logMode = 1 + } else { + logMode = 2 + } + log.Logger = &logger.LegoAdapter{ + LogToConsole: logMode != 1, + } +} + +func acmeLog(level logger.LogLevel, format string, v ...any) { + if logMode == 1 { + logger.Log(level, logSender, "", format, v...) + } else { + switch level { + case logger.LevelDebug: + logger.DebugToConsole(format, v...) + case logger.LevelInfo: + logger.InfoToConsole(format, v...) + case logger.LevelWarn: + logger.WarnToConsole(format, v...) + default: + logger.ErrorToConsole(format, v...) + } + } +} diff --git a/internal/bundle/bundle.go b/internal/bundle/bundle.go new file mode 100644 index 00000000..077fb360 --- /dev/null +++ b/internal/bundle/bundle.go @@ -0,0 +1,64 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build bundle + +package bundle + +import ( + "embed" + "fmt" + "io/fs" + "net/http" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("+bundle") +} + +//go:embed templates/* +var templatesFs embed.FS + +//go:embed static/* +var staticFs embed.FS + +//go:embed openapi/* +var openapiFs embed.FS + +// GetTemplatesFs returns the embedded filesystem with the SFTPGo templates +func GetTemplatesFs() embed.FS { + return templatesFs +} + +// GetStaticFs return the http Filesystem with the embedded static files +func GetStaticFs() http.FileSystem { + fsys, err := fs.Sub(staticFs, "static") + if err != nil { + err = fmt.Errorf("unable to get embedded filesystem for static files: %w", err) + panic(err) + } + return http.FS(fsys) +} + +// GetOpenAPIFs return the http Filesystem with the embedded static files +func GetOpenAPIFs() http.FileSystem { + fsys, err := fs.Sub(openapiFs, "openapi") + if err != nil { + err = fmt.Errorf("unable to get embedded filesystem for OpenAPI files: %w", err) + panic(err) + } + return http.FS(fsys) +} diff --git a/internal/cmd/acme.go b/internal/cmd/acme.go new file mode 100644 index 00000000..e98cf515 --- /dev/null +++ b/internal/cmd/acme.go @@ -0,0 +1,99 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + acmeCmd = &cobra.Command{ + Use: "acme", + Short: "Obtain TLS certificates from ACME-based CAs like Let's Encrypt", + } + acmeRunCmd = &cobra.Command{ + Use: "run", + Short: "Register your account and obtain certificates", + Long: `This command must be run to obtain TLS certificates the first time or every +time you add a new domain to your configuration file. +Certificates are saved in the configured "certs_path". +After this initial step, the certificates are automatically checked and +renewed by the SFTPGo service +`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.ErrorToConsole("Unable to initialize ACME, config load error: %v", err) + return + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("unable to initialize KMS: %v", err) + os.Exit(1) + } + if config.HasKMSPlugin() { + if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { + logger.ErrorToConsole("unable to initialize plugin system: %v", err) + os.Exit(1) + } + registerSignals() + defer plugin.Handler.Cleanup() + } + + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("Unable to initialize MFA: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, false) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + acmeConfig := config.GetACMEConfig() + err = acme.Initialize(acmeConfig, configDir, false) + if err != nil { + logger.ErrorToConsole("Unable to initialize ACME configuration: %v", err) + os.Exit(1) + } + if err = acme.GetCertificates(); err != nil { + logger.ErrorToConsole("Cannot get certificates: %v", err) + os.Exit(1) + } + }, + } +) + +func init() { + addConfigFlags(acmeRunCmd) + acmeCmd.AddCommand(acmeRunCmd) + rootCmd.AddCommand(acmeCmd) +} diff --git a/internal/cmd/gen.go b/internal/cmd/gen.go new file mode 100644 index 00000000..c576fada --- /dev/null +++ b/internal/cmd/gen.go @@ -0,0 +1,26 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import "github.com/spf13/cobra" + +var genCmd = &cobra.Command{ + Use: "gen", + Short: "A collection of useful generators", +} + +func init() { + rootCmd.AddCommand(genCmd) +} diff --git a/internal/cmd/gencompletion.go b/internal/cmd/gencompletion.go new file mode 100644 index 00000000..2f25193d --- /dev/null +++ b/internal/cmd/gencompletion.go @@ -0,0 +1,133 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + + "github.com/spf13/cobra" +) + +var genCompletionCmd = &cobra.Command{ + Use: "completion [bash|zsh|fish|powershell]", + Short: "Generate the autocompletion script for the specified shell", + Long: `Generate the autocompletion script for sftpgo for the specified shell. + +See each sub-command's help for details on how to use the generated script. +`, +} + +var genCompletionBashCmd = &cobra.Command{ + Use: "bash", + Short: "Generate the autocompletion script for bash", + Long: `Generate the autocompletion script for the bash shell. + +This script depends on the 'bash-completion' package. +If it is not installed already, you can install it via your OS's package +manager. + +To load completions in your current shell session: + +$ source <(sftpgo gen completion bash) + +To load completions for every new session, execute once: + +Linux: + $ sudo sftpgo gen completion bash > /usr/share/bash-completion/completions/sftpgo + +MacOS: + $ sudo sftpgo gen completion bash > /usr/local/etc/bash_completion.d/sftpgo + +You will need to start a new shell for this setup to take effect. +`, + DisableFlagsInUseLine: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Root().GenBashCompletionV2(os.Stdout, true) + }, +} + +var genCompletionZshCmd = &cobra.Command{ + Use: "zsh", + Short: "Generate the autocompletion script for zsh", + Long: `Generate the autocompletion script for the zsh shell. + +If shell completion is not already enabled in your environment you will need +to enable it. You can execute the following once: + +$ echo "autoload -U compinit; compinit" >> ~/.zshrc + +To load completions for every new session, execute once: + +Linux: + $ sftpgo gen completion zsh > > "${fpath[1]}/_sftpgo" + +macOS: + $ sudo sftpgo gen completion zsh > /usr/local/share/zsh/site-functions/_sftpgo + +You will need to start a new shell for this setup to take effect. +`, + DisableFlagsInUseLine: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Root().GenZshCompletion(os.Stdout) + }, +} + +var genCompletionFishCmd = &cobra.Command{ + Use: "fish", + Short: "Generate the autocompletion script for fish", + Long: `Generate the autocompletion script for the fish shell. + +To load completions in your current shell session: + +$ sftpgo gen completion fish | source + +To load completions for every new session, execute once: + +$ sftpgo gen completion fish > ~/.config/fish/completions/sftpgo.fish + +You will need to start a new shell for this setup to take effect. +`, + DisableFlagsInUseLine: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Root().GenFishCompletion(os.Stdout, true) + }, +} + +var genCompletionPowerShellCmd = &cobra.Command{ + Use: "powershell", + Short: "Generate the autocompletion script for powershell", + Long: `Generate the autocompletion script for powershell. + +To load completions in your current shell session: + +PS C:\> sftpgo gen completion powershell | Out-String | Invoke-Expression + +To load completions for every new session, add the output of the above command +to your powershell profile. +`, + DisableFlagsInUseLine: true, + RunE: func(cmd *cobra.Command, _ []string) error { + return cmd.Root().GenPowerShellCompletionWithDesc(os.Stdout) + }, +} + +func init() { + genCompletionCmd.AddCommand(genCompletionBashCmd) + genCompletionCmd.AddCommand(genCompletionZshCmd) + genCompletionCmd.AddCommand(genCompletionFishCmd) + genCompletionCmd.AddCommand(genCompletionPowerShellCmd) + + genCmd.AddCommand(genCompletionCmd) +} diff --git a/internal/cmd/genman.go b/internal/cmd/genman.go new file mode 100644 index 00000000..dba1f149 --- /dev/null +++ b/internal/cmd/genman.go @@ -0,0 +1,69 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "errors" + "fmt" + "io/fs" + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/cobra/doc" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +var ( + manDir string + genManCmd = &cobra.Command{ + Use: "man", + Short: "Generate man pages for sftpgo", + Long: `This command automatically generates up-to-date man pages of SFTPGo's +command-line interface. +By default, it creates the man page files in the "man" directory under the +current directory. +`, + Run: func(cmd *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + if _, err := os.Stat(manDir); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(manDir, os.ModePerm) + if err != nil { + logger.WarnToConsole("Unable to generate man page files: %v", err) + os.Exit(1) + } + } + header := &doc.GenManHeader{ + Section: "1", + Manual: "SFTPGo Manual", + Source: fmt.Sprintf("SFTPGo %v", version.Get().Version), + } + cmd.Root().DisableAutoGenTag = true + err := doc.GenManTree(cmd.Root(), header, manDir) + if err != nil { + logger.WarnToConsole("Unable to generate man page files: %v", err) + os.Exit(1) + } + }, + } +) + +func init() { + genManCmd.Flags().StringVarP(&manDir, "dir", "d", "man", "The directory to write the man pages") + genCmd.AddCommand(genManCmd) +} diff --git a/internal/cmd/initprovider.go b/internal/cmd/initprovider.go new file mode 100644 index 00000000..8862d261 --- /dev/null +++ b/internal/cmd/initprovider.go @@ -0,0 +1,125 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/service" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + initProviderCmd = &cobra.Command{ + Use: "initprovider", + Short: "Initialize and/or updates the configured data provider", + Long: `This command reads the data provider connection details from the specified +configuration file and creates the initial structure or update the existing one, +as needed. + +Some data providers such as bolt and memory does not require an initialization +but they could require an update to the existing data after upgrading SFTPGo. + +For SQLite/bolt providers the database file will be auto-created if missing. + +For PostgreSQL and MySQL providers you need to create the configured database, +this command will create/update the required tables as needed. + +To initialize/update the data provider from the configuration directory simply use: + +$ sftpgo initprovider + +Any defined action is ignored. +Please take a look at the usage below to customize the options.`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.ErrorToConsole("Unable to initialize data provider, config load error: %v", err) + return + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("Unable to initialize KMS: %v", err) + os.Exit(1) + } + if config.HasKMSPlugin() { + if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { + logger.ErrorToConsole("unable to initialize plugin system: %v", err) + os.Exit(1) + } + registerSignals() + defer plugin.Handler.Cleanup() + } + + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("Unable to initialize MFA: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + // ignore actions + providerConf.Actions.Hook = "" + providerConf.Actions.ExecuteFor = nil + providerConf.Actions.ExecuteOn = nil + logger.InfoToConsole("Initializing provider: %q config file: %q", providerConf.Driver, viper.ConfigFileUsed()) + err = dataprovider.InitializeDatabase(providerConf, configDir) + switch err { + case nil: + logger.InfoToConsole("Data provider successfully initialized/updated") + case dataprovider.ErrNoInitRequired: + logger.InfoToConsole("%v", err.Error()) + default: + logger.ErrorToConsole("Unable to initialize/update the data provider: %v", err) + os.Exit(1) + } + if providerConf.Driver != dataprovider.MemoryDataProviderName && loadDataFrom != "" { + if err := common.Initialize(config.GetCommonConfig(), providerConf.GetShared()); err != nil { + logger.ErrorToConsole("%v", err) + os.Exit(1) + } + service := service.Service{ + LoadDataFrom: loadDataFrom, + LoadDataMode: loadDataMode, + LoadDataQuotaScan: loadDataQuotaScan, + LoadDataClean: loadDataClean, + } + if err = service.LoadInitialData(); err != nil { + logger.ErrorToConsole("Cannot load initial data: %v", err) + os.Exit(1) + } + } + }, + } +) + +func init() { + rootCmd.AddCommand(initProviderCmd) + addConfigFlags(initProviderCmd) + addBaseLoadDataFlags(initProviderCmd) +} diff --git a/internal/cmd/install_windows.go b/internal/cmd/install_windows.go new file mode 100644 index 00000000..19a02fbf --- /dev/null +++ b/internal/cmd/install_windows.go @@ -0,0 +1,117 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + "strconv" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + installCmd = &cobra.Command{ + Use: "install", + Short: "Install SFTPGo as Windows Service", + Long: `To install the SFTPGo Windows Service with the default values for the command +line flags simply use: + +sftpgo service install + +Please take a look at the usage below to customize the startup options`, + Run: func(_ *cobra.Command, _ []string) { + s := service.Service{ + ConfigDir: util.CleanDirInput(configDir), + ConfigFile: configFile, + LogFilePath: logFilePath, + LogMaxSize: logMaxSize, + LogMaxBackups: logMaxBackups, + LogMaxAge: logMaxAge, + LogCompress: logCompress, + LogLevel: logLevel, + LogUTCTime: logUTCTime, + Shutdown: make(chan bool), + } + winService := service.WindowsService{ + Service: s, + } + serviceArgs := []string{"service", "start"} + customFlags := getCustomServeFlags() + if len(customFlags) > 0 { + serviceArgs = append(serviceArgs, customFlags...) + } + err := winService.Install(serviceArgs...) + if err != nil { + fmt.Printf("Error installing service: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Service installed!\r\n") + } + }, + } +) + +func init() { + serviceCmd.AddCommand(installCmd) + addServeFlags(installCmd) +} + +func getCustomServeFlags() []string { + result := []string{} + if configDir != defaultConfigDir { + configDir = util.CleanDirInput(configDir) + result = append(result, "--"+configDirFlag) + result = append(result, configDir) + } + if configFile != defaultConfigFile { + result = append(result, "--"+configFileFlag) + result = append(result, configFile) + } + if logFilePath != defaultLogFile { + result = append(result, "--"+logFilePathFlag) + result = append(result, logFilePath) + } + if logMaxSize != defaultLogMaxSize { + result = append(result, "--"+logMaxSizeFlag) + result = append(result, strconv.Itoa(logMaxSize)) + } + if logMaxBackups != defaultLogMaxBackup { + result = append(result, "--"+logMaxBackupFlag) + result = append(result, strconv.Itoa(logMaxBackups)) + } + if logMaxAge != defaultLogMaxAge { + result = append(result, "--"+logMaxAgeFlag) + result = append(result, strconv.Itoa(logMaxAge)) + } + if logLevel != defaultLogLevel { + result = append(result, "--"+logLevelFlag) + result = append(result, logLevel) + } + if logUTCTime != defaultLogUTCTime { + result = append(result, "--"+logUTCTimeFlag+"=true") + } + if logCompress != defaultLogCompress { + result = append(result, "--"+logCompressFlag+"=true") + } + if graceTime != defaultGraceTime { + result = append(result, "--"+graceTimeFlag) + result = append(result, strconv.Itoa(graceTime)) + } + return result +} diff --git a/internal/cmd/ping.go b/internal/cmd/ping.go new file mode 100644 index 00000000..525c0ccf --- /dev/null +++ b/internal/cmd/ping.go @@ -0,0 +1,120 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "net/http" + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getHealthzURLFromBindings(bindings []httpd.Binding) string { + for _, b := range bindings { + if b.Port > 0 && b.IsValid() { + var url string + if b.EnableHTTPS { + url = "https://" + } else { + url = "http://" + } + if b.Address == "" { + url += "127.0.0.1" + } else { + url += b.Address + } + url += fmt.Sprintf(":%d", b.Port) + url += "/healthz" + return url + } + } + return "" +} + +var ( + pingCmd = &cobra.Command{ + Use: "ping", + Short: "Issues an health check to SFTPGo", + Long: `This command is only useful in environments where system commands like +"curl", "wget" and similar are not available. +Checks over UNIX domain sockets are not supported`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.WarnToConsole("Unable to load configuration: %v", err) + os.Exit(1) + } + httpConfig := config.GetHTTPConfig() + err = httpConfig.Initialize(configDir) + if err != nil { + logger.ErrorToConsole("error initializing http client: %v", err) + os.Exit(1) + } + telemetryConfig := config.GetTelemetryConfig() + var url string + if telemetryConfig.BindPort > 0 { + if telemetryConfig.CertificateFile != "" && telemetryConfig.CertificateKeyFile != "" { + url += "https://" + } else { + url += "http://" + } + if telemetryConfig.BindAddress == "" { + url += "127.0.0.1" + } else { + url += telemetryConfig.BindAddress + } + url += fmt.Sprintf(":%d", telemetryConfig.BindPort) + url += "/healthz" + } + if url == "" { + httpdConfig := config.GetHTTPDConfig() + url = getHealthzURLFromBindings(httpdConfig.Bindings) + } + if url == "" { + logger.ErrorToConsole("no suitable configuration found, please enable the telemetry server or REST API over HTTP/S") + os.Exit(1) + } + + logger.DebugToConsole("Health Check URL %q", url) + resp, err := httpclient.RetryableGet(url) + if err != nil { + logger.ErrorToConsole("Unable to connect to SFTPGo: %v", err) + os.Exit(1) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.ErrorToConsole("Unexpected status code %d", resp.StatusCode) + os.Exit(1) + } + logger.InfoToConsole("OK") + }, + } +) + +func init() { + addConfigFlags(pingCmd) + rootCmd.AddCommand(pingCmd) +} diff --git a/internal/cmd/portable.go b/internal/cmd/portable.go new file mode 100644 index 00000000..356371a9 --- /dev/null +++ b/internal/cmd/portable.go @@ -0,0 +1,545 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !noportable + +package cmd + +import ( + "fmt" + "os" + "path" + "path/filepath" + "strings" + + "github.com/sftpgo/sdk" + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/service" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + directoryToServe string + portableSFTPDPort int + portableUsername string + portablePassword string + portablePasswordFile string + portableStartDir string + portableLogFile string + portableLogLevel string + portableLogUTCTime bool + portablePublicKeys []string + portablePermissions []string + portableSSHCommands []string + portableAllowedPatterns []string + portableDeniedPatterns []string + portableFsProvider string + portableS3Bucket string + portableS3Region string + portableS3AccessKey string + portableS3AccessSecret string + portableS3RoleARN string + portableS3Endpoint string + portableS3StorageClass string + portableS3ACL string + portableS3KeyPrefix string + portableS3ULPartSize int + portableS3ULConcurrency int + portableS3ForcePathStyle bool + portableS3SkipTLSVerify bool + portableGCSBucket string + portableGCSCredentialsFile string + portableGCSAutoCredentials int + portableGCSStorageClass string + portableGCSKeyPrefix string + portableFTPDPort int + portableFTPSCert string + portableFTPSKey string + portableWebDAVPort int + portableWebDAVCert string + portableWebDAVKey string + portableHTTPPort int + portableHTTPSCert string + portableHTTPSKey string + portableAzContainer string + portableAzAccountName string + portableAzAccountKey string + portableAzEndpoint string + portableAzAccessTier string + portableAzSASURL string + portableAzKeyPrefix string + portableAzULPartSize int + portableAzULConcurrency int + portableAzDLPartSize int + portableAzDLConcurrency int + portableAzUseEmulator bool + portableCryptPassphrase string + portableSFTPEndpoint string + portableSFTPUsername string + portableSFTPPassword string + portableSFTPPrivateKeyPath string + portableSFTPFingerprints []string + portableSFTPPrefix string + portableSFTPDisableConcurrentReads bool + portableSFTPDBufferSize int64 + portableCmd = &cobra.Command{ + Use: "portable", + Short: "Serve a single directory/account", + Long: `To serve the current working directory with auto generated credentials simply +use: + +$ sftpgo portable + +Please take a look at the usage below to customize the serving parameters`, + Run: func(_ *cobra.Command, _ []string) { + portableDir := directoryToServe + fsProvider := dataprovider.GetProviderFromValue(convertFsProvider()) + if !filepath.IsAbs(portableDir) { + if fsProvider == sdk.LocalFilesystemProvider { + portableDir, _ = filepath.Abs(portableDir) + } else { + portableDir = os.TempDir() + } + } + permissions := make(map[string][]string) + permissions["/"] = portablePermissions + portableGCSCredentials := "" + if fsProvider == sdk.GCSFilesystemProvider && portableGCSCredentialsFile != "" { + contents, err := getFileContents(portableGCSCredentialsFile) + if err != nil { + fmt.Printf("Unable to get GCS credentials: %v\n", err) + os.Exit(1) + } + portableGCSCredentials = contents + portableGCSAutoCredentials = 0 + } + portableSFTPPrivateKey := "" + if fsProvider == sdk.SFTPFilesystemProvider && portableSFTPPrivateKeyPath != "" { + contents, err := getFileContents(portableSFTPPrivateKeyPath) + if err != nil { + fmt.Printf("Unable to get SFTP private key: %v\n", err) + os.Exit(1) + } + portableSFTPPrivateKey = contents + } + if portableFTPDPort >= 0 && portableFTPSCert != "" && portableFTPSKey != "" { + keyPairs := []common.TLSKeyPair{ + { + Cert: portableFTPSCert, + Key: portableFTPSKey, + ID: common.DefaultTLSKeyPaidID, + }, + } + _, err := common.NewCertManager(keyPairs, filepath.Clean(defaultConfigDir), + "FTP portable") + if err != nil { + fmt.Printf("Unable to load FTPS key pair, cert file %q key file %q error: %v\n", + portableFTPSCert, portableFTPSKey, err) + os.Exit(1) + } + } + if portableWebDAVPort >= 0 && portableWebDAVCert != "" && portableWebDAVKey != "" { + keyPairs := []common.TLSKeyPair{ + { + Cert: portableWebDAVCert, + Key: portableWebDAVKey, + ID: common.DefaultTLSKeyPaidID, + }, + } + _, err := common.NewCertManager(keyPairs, filepath.Clean(defaultConfigDir), + "WebDAV portable") + if err != nil { + fmt.Printf("Unable to load WebDAV key pair, cert file %q key file %q error: %v\n", + portableWebDAVCert, portableWebDAVKey, err) + os.Exit(1) + } + } + if portableHTTPPort >= 0 && portableHTTPSCert != "" && portableHTTPSKey != "" { + keyPairs := []common.TLSKeyPair{ + { + Cert: portableHTTPSCert, + Key: portableHTTPSKey, + ID: common.DefaultTLSKeyPaidID, + }, + } + _, err := common.NewCertManager(keyPairs, filepath.Clean(defaultConfigDir), + "HTTP portable") + if err != nil { + fmt.Printf("Unable to load HTTPS key pair, cert file %q key file %q error: %v\n", + portableHTTPSCert, portableHTTPSKey, err) + os.Exit(1) + } + } + pwd := portablePassword + if portablePasswordFile != "" { + content, err := os.ReadFile(portablePasswordFile) + if err != nil { + fmt.Printf("Unable to read password file %q: %v", portablePasswordFile, err) + os.Exit(1) + } + pwd = strings.TrimSpace(util.BytesToString(content)) + } + service.SetGraceTime(graceTime) + service := service.Service{ + ConfigDir: util.CleanDirInput(configDir), + ConfigFile: configFile, + LogFilePath: portableLogFile, + LogMaxSize: defaultLogMaxSize, + LogMaxBackups: defaultLogMaxBackup, + LogMaxAge: defaultLogMaxAge, + LogCompress: defaultLogCompress, + LogLevel: portableLogLevel, + LogUTCTime: portableLogUTCTime, + Shutdown: make(chan bool), + PortableMode: 1, + PortableUser: dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: portableUsername, + Password: pwd, + PublicKeys: portablePublicKeys, + Permissions: permissions, + HomeDir: portableDir, + Status: 1, + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + FilePatterns: parsePatternsFilesFilters(), + StartDirectory: portableStartDir, + }, + }, + FsConfig: vfs.Filesystem{ + Provider: fsProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: portableS3Bucket, + Region: portableS3Region, + AccessKey: portableS3AccessKey, + RoleARN: portableS3RoleARN, + Endpoint: portableS3Endpoint, + StorageClass: portableS3StorageClass, + ACL: portableS3ACL, + KeyPrefix: portableS3KeyPrefix, + UploadPartSize: int64(portableS3ULPartSize), + UploadConcurrency: portableS3ULConcurrency, + ForcePathStyle: portableS3ForcePathStyle, + SkipTLSVerify: portableS3SkipTLSVerify, + }, + AccessSecret: kms.NewPlainSecret(portableS3AccessSecret), + }, + GCSConfig: vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: portableGCSBucket, + AutomaticCredentials: portableGCSAutoCredentials, + StorageClass: portableGCSStorageClass, + KeyPrefix: portableGCSKeyPrefix, + }, + Credentials: kms.NewPlainSecret(portableGCSCredentials), + }, + AzBlobConfig: vfs.AzBlobFsConfig{ + BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ + Container: portableAzContainer, + AccountName: portableAzAccountName, + Endpoint: portableAzEndpoint, + AccessTier: portableAzAccessTier, + KeyPrefix: portableAzKeyPrefix, + UseEmulator: portableAzUseEmulator, + UploadPartSize: int64(portableAzULPartSize), + UploadConcurrency: portableAzULConcurrency, + DownloadPartSize: int64(portableAzDLPartSize), + DownloadConcurrency: portableAzDLConcurrency, + }, + AccountKey: kms.NewPlainSecret(portableAzAccountKey), + SASURL: kms.NewPlainSecret(portableAzSASURL), + }, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(portableCryptPassphrase), + }, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: portableSFTPEndpoint, + Username: portableSFTPUsername, + Fingerprints: portableSFTPFingerprints, + Prefix: portableSFTPPrefix, + DisableCouncurrentReads: portableSFTPDisableConcurrentReads, + BufferSize: portableSFTPDBufferSize, + }, + Password: kms.NewPlainSecret(portableSFTPPassword), + PrivateKey: kms.NewPlainSecret(portableSFTPPrivateKey), + KeyPassphrase: kms.NewEmptySecret(), + }, + }, + }, + } + err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableHTTPPort, + portableSSHCommands, portableFTPSCert, portableFTPSKey, portableWebDAVCert, portableWebDAVKey, + portableHTTPSCert, portableHTTPSKey) + if err == nil { + service.Wait() + if service.Error == nil { + os.Exit(0) + } + } + os.Exit(1) + }, + } +) + +func init() { + version.AddFeature("+portable") + + portableCmd.Flags().StringVarP(&directoryToServe, "directory", "d", ".", `Path to the directory to serve. +This can be an absolute path or a path +relative to the current directory +`) + portableCmd.Flags().StringVar(&portableStartDir, "start-directory", "/", `Alternate start directory. +This is a virtual path not a filesystem +path`) + portableCmd.Flags().IntVarP(&portableSFTPDPort, "sftpd-port", "s", 0, `0 means a random unprivileged port, +< 0 disabled`) + portableCmd.Flags().IntVar(&portableFTPDPort, "ftpd-port", -1, `0 means a random unprivileged port, +< 0 disabled`) + portableCmd.Flags().IntVar(&portableWebDAVPort, "webdav-port", -1, `0 means a random unprivileged port, +< 0 disabled`) + portableCmd.Flags().IntVar(&portableHTTPPort, "httpd-port", -1, `0 means a random unprivileged port, +< 0 disabled`) + portableCmd.Flags().StringSliceVar(&portableSSHCommands, "ssh-commands", sftpd.GetDefaultSSHCommands(), + `SSH commands to enable. +"*" means any supported SSH command +including scp +`) + portableCmd.Flags().StringVarP(&portableUsername, "username", "u", "", `Leave empty to use an auto generated +value`) + portableCmd.Flags().StringVarP(&portablePassword, "password", "p", "", `Leave empty to use an auto generated +value`) + portableCmd.Flags().StringVar(&portablePasswordFile, "password-file", "", `Read the password from the specified +file path. Leave empty to use an auto +generated value`) + portableCmd.Flags().StringVarP(&portableLogFile, logFilePathFlag, "l", "", "Leave empty to disable logging") + portableCmd.Flags().StringVar(&portableLogLevel, logLevelFlag, defaultLogLevel, `Set the log level. +Supported values: + +debug, info, warn, error. +`) + portableCmd.Flags().BoolVar(&portableLogUTCTime, logUTCTimeFlag, false, "Use UTC time for logging") + portableCmd.Flags().StringSliceVarP(&portablePublicKeys, "public-key", "k", []string{}, "") + portableCmd.Flags().StringSliceVarP(&portablePermissions, "permissions", "g", []string{"list", "download"}, + `User's permissions. "*" means any +permission`) + portableCmd.Flags().StringArrayVar(&portableAllowedPatterns, "allowed-patterns", []string{}, + `Allowed file patterns case insensitive. +The format is: +/dir::pattern1,pattern2. +For example: "/somedir::*.jpg,a*b?.png"`) + portableCmd.Flags().StringArrayVar(&portableDeniedPatterns, "denied-patterns", []string{}, + `Denied file patterns case insensitive. +The format is: +/dir::pattern1,pattern2. +For example: "/somedir::*.jpg,a*b?.png"`) + portableCmd.Flags().StringVarP(&portableFsProvider, "fs-provider", "f", "osfs", `osfs => local filesystem (legacy value: 0) +s3fs => AWS S3 compatible (legacy: 1) +gcsfs => Google Cloud Storage (legacy: 2) +azblobfs => Azure Blob Storage (legacy: 3) +cryptfs => Encrypted local filesystem (legacy: 4) +sftpfs => SFTP (legacy: 5)`) + portableCmd.Flags().StringVar(&portableS3Bucket, "s3-bucket", "", "") + portableCmd.Flags().StringVar(&portableS3Region, "s3-region", "", "") + portableCmd.Flags().StringVar(&portableS3AccessKey, "s3-access-key", "", "") + portableCmd.Flags().StringVar(&portableS3AccessSecret, "s3-access-secret", "", "") + portableCmd.Flags().StringVar(&portableS3RoleARN, "s3-role-arn", "", "") + portableCmd.Flags().StringVar(&portableS3Endpoint, "s3-endpoint", "", "") + portableCmd.Flags().StringVar(&portableS3StorageClass, "s3-storage-class", "", "") + portableCmd.Flags().StringVar(&portableS3ACL, "s3-acl", "", "") + portableCmd.Flags().StringVar(&portableS3KeyPrefix, "s3-key-prefix", "", `Allows to restrict access to the +virtual folder identified by this +prefix and its contents`) + portableCmd.Flags().IntVar(&portableS3ULPartSize, "s3-upload-part-size", 5, `The buffer size for multipart uploads +(MB)`) + portableCmd.Flags().IntVar(&portableS3ULConcurrency, "s3-upload-concurrency", 2, `How many parts are uploaded in +parallel`) + portableCmd.Flags().BoolVar(&portableS3ForcePathStyle, "s3-force-path-style", false, `Force path style bucket URL`) + portableCmd.Flags().BoolVar(&portableS3SkipTLSVerify, "s3-skip-tls-verify", false, `If enabled the S3 client accepts any TLS +certificate presented by the server and +any host name in that certificate. +In this mode, TLS is susceptible to +man-in-the-middle attacks. +This should be used only for testing. +`) + portableCmd.Flags().StringVar(&portableGCSBucket, "gcs-bucket", "", "") + portableCmd.Flags().StringVar(&portableGCSStorageClass, "gcs-storage-class", "", "") + portableCmd.Flags().StringVar(&portableGCSKeyPrefix, "gcs-key-prefix", "", `Allows to restrict access to the +virtual folder identified by this +prefix and its contents`) + portableCmd.Flags().StringVar(&portableGCSCredentialsFile, "gcs-credentials-file", "", `Google Cloud Storage JSON credentials +file`) + portableCmd.Flags().IntVar(&portableGCSAutoCredentials, "gcs-automatic-credentials", 1, `0 means explicit credentials using +a JSON credentials file, 1 automatic +`) + portableCmd.Flags().StringVar(&portableFTPSCert, "ftpd-cert", "", "Path to the certificate file for FTPS") + portableCmd.Flags().StringVar(&portableFTPSKey, "ftpd-key", "", "Path to the key file for FTPS") + portableCmd.Flags().StringVar(&portableWebDAVCert, "webdav-cert", "", `Path to the certificate file for WebDAV +over HTTPS`) + portableCmd.Flags().StringVar(&portableWebDAVKey, "webdav-key", "", `Path to the key file for WebDAV over +HTTPS`) + portableCmd.Flags().StringVar(&portableHTTPSCert, "httpd-cert", "", `Path to the certificate file for WebClient +over HTTPS`) + portableCmd.Flags().StringVar(&portableHTTPSKey, "httpd-key", "", `Path to the key file for WebClient over +HTTPS`) + portableCmd.Flags().StringVar(&portableAzContainer, "az-container", "", "") + portableCmd.Flags().StringVar(&portableAzAccountName, "az-account-name", "", "") + portableCmd.Flags().StringVar(&portableAzAccountKey, "az-account-key", "", "") + portableCmd.Flags().StringVar(&portableAzSASURL, "az-sas-url", "", `Shared access signature URL`) + portableCmd.Flags().StringVar(&portableAzEndpoint, "az-endpoint", "", `Leave empty to use the default: +"blob.core.windows.net"`) + portableCmd.Flags().StringVar(&portableAzAccessTier, "az-access-tier", "", `Leave empty to use the default +container setting`) + portableCmd.Flags().StringVar(&portableAzKeyPrefix, "az-key-prefix", "", `Allows to restrict access to the +virtual folder identified by this +prefix and its contents`) + portableCmd.Flags().IntVar(&portableAzULPartSize, "az-upload-part-size", 5, `The buffer size for multipart uploads +(MB)`) + portableCmd.Flags().IntVar(&portableAzULConcurrency, "az-upload-concurrency", 5, `How many parts are uploaded in +parallel`) + portableCmd.Flags().IntVar(&portableAzDLPartSize, "az-download-part-size", 5, `The buffer size for multipart downloads +(MB)`) + portableCmd.Flags().IntVar(&portableAzDLConcurrency, "az-download-concurrency", 5, `How many parts are downloaded in +parallel`) + portableCmd.Flags().BoolVar(&portableAzUseEmulator, "az-use-emulator", false, "") + portableCmd.Flags().StringVar(&portableCryptPassphrase, "crypto-passphrase", "", `Passphrase for encryption/decryption`) + portableCmd.Flags().StringVar(&portableSFTPEndpoint, "sftp-endpoint", "", `SFTP endpoint as host:port for SFTP +provider`) + portableCmd.Flags().StringVar(&portableSFTPUsername, "sftp-username", "", `SFTP user for SFTP provider`) + portableCmd.Flags().StringVar(&portableSFTPPassword, "sftp-password", "", `SFTP password for SFTP provider`) + portableCmd.Flags().StringVar(&portableSFTPPrivateKeyPath, "sftp-key-path", "", `SFTP private key path for SFTP provider`) + portableCmd.Flags().StringSliceVar(&portableSFTPFingerprints, "sftp-fingerprints", []string{}, `SFTP fingerprints to verify remote host +key for SFTP provider`) + portableCmd.Flags().StringVar(&portableSFTPPrefix, "sftp-prefix", "", `SFTP prefix allows restrict all +operations to a given path within the +remote SFTP server`) + portableCmd.Flags().BoolVar(&portableSFTPDisableConcurrentReads, "sftp-disable-concurrent-reads", false, `Concurrent reads are safe to use and +disabling them will degrade performance. +Disable for read once servers`) + portableCmd.Flags().Int64Var(&portableSFTPDBufferSize, "sftp-buffer-size", 0, `The size of the buffer (in MB) to use +for transfers. By enabling buffering, +the reads and writes, from/to the +remote SFTP server, are split in +multiple concurrent requests and this +allows data to be transferred at a +faster rate, over high latency networks, +by overlapping round-trip times`) + portableCmd.Flags().IntVar(&graceTime, graceTimeFlag, 0, + `This grace time defines the number of +seconds allowed for existing transfers +to get completed before shutting down. +A graceful shutdown is triggered by an +interrupt signal. +`) + addConfigFlags(portableCmd) + rootCmd.AddCommand(portableCmd) +} + +func parsePatternsFilesFilters() []sdk.PatternsFilter { + var patterns []sdk.PatternsFilter + for _, val := range portableAllowedPatterns { + p, exts := getPatternsFilterValues(strings.TrimSpace(val)) + if p != "" { + patterns = append(patterns, sdk.PatternsFilter{ + Path: path.Clean(p), + AllowedPatterns: exts, + DeniedPatterns: []string{}, + }) + } + } + for _, val := range portableDeniedPatterns { + p, exts := getPatternsFilterValues(strings.TrimSpace(val)) + if p != "" { + found := false + for index, e := range patterns { + if path.Clean(e.Path) == path.Clean(p) { + patterns[index].DeniedPatterns = append(patterns[index].DeniedPatterns, exts...) + found = true + break + } + } + if !found { + patterns = append(patterns, sdk.PatternsFilter{ + Path: path.Clean(p), + AllowedPatterns: []string{}, + DeniedPatterns: exts, + }) + } + } + } + return patterns +} + +func getPatternsFilterValues(value string) (string, []string) { + if strings.Contains(value, "::") { + dirExts := strings.Split(value, "::") + if len(dirExts) > 1 { + dir := strings.TrimSpace(dirExts[0]) + exts := []string{} + for e := range strings.SplitSeq(dirExts[1], ",") { + cleanedExt := strings.TrimSpace(e) + if cleanedExt != "" { + exts = append(exts, cleanedExt) + } + } + if dir != "" && len(exts) > 0 { + return dir, exts + } + } + } + return "", nil +} + +func getFileContents(name string) (string, error) { + fi, err := os.Stat(name) + if err != nil { + return "", err + } + if fi.Size() > 1048576 { + return "", fmt.Errorf("%q is too big %v/1048576 bytes", name, fi.Size()) + } + contents, err := os.ReadFile(name) + if err != nil { + return "", err + } + return util.BytesToString(contents), nil +} + +func convertFsProvider() string { + switch portableFsProvider { + case "osfs", "6": // httpfs (6) is not supported in portable mode, so return the default + return "0" + case "s3fs": + return "1" + case "gcsfs": + return "2" + case "azblobfs": + return "3" + case "cryptfs": + return "4" + case "sftpfs": + return "5" + default: + return portableFsProvider + } +} diff --git a/internal/cmd/portable_disabled.go b/internal/cmd/portable_disabled.go new file mode 100644 index 00000000..f043ee7e --- /dev/null +++ b/internal/cmd/portable_disabled.go @@ -0,0 +1,23 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build noportable + +package cmd + +import "github.com/drakkan/sftpgo/v2/internal/version" + +func init() { + version.AddFeature("-portable") +} diff --git a/internal/cmd/reload_windows.go b/internal/cmd/reload_windows.go new file mode 100644 index 00000000..359f631b --- /dev/null +++ b/internal/cmd/reload_windows.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" +) + +var ( + reloadCmd = &cobra.Command{ + Use: "reload", + Short: "Reload the SFTPGo Windows Service sending a \"paramchange\" request", + Run: func(_ *cobra.Command, _ []string) { + s := service.WindowsService{ + Service: service.Service{ + Shutdown: make(chan bool), + }, + } + err := s.Reload() + if err != nil { + fmt.Printf("Error sending reload signal: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Reload signal sent!\r\n") + } + }, + } +) + +func init() { + serviceCmd.AddCommand(reloadCmd) +} diff --git a/internal/cmd/resetprovider.go b/internal/cmd/resetprovider.go new file mode 100644 index 00000000..192ec4ff --- /dev/null +++ b/internal/cmd/resetprovider.go @@ -0,0 +1,89 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "bufio" + "os" + "strings" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + resetProviderForce bool + resetProviderCmd = &cobra.Command{ + Use: "resetprovider", + Short: "Reset the configured provider, any data will be lost", + Long: `This command reads the data provider connection details from the specified +configuration file and resets the provider by deleting all data and schemas. +This command is not supported for the memory provider. + +Please take a look at the usage below to customize the options.`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.WarnToConsole("Unable to load configuration: %v", err) + os.Exit(1) + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("unable to initialize KMS: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + if !resetProviderForce { + logger.WarnToConsole("You are about to delete all the SFTPGo data for provider %q, config file: %q", + providerConf.Driver, viper.ConfigFileUsed()) + logger.WarnToConsole("Are you sure? (Y/n)") + reader := bufio.NewReader(os.Stdin) + answer, err := reader.ReadString('\n') + if err != nil { + logger.ErrorToConsole("unable to read your answer: %v", err) + os.Exit(1) + } + if strings.ToUpper(strings.TrimSpace(answer)) != "Y" { + logger.InfoToConsole("command aborted") + os.Exit(1) + } + } + logger.InfoToConsole("Resetting provider: %q, config file: %q", providerConf.Driver, viper.ConfigFileUsed()) + err = dataprovider.ResetDatabase(providerConf, configDir) + if err != nil { + logger.WarnToConsole("Error resetting provider: %v", err) + os.Exit(1) + } + logger.InfoToConsole("Tha data provider was successfully reset") + }, + } +) + +func init() { + addConfigFlags(resetProviderCmd) + resetProviderCmd.Flags().BoolVar(&resetProviderForce, "force", false, `reset the provider without asking for confirmation`) + + rootCmd.AddCommand(resetProviderCmd) +} diff --git a/internal/cmd/resetpwd.go b/internal/cmd/resetpwd.go new file mode 100644 index 00000000..eb009e31 --- /dev/null +++ b/internal/cmd/resetpwd.go @@ -0,0 +1,128 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "bytes" + "fmt" + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + "golang.org/x/term" + + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + resetPwdAdmin string + resetPwdCmd = &cobra.Command{ + Use: "resetpwd", + Short: "Reset the password for the specified administrator", + Long: `This command reads the data provider connection details from the specified +configuration file and resets the password for the specified administrator. +Two-factor authentication is also disabled. +This command is not supported for the memory provider. +For embedded providers like bolt and SQLite you should stop the running SFTPGo +instance to avoid database corruption. + +Please take a look at the usage below to customize the options.`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.WarnToConsole("Unable to load configuration: %v", err) + os.Exit(1) + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("unable to initialize KMS: %v", err) + os.Exit(1) + } + if config.HasKMSPlugin() { + if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { + logger.ErrorToConsole("unable to initialize plugin system: %v", err) + os.Exit(1) + } + registerSignals() + defer plugin.Handler.Cleanup() + } + + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("Unable to initialize MFA: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + if providerConf.Driver == dataprovider.MemoryDataProviderName { + logger.ErrorToConsole("memory provider is not supported") + os.Exit(1) + } + logger.InfoToConsole("Initializing provider: %q config file: %q", providerConf.Driver, viper.ConfigFileUsed()) + err = dataprovider.Initialize(providerConf, configDir, false) + if err != nil { + logger.ErrorToConsole("Unable to initialize data provider: %v", err) + os.Exit(1) + } + admin, err := dataprovider.AdminExists(resetPwdAdmin) + if err != nil { + logger.ErrorToConsole("Unable to get admin %q: %v", resetPwdAdmin, err) + os.Exit(1) + } + fmt.Printf("Enter Password: ") + pwd, err := term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + logger.ErrorToConsole("Unable to read the password: %v", err) + os.Exit(1) + } + fmt.Println("") + fmt.Printf("Confirm Password: ") + confirmPwd, err := term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + logger.ErrorToConsole("Unable to read the password: %v", err) + os.Exit(1) + } + fmt.Println("") + if !bytes.Equal(pwd, confirmPwd) { + logger.ErrorToConsole("Passwords do not match") + os.Exit(1) + } + admin.Password = string(pwd) + admin.Filters.TOTPConfig.Enabled = false + if err := dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSystem, "", ""); err != nil { + logger.ErrorToConsole("Unable to update password: %v", err) + os.Exit(1) + } + logger.InfoToConsole("Password updated for admin %q", resetPwdAdmin) + }, + } +) + +func init() { + addConfigFlags(resetPwdCmd) + resetPwdCmd.Flags().StringVar(&resetPwdAdmin, "admin", "", `Administrator username whose password to reset`) + resetPwdCmd.MarkFlagRequired("admin") //nolint:errcheck + + rootCmd.AddCommand(resetPwdCmd) +} diff --git a/internal/cmd/revertprovider.go b/internal/cmd/revertprovider.go new file mode 100644 index 00000000..09b8d1a8 --- /dev/null +++ b/internal/cmd/revertprovider.go @@ -0,0 +1,93 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + revertProviderTargetVersion int + revertProviderCmd = &cobra.Command{ + Use: "revertprovider", + Short: "Revert the configured data provider to a previous version", + Long: `This command reads the data provider connection details from the specified +configuration file and restore the provider schema and/or data to a previous version. +This command is not supported for the memory provider. + +Please take a look at the usage below to customize the options.`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + if revertProviderTargetVersion != 33 { + logger.WarnToConsole("Unsupported target version, 33 is the only supported one") + os.Exit(1) + } + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.WarnToConsole("Unable to load configuration: %v", err) + os.Exit(1) + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("unable to initialize KMS: %v", err) + os.Exit(1) + } + if config.HasKMSPlugin() { + if err := plugin.Initialize(config.GetPluginsConfig(), "debug"); err != nil { + logger.ErrorToConsole("unable to initialize plugin system: %v", err) + os.Exit(1) + } + registerSignals() + defer plugin.Handler.Cleanup() + } + + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("Unable to initialize MFA: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + logger.InfoToConsole("Reverting provider: %q config file: %q target version %d", providerConf.Driver, + viper.ConfigFileUsed(), revertProviderTargetVersion) + err = dataprovider.RevertDatabase(providerConf, configDir, revertProviderTargetVersion) + if err != nil { + logger.WarnToConsole("Error reverting provider: %v", err) + os.Exit(1) + } + logger.InfoToConsole("Data provider successfully reverted") + }, + } +) + +func init() { + addConfigFlags(revertProviderCmd) + revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 33, `33 means the version supported in v2.7.x`) + + rootCmd.AddCommand(revertProviderCmd) +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go new file mode 100644 index 00000000..00759236 --- /dev/null +++ b/internal/cmd/root.go @@ -0,0 +1,283 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package cmd provides Command Line Interface support +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + configDirFlag = "config-dir" + configDirKey = "config_dir" + configFileFlag = "config-file" + configFileKey = "config_file" + logFilePathFlag = "log-file-path" + logFilePathKey = "log_file_path" + logMaxSizeFlag = "log-max-size" + logMaxSizeKey = "log_max_size" + logMaxBackupFlag = "log-max-backups" + logMaxBackupKey = "log_max_backups" + logMaxAgeFlag = "log-max-age" + logMaxAgeKey = "log_max_age" + logCompressFlag = "log-compress" + logCompressKey = "log_compress" + logLevelFlag = "log-level" + logLevelKey = "log_level" + logUTCTimeFlag = "log-utc-time" + logUTCTimeKey = "log_utc_time" + loadDataFromFlag = "loaddata-from" + loadDataFromKey = "loaddata_from" + loadDataModeFlag = "loaddata-mode" + loadDataModeKey = "loaddata_mode" + loadDataQuotaScanFlag = "loaddata-scan" + loadDataQuotaScanKey = "loaddata_scan" + loadDataCleanFlag = "loaddata-clean" + loadDataCleanKey = "loaddata_clean" + graceTimeFlag = "grace-time" + graceTimeKey = "grace_time" + defaultConfigDir = "." + defaultConfigFile = "" + defaultLogFile = "sftpgo.log" + defaultLogMaxSize = 10 + defaultLogMaxBackup = 5 + defaultLogMaxAge = 28 + defaultLogCompress = false + defaultLogLevel = "debug" + defaultLogUTCTime = false + defaultLoadDataFrom = "" + defaultLoadDataMode = 1 + defaultLoadDataQuotaScan = 0 + defaultLoadDataClean = false + defaultGraceTime = 0 +) + +var ( + configDir string + configFile string + logFilePath string + logMaxSize int + logMaxBackups int + logMaxAge int + logCompress bool + logLevel string + logUTCTime bool + loadDataFrom string + loadDataMode int + loadDataQuotaScan int + loadDataClean bool + graceTime int + + rootCmd = &cobra.Command{ + Use: "sftpgo", + Short: "Full-featured and highly configurable file transfer server", + } +) + +func init() { + rootCmd.CompletionOptions.DisableDefaultCmd = true + rootCmd.Flags().BoolP("version", "v", false, "") + rootCmd.Version = version.GetAsString() + rootCmd.SetVersionTemplate(`{{printf "SFTPGo "}}{{printf "%s" .Version}} +`) +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func addConfigFlags(cmd *cobra.Command) { + viper.SetDefault(configDirKey, defaultConfigDir) + viper.BindEnv(configDirKey, "SFTPGO_CONFIG_DIR") //nolint:errcheck // err is not nil only if the key to bind is missing + cmd.Flags().StringVarP(&configDir, configDirFlag, "c", viper.GetString(configDirKey), + `Location of the config dir. This directory +is used as the base for files with a relative +path, e.g. the private keys for the SFTP +server or the database file if you use a +file-based data provider. +The configuration file, if not explicitly set, +is looked for in this dir. We support reading +from JSON, TOML, YAML, HCL, envfile and Java +properties config files. The default config +file name is "sftpgo" and therefore +"sftpgo.json", "sftpgo.yaml" and so on are +searched. +This flag can be set using SFTPGO_CONFIG_DIR +env var too.`) + viper.BindPFlag(configDirKey, cmd.Flags().Lookup(configDirFlag)) //nolint:errcheck + + viper.SetDefault(configFileKey, defaultConfigFile) + viper.BindEnv(configFileKey, "SFTPGO_CONFIG_FILE") //nolint:errcheck + cmd.Flags().StringVar(&configFile, configFileFlag, viper.GetString(configFileKey), + `Path to SFTPGo configuration file. +This flag explicitly defines the path, name +and extension of the config file. If must be +an absolute path or a path relative to the +configuration directory. The specified file +name must have a supported extension (JSON, +YAML, TOML, HCL or Java properties). +This flag can be set using SFTPGO_CONFIG_FILE +env var too.`) + viper.BindPFlag(configFileKey, cmd.Flags().Lookup(configFileFlag)) //nolint:errcheck +} + +func addBaseLoadDataFlags(cmd *cobra.Command) { + viper.SetDefault(loadDataFromKey, defaultLoadDataFrom) + viper.BindEnv(loadDataFromKey, "SFTPGO_LOADDATA_FROM") //nolint:errcheck + cmd.Flags().StringVar(&loadDataFrom, loadDataFromFlag, viper.GetString(loadDataFromKey), + `Load users and folders from this file. +The file must be specified as absolute path +and it must contain a backup obtained using +the "dumpdata" REST API or compatible content. +This flag can be set using SFTPGO_LOADDATA_FROM +env var too. +`) + viper.BindPFlag(loadDataFromKey, cmd.Flags().Lookup(loadDataFromFlag)) //nolint:errcheck + + viper.SetDefault(loadDataModeKey, defaultLoadDataMode) + viper.BindEnv(loadDataModeKey, "SFTPGO_LOADDATA_MODE") //nolint:errcheck + cmd.Flags().IntVar(&loadDataMode, loadDataModeFlag, viper.GetInt(loadDataModeKey), + `Restore mode for data to load: + 0 - new users are added, existing users are + updated + 1 - New users are added, existing users are + not modified +This flag can be set using SFTPGO_LOADDATA_MODE +env var too. +`) + viper.BindPFlag(loadDataModeKey, cmd.Flags().Lookup(loadDataModeFlag)) //nolint:errcheck + + viper.SetDefault(loadDataCleanKey, defaultLoadDataClean) + viper.BindEnv(loadDataCleanKey, "SFTPGO_LOADDATA_CLEAN") //nolint:errcheck + cmd.Flags().BoolVar(&loadDataClean, loadDataCleanFlag, viper.GetBool(loadDataCleanKey), + `Determine if the loaddata-from file should +be removed after a successful load. This flag +can be set using SFTPGO_LOADDATA_CLEAN env var +too. (default "false") +`) + viper.BindPFlag(loadDataCleanKey, cmd.Flags().Lookup(loadDataCleanFlag)) //nolint:errcheck +} + +func addServeFlags(cmd *cobra.Command) { + addConfigFlags(cmd) + + viper.SetDefault(logFilePathKey, defaultLogFile) + viper.BindEnv(logFilePathKey, "SFTPGO_LOG_FILE_PATH") //nolint:errcheck + cmd.Flags().StringVarP(&logFilePath, logFilePathFlag, "l", viper.GetString(logFilePathKey), + `Location for the log file. Leave empty to write +logs to the standard output. This flag can be +set using SFTPGO_LOG_FILE_PATH env var too. +`) + viper.BindPFlag(logFilePathKey, cmd.Flags().Lookup(logFilePathFlag)) //nolint:errcheck + + viper.SetDefault(logMaxSizeKey, defaultLogMaxSize) + viper.BindEnv(logMaxSizeKey, "SFTPGO_LOG_MAX_SIZE") //nolint:errcheck + cmd.Flags().IntVarP(&logMaxSize, logMaxSizeFlag, "s", viper.GetInt(logMaxSizeKey), + `Maximum size in megabytes of the log file +before it gets rotated. This flag can be set +using SFTPGO_LOG_MAX_SIZE env var too. It is +unused if log-file-path is empty. +`) + viper.BindPFlag(logMaxSizeKey, cmd.Flags().Lookup(logMaxSizeFlag)) //nolint:errcheck + + viper.SetDefault(logMaxBackupKey, defaultLogMaxBackup) + viper.BindEnv(logMaxBackupKey, "SFTPGO_LOG_MAX_BACKUPS") //nolint:errcheck + cmd.Flags().IntVarP(&logMaxBackups, "log-max-backups", "b", viper.GetInt(logMaxBackupKey), + `Maximum number of old log files to retain. +This flag can be set using SFTPGO_LOG_MAX_BACKUPS +env var too. It is unused if log-file-path is +empty.`) + viper.BindPFlag(logMaxBackupKey, cmd.Flags().Lookup(logMaxBackupFlag)) //nolint:errcheck + + viper.SetDefault(logMaxAgeKey, defaultLogMaxAge) + viper.BindEnv(logMaxAgeKey, "SFTPGO_LOG_MAX_AGE") //nolint:errcheck + cmd.Flags().IntVarP(&logMaxAge, "log-max-age", "a", viper.GetInt(logMaxAgeKey), + `Maximum number of days to retain old log files. +This flag can be set using SFTPGO_LOG_MAX_AGE env +var too. It is unused if log-file-path is empty. +`) + viper.BindPFlag(logMaxAgeKey, cmd.Flags().Lookup(logMaxAgeFlag)) //nolint:errcheck + + viper.SetDefault(logCompressKey, defaultLogCompress) + viper.BindEnv(logCompressKey, "SFTPGO_LOG_COMPRESS") //nolint:errcheck + cmd.Flags().BoolVarP(&logCompress, logCompressFlag, "z", viper.GetBool(logCompressKey), + `Determine if the rotated log files +should be compressed using gzip. This flag can +be set using SFTPGO_LOG_COMPRESS env var too. +It is unused if log-file-path is empty. +`) + viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) //nolint:errcheck + + viper.SetDefault(logLevelKey, defaultLogLevel) + viper.BindEnv(logLevelKey, "SFTPGO_LOG_LEVEL") //nolint:errcheck + cmd.Flags().StringVar(&logLevel, logLevelFlag, viper.GetString(logLevelKey), + `Set the log level. Supported values: + +debug, info, warn, error. + +This flag can be set +using SFTPGO_LOG_LEVEL env var too. +`) + viper.BindPFlag(logLevelKey, cmd.Flags().Lookup(logLevelFlag)) //nolint:errcheck + + viper.SetDefault(logUTCTimeKey, defaultLogUTCTime) + viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck + cmd.Flags().BoolVar(&logUTCTime, logUTCTimeFlag, viper.GetBool(logUTCTimeKey), + `Use UTC time for logging. This flag can be set +using SFTPGO_LOG_UTC_TIME env var too. +`) + viper.BindPFlag(logUTCTimeKey, cmd.Flags().Lookup(logUTCTimeFlag)) //nolint:errcheck + + addBaseLoadDataFlags(cmd) + + viper.SetDefault(loadDataQuotaScanKey, defaultLoadDataQuotaScan) + viper.BindEnv(loadDataQuotaScanKey, "SFTPGO_LOADDATA_QUOTA_SCAN") //nolint:errcheck + cmd.Flags().IntVar(&loadDataQuotaScan, loadDataQuotaScanFlag, viper.GetInt(loadDataQuotaScanKey), + `Quota scan mode after data load: + 0 - no quota scan + 1 - scan quota + 2 - scan quota if the user has quota restrictions +This flag can be set using SFTPGO_LOADDATA_QUOTA_SCAN +env var too. +(default 0)`) + viper.BindPFlag(loadDataQuotaScanKey, cmd.Flags().Lookup(loadDataQuotaScanFlag)) //nolint:errcheck + + viper.SetDefault(graceTimeKey, defaultGraceTime) + viper.BindEnv(graceTimeKey, "SFTPGO_GRACE_TIME") //nolint:errcheck + cmd.Flags().IntVar(&graceTime, graceTimeFlag, viper.GetInt(graceTimeKey), + `Graceful shutdown is an option to initiate a +shutdown without abrupt cancellation of the +currently ongoing client-initiated transfer +sessions. +This grace time defines the number of seconds +allowed for existing transfers to get +completed before shutting down. +A graceful shutdown is triggered by an +interrupt signal. +This flag can be set using SFTPGO_GRACE_TIME env +var too. 0 means disabled. (default 0)`) + viper.BindPFlag(graceTimeKey, cmd.Flags().Lookup(graceTimeFlag)) //nolint:errcheck +} diff --git a/internal/cmd/rotatelogs_windows.go b/internal/cmd/rotatelogs_windows.go new file mode 100644 index 00000000..1eba93a8 --- /dev/null +++ b/internal/cmd/rotatelogs_windows.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" +) + +var ( + rotateLogCmd = &cobra.Command{ + Use: "rotatelogs", + Short: "Signal to the running service to rotate the logs", + Run: func(_ *cobra.Command, _ []string) { + s := service.WindowsService{ + Service: service.Service{ + Shutdown: make(chan bool), + }, + } + err := s.RotateLogFile() + if err != nil { + fmt.Printf("Error sending rotate log file signal to the service: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Rotate log file signal sent!\r\n") + } + }, + } +) + +func init() { + serviceCmd.AddCommand(rotateLogCmd) +} diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go new file mode 100644 index 00000000..348fb834 --- /dev/null +++ b/internal/cmd/serve.go @@ -0,0 +1,147 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/spf13/cobra" + "github.com/subosito/gotenv" + + "github.com/drakkan/sftpgo/v2/internal/service" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + envFileMaxSize = 1048576 +) + +var ( + serveCmd = &cobra.Command{ + Use: "serve", + Short: "Start the SFTPGo service", + Long: `To start the SFTPGo with the default values for the command line flags simply +use: + +$ sftpgo serve + +Please take a look at the usage below to customize the startup options`, + Run: func(_ *cobra.Command, _ []string) { + configDir := util.CleanDirInput(configDir) + checkServeParamsFromEnvFiles(configDir) + service.SetGraceTime(graceTime) + service := service.Service{ + ConfigDir: configDir, + ConfigFile: configFile, + LogFilePath: logFilePath, + LogMaxSize: logMaxSize, + LogMaxBackups: logMaxBackups, + LogMaxAge: logMaxAge, + LogCompress: logCompress, + LogLevel: logLevel, + LogUTCTime: logUTCTime, + LoadDataFrom: loadDataFrom, + LoadDataMode: loadDataMode, + LoadDataQuotaScan: loadDataQuotaScan, + LoadDataClean: loadDataClean, + Shutdown: make(chan bool), + } + if err := service.Start(); err == nil { + service.Wait() + if service.Error == nil { + os.Exit(0) + } + } + os.Exit(1) + }, + } +) + +func setIntFromEnv(receiver *int, val string) { + converted, err := strconv.Atoi(val) + if err == nil { + *receiver = converted + } +} + +func setBoolFromEnv(receiver *bool, val string) { + converted, err := strconv.ParseBool(strings.TrimSpace(val)) + if err == nil { + *receiver = converted + } +} + +func checkServeParamsFromEnvFiles(configDir string) { //nolint:gocyclo + // The logger is not yet initialized here, we have no way to report errors. + envd := filepath.Join(configDir, "env.d") + entries, err := os.ReadDir(envd) + if err != nil { + return + } + for _, entry := range entries { + info, err := entry.Info() + if err == nil && info.Mode().IsRegular() { + envFile := filepath.Join(envd, entry.Name()) + if info.Size() > envFileMaxSize { + continue + } + envVars, err := gotenv.Read(envFile) + if err != nil { + return + } + for k, v := range envVars { + if _, isSet := os.LookupEnv(k); isSet { + continue + } + switch k { + case "SFTPGO_LOG_FILE_PATH": + logFilePath = v + case "SFTPGO_LOG_MAX_SIZE": + setIntFromEnv(&logMaxSize, v) + case "SFTPGO_LOG_MAX_BACKUPS": + setIntFromEnv(&logMaxBackups, v) + case "SFTPGO_LOG_MAX_AGE": + setIntFromEnv(&logMaxAge, v) + case "SFTPGO_LOG_COMPRESS": + setBoolFromEnv(&logCompress, v) + case "SFTPGO_LOG_LEVEL": + logLevel = v + case "SFTPGO_LOG_UTC_TIME": + setBoolFromEnv(&logUTCTime, v) + case "SFTPGO_CONFIG_FILE": + configFile = v + case "SFTPGO_LOADDATA_FROM": + loadDataFrom = v + case "SFTPGO_LOADDATA_MODE": + setIntFromEnv(&loadDataMode, v) + case "SFTPGO_LOADDATA_CLEAN": + setBoolFromEnv(&loadDataClean, v) + case "SFTPGO_LOADDATA_QUOTA_SCAN": + setIntFromEnv(&loadDataQuotaScan, v) + case "SFTPGO_GRACE_TIME": + setIntFromEnv(&graceTime, v) + } + } + } + } +} + +func init() { + rootCmd.AddCommand(serveCmd) + addServeFlags(serveCmd) +} diff --git a/internal/cmd/service_windows.go b/internal/cmd/service_windows.go new file mode 100644 index 00000000..9deff9a1 --- /dev/null +++ b/internal/cmd/service_windows.go @@ -0,0 +1,30 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "github.com/spf13/cobra" +) + +var ( + serviceCmd = &cobra.Command{ + Use: "service", + Short: "Manage the SFTPGo Windows Service", + } +) + +func init() { + rootCmd.AddCommand(serviceCmd) +} diff --git a/internal/cmd/signals_unix.go b/internal/cmd/signals_unix.go new file mode 100644 index 00000000..91849b9a --- /dev/null +++ b/internal/cmd/signals_unix.go @@ -0,0 +1,41 @@ +// Copyright (C) 2025 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !windows + +package cmd + +import ( + "os" + "os/signal" + "syscall" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" +) + +func registerSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + go func() { + for sig := range c { + switch sig { + case syscall.SIGINT, syscall.SIGTERM: + logger.DebugToConsole("Received interrupt request") + plugin.Handler.Cleanup() + os.Exit(0) + } + } + }() +} diff --git a/internal/cmd/signals_windows.go b/internal/cmd/signals_windows.go new file mode 100644 index 00000000..3ea1e6ff --- /dev/null +++ b/internal/cmd/signals_windows.go @@ -0,0 +1,36 @@ +// Copyright (C) 2025 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + "os/signal" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" +) + +func registerSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + go func() { + for range c { + logger.DebugToConsole("Received interrupt request") + plugin.Handler.Cleanup() + os.Exit(0) + } + }() +} diff --git a/internal/cmd/smtptest.go b/internal/cmd/smtptest.go new file mode 100644 index 00000000..2a448afe --- /dev/null +++ b/internal/cmd/smtptest.go @@ -0,0 +1,76 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + smtpTestRecipient string + smtpTestCmd = &cobra.Command{ + Use: "smtptest", + Short: "Test the SMTP configuration", + Long: `SFTPGo will try to send a test email to the specified recipient. +If the SMTP configuration is correct you should receive this email.`, + Run: func(_ *cobra.Command, _ []string) { + logger.DisableLogger() + logger.EnableConsoleLogger(zerolog.DebugLevel) + configDir = util.CleanDirInput(configDir) + err := config.LoadConfig(configDir, configFile) + if err != nil { + logger.ErrorToConsole("Unable to load configuration: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, false) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + smtpConfig := config.GetSMTPConfig() + smtpConfig.Debug = 1 + err = smtpConfig.Initialize(configDir, false) + if err != nil { + logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err) + os.Exit(1) + } + err = smtp.SendEmail([]string{smtpTestRecipient}, nil, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!", + smtp.EmailContentTypeTextPlain) + if err != nil { + logger.WarnToConsole("Error sending email: %v", err) + os.Exit(1) + } + logger.InfoToConsole("No errors were reported while sending the test email. Please check your inbox to make sure.") + }, + } +) + +func init() { + addConfigFlags(smtpTestCmd) + smtpTestCmd.Flags().StringVar(&smtpTestRecipient, "recipient", "", `email address to send the test e-mail to`) + smtpTestCmd.MarkFlagRequired("recipient") //nolint:errcheck + + rootCmd.AddCommand(smtpTestCmd) +} diff --git a/internal/cmd/start_windows.go b/internal/cmd/start_windows.go new file mode 100644 index 00000000..5e4dde64 --- /dev/null +++ b/internal/cmd/start_windows.go @@ -0,0 +1,68 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + startCmd = &cobra.Command{ + Use: "start", + Short: "Start the SFTPGo Windows Service", + Run: func(_ *cobra.Command, _ []string) { + configDir = util.CleanDirInput(configDir) + checkServeParamsFromEnvFiles(configDir) + service.SetGraceTime(graceTime) + s := service.Service{ + ConfigDir: configDir, + ConfigFile: configFile, + LogFilePath: logFilePath, + LogMaxSize: logMaxSize, + LogMaxBackups: logMaxBackups, + LogMaxAge: logMaxAge, + LogCompress: logCompress, + LogLevel: logLevel, + LogUTCTime: logUTCTime, + LoadDataFrom: loadDataFrom, + LoadDataMode: loadDataMode, + LoadDataQuotaScan: loadDataQuotaScan, + LoadDataClean: loadDataClean, + Shutdown: make(chan bool), + } + winService := service.WindowsService{ + Service: s, + } + err := winService.RunService() + if err != nil { + fmt.Printf("Error starting service: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Service started!\r\n") + } + }, + } +) + +func init() { + serviceCmd.AddCommand(startCmd) + addServeFlags(startCmd) +} diff --git a/internal/cmd/status_windows.go b/internal/cmd/status_windows.go new file mode 100644 index 00000000..466ca2a6 --- /dev/null +++ b/internal/cmd/status_windows.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" +) + +var ( + statusCmd = &cobra.Command{ + Use: "status", + Short: "Retrieve the status for the SFTPGo Windows Service", + Run: func(_ *cobra.Command, _ []string) { + s := service.WindowsService{ + Service: service.Service{ + Shutdown: make(chan bool), + }, + } + status, err := s.Status() + if err != nil { + fmt.Printf("Error querying service status: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Service status: %q\r\n", status.String()) + } + }, + } +) + +func init() { + serviceCmd.AddCommand(statusCmd) +} diff --git a/internal/cmd/stop_windows.go b/internal/cmd/stop_windows.go new file mode 100644 index 00000000..8c4e987f --- /dev/null +++ b/internal/cmd/stop_windows.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" +) + +var ( + stopCmd = &cobra.Command{ + Use: "stop", + Short: "Stop the SFTPGo Windows Service", + Run: func(_ *cobra.Command, _ []string) { + s := service.WindowsService{ + Service: service.Service{ + Shutdown: make(chan bool), + }, + } + err := s.Stop() + if err != nil { + fmt.Printf("Error stopping service: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Service stopped!\r\n") + } + }, + } +) + +func init() { + serviceCmd.AddCommand(stopCmd) +} diff --git a/internal/cmd/uninstall_windows.go b/internal/cmd/uninstall_windows.go new file mode 100644 index 00000000..5aa04d80 --- /dev/null +++ b/internal/cmd/uninstall_windows.go @@ -0,0 +1,49 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + + "github.com/drakkan/sftpgo/v2/internal/service" +) + +var ( + uninstallCmd = &cobra.Command{ + Use: "uninstall", + Short: "Uninstall the SFTPGo Windows Service", + Run: func(_ *cobra.Command, _ []string) { + s := service.WindowsService{ + Service: service.Service{ + Shutdown: make(chan bool), + }, + } + err := s.Uninstall() + if err != nil { + fmt.Printf("Error removing service: %v\r\n", err) + os.Exit(1) + } else { + fmt.Printf("Service uninstalled\r\n") + } + }, + } +) + +func init() { + serviceCmd.AddCommand(uninstallCmd) +} diff --git a/internal/command/command.go b/internal/command/command.go new file mode 100644 index 00000000..512c2d4e --- /dev/null +++ b/internal/command/command.go @@ -0,0 +1,145 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package command provides command configuration for SFTPGo hooks +package command + +import ( + "fmt" + "slices" + "strings" + "time" +) + +const ( + minTimeout = 1 + maxTimeout = 300 + defaultTimeout = 30 +) + +// Supported hook names +const ( + HookFsActions = "fs_actions" + HookProviderActions = "provider_actions" + HookStartup = "startup" + HookPostConnect = "post_connect" + HookPostDisconnect = "post_disconnect" + HookCheckPassword = "check_password" + HookPreLogin = "pre_login" + HookPostLogin = "post_login" + HookExternalAuth = "external_auth" + HookKeyboardInteractive = "keyboard_interactive" +) + +var ( + config Config + supportedHooks = []string{HookFsActions, HookProviderActions, HookStartup, HookPostConnect, HookPostDisconnect, + HookCheckPassword, HookPreLogin, HookPostLogin, HookExternalAuth, HookKeyboardInteractive} +) + +// Command define the configuration for a specific commands +type Command struct { + // Path is the command path as defined in the hook configuration + Path string `json:"path" mapstructure:"path"` + // Timeout specifies a time limit, in seconds, for the command execution. + // This value overrides the global timeout if set. + // Do not use variables with the SFTPGO_ prefix to avoid conflicts with env + // vars that SFTPGo sets + Timeout int `json:"timeout" mapstructure:"timeout"` + // Env defines environment variable for the command. + // Each entry is of the form "key=value". + // These values are added to the global environment variables if any + Env []string `json:"env" mapstructure:"env"` + // Args defines arguments to pass to the specified command + Args []string `json:"args" mapstructure:"args"` + // if not empty both command path and hook name must match + Hook string `json:"hook" mapstructure:"hook"` +} + +// Config defines the configuration for external commands such as +// program based hooks +type Config struct { + // Timeout specifies a global time limit, in seconds, for the external commands execution + Timeout int `json:"timeout" mapstructure:"timeout"` + // Env defines environment variable for the commands. + // Each entry is of the form "key=value". + // Do not use variables with the SFTPGO_ prefix to avoid conflicts with env + // vars that SFTPGo sets + Env []string `json:"env" mapstructure:"env"` + // Commands defines configuration for specific commands + Commands []Command `json:"commands" mapstructure:"commands"` +} + +func init() { + config = Config{ + Timeout: defaultTimeout, + } +} + +// Initialize configures commands +func (c Config) Initialize() error { + if c.Timeout < minTimeout || c.Timeout > maxTimeout { + return fmt.Errorf("invalid timeout %v", c.Timeout) + } + for _, env := range c.Env { + if len(strings.SplitN(env, "=", 2)) != 2 { + return fmt.Errorf("invalid env var %q", env) + } + } + for idx, cmd := range c.Commands { + if cmd.Path == "" { + return fmt.Errorf("invalid path %q", cmd.Path) + } + if cmd.Timeout == 0 { + c.Commands[idx].Timeout = c.Timeout + } else { + if cmd.Timeout < minTimeout || cmd.Timeout > maxTimeout { + return fmt.Errorf("invalid timeout %v for command %q", cmd.Timeout, cmd.Path) + } + } + for _, env := range cmd.Env { + if len(strings.SplitN(env, "=", 2)) != 2 { + return fmt.Errorf("invalid env var %q for command %q", env, cmd.Path) + } + } + // don't validate args, we allow to pass empty arguments + if cmd.Hook != "" { + if !slices.Contains(supportedHooks, cmd.Hook) { + return fmt.Errorf("invalid hook name %q, supported values: %+v", cmd.Hook, supportedHooks) + } + } + } + config = c + return nil +} + +// GetConfig returns the configuration for the specified command +func GetConfig(command, hook string) (time.Duration, []string, []string) { + env := []string{} + var args []string + timeout := time.Duration(config.Timeout) * time.Second + env = append(env, config.Env...) + for _, cmd := range config.Commands { + if cmd.Path == command { + if cmd.Hook == "" || cmd.Hook == hook { + timeout = time.Duration(cmd.Timeout) * time.Second + env = append(env, cmd.Env...) + args = cmd.Args + break + } + } + } + + return timeout, env, args +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go new file mode 100644 index 00000000..cea9b17d --- /dev/null +++ b/internal/command/command_test.go @@ -0,0 +1,181 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package command + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCommandConfig(t *testing.T) { + require.Equal(t, defaultTimeout, config.Timeout) + cfg := Config{ + Timeout: 10, + Env: []string{"a=b"}, + } + err := cfg.Initialize() + require.NoError(t, err) + assert.Equal(t, cfg.Timeout, config.Timeout) + assert.Equal(t, cfg.Env, config.Env) + assert.Len(t, cfg.Commands, 0) + timeout, env, args := GetConfig("cmd", "") + assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) + assert.Contains(t, env, "a=b") + assert.Len(t, args, 0) + + cfg.Commands = []Command{ + { + Path: "cmd1", + Timeout: 30, + Env: []string{"c=d"}, + Args: []string{"1", "", "2"}, + }, + { + Path: "cmd2", + Timeout: 0, + Env: []string{"e=f"}, + }, + } + err = cfg.Initialize() + require.NoError(t, err) + assert.Equal(t, cfg.Timeout, config.Timeout) + assert.Equal(t, cfg.Env, config.Env) + if assert.Len(t, config.Commands, 2) { + assert.Equal(t, cfg.Commands[0].Path, config.Commands[0].Path) + assert.Equal(t, cfg.Commands[0].Timeout, config.Commands[0].Timeout) + assert.Equal(t, cfg.Commands[0].Env, config.Commands[0].Env) + assert.Equal(t, cfg.Commands[0].Args, config.Commands[0].Args) + assert.Equal(t, cfg.Commands[1].Path, config.Commands[1].Path) + assert.Equal(t, cfg.Timeout, config.Commands[1].Timeout) + assert.Equal(t, cfg.Commands[1].Env, config.Commands[1].Env) + assert.Equal(t, cfg.Commands[1].Args, config.Commands[1].Args) + } + timeout, env, args = GetConfig("cmd1", "") + assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout) + assert.Contains(t, env, "a=b") + assert.Contains(t, env, "c=d") + assert.NotContains(t, env, "e=f") + if assert.Len(t, args, 3) { + assert.Equal(t, "1", args[0]) + assert.Empty(t, args[1]) + assert.Equal(t, "2", args[2]) + } + timeout, env, args = GetConfig("cmd2", "") + assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) + assert.Contains(t, env, "a=b") + assert.NotContains(t, env, "c=d") + assert.Contains(t, env, "e=f") + assert.Len(t, args, 0) + + cfg.Commands = []Command{ + { + Path: "cmd1", + Timeout: 30, + Env: []string{"c=d"}, + Args: []string{"1", "", "2"}, + Hook: HookCheckPassword, + }, + { + Path: "cmd1", + Timeout: 0, + Env: []string{"e=f"}, + Hook: HookExternalAuth, + }, + } + err = cfg.Initialize() + require.NoError(t, err) + timeout, env, args = GetConfig("cmd1", "") + assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) + assert.Contains(t, env, "a=b") + assert.NotContains(t, env, "c=d") + assert.NotContains(t, env, "e=f") + assert.Len(t, args, 0) + timeout, env, args = GetConfig("cmd1", HookCheckPassword) + assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout) + assert.Contains(t, env, "a=b") + assert.Contains(t, env, "c=d") + assert.NotContains(t, env, "e=f") + if assert.Len(t, args, 3) { + assert.Equal(t, "1", args[0]) + assert.Empty(t, args[1]) + assert.Equal(t, "2", args[2]) + } + timeout, env, args = GetConfig("cmd1", HookExternalAuth) + assert.Equal(t, time.Duration(cfg.Timeout)*time.Second, timeout) + assert.Contains(t, env, "a=b") + assert.NotContains(t, env, "c=d") + assert.Contains(t, env, "e=f") + assert.Len(t, args, 0) +} + +func TestConfigErrors(t *testing.T) { + c := Config{} + err := c.Initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid timeout") + } + c.Timeout = 10 + c.Env = []string{"a"} + err = c.Initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid env var") + } + c.Env = nil + c.Commands = []Command{ + { + Path: "", + }, + } + err = c.Initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid path") + } + c.Commands = []Command{ + { + Path: "path", + Timeout: 10000, + }, + } + err = c.Initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid timeout") + } + c.Commands = []Command{ + { + Path: "path", + Timeout: 30, + Env: []string{"b"}, + }, + } + err = c.Initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid env var") + } + c.Commands = []Command{ + { + Path: "path", + Timeout: 30, + Env: []string{"a=b"}, + Hook: "invali", + }, + } + err = c.Initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid hook name") + } +} diff --git a/internal/common/actions.go b/internal/common/actions.go new file mode 100644 index 00000000..d6fd04b0 --- /dev/null +++ b/internal/common/actions.go @@ -0,0 +1,355 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "os/exec" + "path" + "path/filepath" + "slices" + "strings" + "sync/atomic" + "time" + + "github.com/sftpgo/sdk" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/command" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" +) + +var ( + errUnexpectedHTTResponse = errors.New("unexpected HTTP hook response code") + hooksConcurrencyGuard = make(chan struct{}, 150) + activeHooks atomic.Int32 +) + +func startNewHook() { + activeHooks.Add(1) + hooksConcurrencyGuard <- struct{}{} +} + +func hookEnded() { + activeHooks.Add(-1) + <-hooksConcurrencyGuard +} + +// ProtocolActions defines the action to execute on file operations and SSH commands +type ProtocolActions struct { + // Valid values are download, upload, pre-delete, delete, rename, ssh_cmd. Empty slice to disable + ExecuteOn []string `json:"execute_on" mapstructure:"execute_on"` + // Actions to be performed synchronously. + // The pre-delete action is always executed synchronously while the other ones are asynchronous. + // Executing an action synchronously means that SFTPGo will not return a result code to the client + // (which is waiting for it) until your hook have completed its execution. + ExecuteSync []string `json:"execute_sync" mapstructure:"execute_sync"` + // Absolute path to an external program or an HTTP URL + Hook string `json:"hook" mapstructure:"hook"` +} + +var actionHandler ActionHandler = &defaultActionHandler{} + +// InitializeActionHandler lets the user choose an action handler implementation. +// +// Do NOT call this function after application initialization. +func InitializeActionHandler(handler ActionHandler) { + actionHandler = handler +} + +// ExecutePreAction executes a pre-* action and returns the result. +// The returned status has the following meaning: +// - 0 not executed +// - 1 executed using an external hook +// - 2 executed using the event manager +func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath string, fileSize int64, openFlags int) (int, error) { + var event *notifier.FsEvent + hasNotifiersPlugin := plugin.Handler.HasNotifiers() + hasHook := slices.Contains(Config.Actions.ExecuteOn, operation) + hasRules := eventManager.hasFsRules() + if !hasHook && !hasNotifiersPlugin && !hasRules { + return 0, nil + } + dateTime := time.Now() + event = newActionNotification(&conn.User, operation, filePath, virtualPath, "", "", "", + conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, openFlags, conn.getNotificationStatus(nil), 0, dateTime, nil) + if hasNotifiersPlugin { + plugin.Handler.NotifyFsEvent(event) + } + if hasRules { + params := EventParams{ + Name: event.Username, + Groups: conn.User.Groups, + Event: event.Action, + Status: event.Status, + VirtualPath: event.VirtualPath, + FsPath: event.Path, + VirtualTargetPath: event.VirtualTargetPath, + FsTargetPath: event.TargetPath, + ObjectName: path.Base(event.VirtualPath), + Extension: path.Ext(event.VirtualPath), + FileSize: event.FileSize, + Protocol: event.Protocol, + IP: event.IP, + Role: event.Role, + Timestamp: dateTime, + Email: conn.User.Email, + Object: nil, + } + executedSync, err := eventManager.handleFsEvent(params) + if executedSync { + return 2, err + } + } + if !hasHook { + return 0, nil + } + return actionHandler.Handle(event) +} + +// ExecuteActionNotification executes the defined hook, if any, for the specified action +func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtualPath, target, virtualTarget, sshCmd string, + fileSize int64, err error, elapsed int64, metadata map[string]string, +) error { + hasNotifiersPlugin := plugin.Handler.HasNotifiers() + hasHook := slices.Contains(Config.Actions.ExecuteOn, operation) + hasRules := eventManager.hasFsRules() + if !hasHook && !hasNotifiersPlugin && !hasRules { + return nil + } + dateTime := time.Now() + notification := newActionNotification(&conn.User, operation, filePath, virtualPath, target, virtualTarget, sshCmd, + conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, 0, conn.getNotificationStatus(err), elapsed, dateTime, metadata) + if hasNotifiersPlugin { + plugin.Handler.NotifyFsEvent(notification) + } + if hasRules { + params := EventParams{ + Name: notification.Username, + Groups: conn.User.Groups, + Event: notification.Action, + Status: notification.Status, + VirtualPath: notification.VirtualPath, + FsPath: notification.Path, + VirtualTargetPath: notification.VirtualTargetPath, + FsTargetPath: notification.TargetPath, + ObjectName: path.Base(notification.VirtualPath), + Extension: path.Ext(notification.VirtualPath), + FileSize: notification.FileSize, + Elapsed: notification.Elapsed, + Protocol: notification.Protocol, + IP: notification.IP, + Role: notification.Role, + Timestamp: dateTime, + Email: conn.User.Email, + Object: nil, + Metadata: metadata, + } + if err != nil { + params.AddError(fmt.Errorf("%q failed: %w", params.Event, err)) + } + executedSync, err := eventManager.handleFsEvent(params) + if executedSync { + return err + } + } + if hasHook { + if slices.Contains(Config.Actions.ExecuteSync, operation) { + _, err := actionHandler.Handle(notification) + return err + } + go func() { + startNewHook() + defer hookEnded() + + actionHandler.Handle(notification) //nolint:errcheck + }() + } + return nil +} + +// ActionHandler handles a notification for a Protocol Action. +type ActionHandler interface { + Handle(notification *notifier.FsEvent) (int, error) +} + +func newActionNotification( + user *dataprovider.User, + operation, filePath, virtualPath, target, virtualTarget, sshCmd, protocol, ip, sessionID string, + fileSize int64, + openFlags, status int, elapsed int64, + datetime time.Time, + metadata map[string]string, +) *notifier.FsEvent { + var bucket, endpoint string + + fsConfig := user.GetFsConfigForPath(virtualPath) + + switch fsConfig.Provider { + case sdk.S3FilesystemProvider: + bucket = fsConfig.S3Config.Bucket + endpoint = fsConfig.S3Config.Endpoint + case sdk.GCSFilesystemProvider: + bucket = fsConfig.GCSConfig.Bucket + case sdk.AzureBlobFilesystemProvider: + bucket = fsConfig.AzBlobConfig.Container + if fsConfig.AzBlobConfig.Endpoint != "" { + endpoint = fsConfig.AzBlobConfig.Endpoint + } + case sdk.SFTPFilesystemProvider: + endpoint = fsConfig.SFTPConfig.Endpoint + case sdk.HTTPFilesystemProvider: + endpoint = fsConfig.HTTPConfig.Endpoint + } + + return ¬ifier.FsEvent{ + Action: operation, + Username: user.Username, + Path: filePath, + TargetPath: target, + VirtualPath: virtualPath, + VirtualTargetPath: virtualTarget, + SSHCmd: sshCmd, + FileSize: fileSize, + FsProvider: int(fsConfig.Provider), + Bucket: bucket, + Endpoint: endpoint, + Status: status, + Protocol: protocol, + IP: ip, + SessionID: sessionID, + OpenFlags: openFlags, + Role: user.Role, + Timestamp: datetime.UnixNano(), + Elapsed: elapsed, + Metadata: metadata, + } +} + +type defaultActionHandler struct{} + +func (h *defaultActionHandler) Handle(event *notifier.FsEvent) (int, error) { + if !slices.Contains(Config.Actions.ExecuteOn, event.Action) { + return 0, nil + } + + if Config.Actions.Hook == "" { + logger.Warn(event.Protocol, "", "Unable to send notification, no hook is defined") + + return 0, nil + } + + if strings.HasPrefix(Config.Actions.Hook, "http") { + err := h.handleHTTP(event) + return 1, err + } + + err := h.handleCommand(event) + return 1, err +} + +func (h *defaultActionHandler) handleHTTP(event *notifier.FsEvent) error { + u, err := url.Parse(Config.Actions.Hook) + if err != nil { + logger.Error(event.Protocol, "", "Invalid hook %q for operation %q: %v", + Config.Actions.Hook, event.Action, err) + return err + } + + startTime := time.Now() + respCode := 0 + + var b bytes.Buffer + _ = json.NewEncoder(&b).Encode(event) + + resp, err := httpclient.RetryablePost(Config.Actions.Hook, "application/json", &b) + if err == nil { + respCode = resp.StatusCode + resp.Body.Close() + + if respCode != http.StatusOK { + err = errUnexpectedHTTResponse + } + } + + logger.Debug(event.Protocol, "", "notified operation %q to URL: %s status code: %d, elapsed: %s err: %v", + event.Action, u.Redacted(), respCode, time.Since(startTime), err) + + return err +} + +func (h *defaultActionHandler) handleCommand(event *notifier.FsEvent) error { + if !filepath.IsAbs(Config.Actions.Hook) { + err := fmt.Errorf("invalid notification command %q", Config.Actions.Hook) + logger.Warn(event.Protocol, "", "unable to execute notification command: %v", err) + + return err + } + + timeout, env, args := command.GetConfig(Config.Actions.Hook, command.HookFsActions) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, Config.Actions.Hook, args...) + cmd.Env = append(env, notificationAsEnvVars(event)...) + + startTime := time.Now() + err := cmd.Run() + + logger.Debug(event.Protocol, "", "executed command %q, elapsed: %s, error: %v", + Config.Actions.Hook, time.Since(startTime), err) + + return err +} + +func notificationAsEnvVars(event *notifier.FsEvent) []string { + result := []string{ + fmt.Sprintf("SFTPGO_ACTION=%s", event.Action), + fmt.Sprintf("SFTPGO_ACTION_USERNAME=%s", event.Username), + fmt.Sprintf("SFTPGO_ACTION_PATH=%s", event.Path), + fmt.Sprintf("SFTPGO_ACTION_TARGET=%s", event.TargetPath), + fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_PATH=%s", event.VirtualPath), + fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_TARGET=%s", event.VirtualTargetPath), + fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%s", event.SSHCmd), + fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%d", event.FileSize), + fmt.Sprintf("SFTPGO_ACTION_ELAPSED=%d", event.Elapsed), + fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%d", event.FsProvider), + fmt.Sprintf("SFTPGO_ACTION_BUCKET=%s", event.Bucket), + fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%s", event.Endpoint), + fmt.Sprintf("SFTPGO_ACTION_STATUS=%d", event.Status), + fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%s", event.Protocol), + fmt.Sprintf("SFTPGO_ACTION_IP=%s", event.IP), + fmt.Sprintf("SFTPGO_ACTION_SESSION_ID=%s", event.SessionID), + fmt.Sprintf("SFTPGO_ACTION_OPEN_FLAGS=%d", event.OpenFlags), + fmt.Sprintf("SFTPGO_ACTION_TIMESTAMP=%d", event.Timestamp), + fmt.Sprintf("SFTPGO_ACTION_ROLE=%s", event.Role), + } + if len(event.Metadata) > 0 { + data, err := json.Marshal(event.Metadata) + if err == nil { + result = append(result, fmt.Sprintf("SFTPGO_ACTION_METADATA=%s", data)) + } + } + return result +} diff --git a/internal/common/actions_test.go b/internal/common/actions_test.go new file mode 100644 index 00000000..6ac0463c --- /dev/null +++ b/internal/common/actions_test.go @@ -0,0 +1,343 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/lithammer/shortuuid/v4" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/sftpgo/sdk/plugin/notifier" + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func TestNewActionNotification(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "username", + }, + } + user.FsConfig.Provider = sdk.LocalFilesystemProvider + user.FsConfig.S3Config = vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "s3bucket", + Endpoint: "endpoint", + }, + } + user.FsConfig.GCSConfig = vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: "gcsbucket", + }, + } + user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{ + BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ + Container: "azcontainer", + Endpoint: "azendpoint", + }, + } + user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: "sftpendpoint", + }, + } + user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "httpendpoint", + }, + } + c := NewBaseConnection("id", ProtocolSSH, "", "", user) + sessionID := xid.New().String() + a := newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, + 123, 0, c.getNotificationStatus(errors.New("fake error")), 0, time.Now(), nil) + assert.Equal(t, user.Username, a.Username) + assert.Equal(t, 0, len(a.Bucket)) + assert.Equal(t, 0, len(a.Endpoint)) + assert.Equal(t, 2, a.Status) + + user.FsConfig.Provider = sdk.S3FilesystemProvider + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID, + 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) + assert.Equal(t, "s3bucket", a.Bucket) + assert.Equal(t, "endpoint", a.Endpoint) + assert.Equal(t, 1, a.Status) + + user.FsConfig.Provider = sdk.GCSFilesystemProvider + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, + 123, 0, c.getNotificationStatus(ErrQuotaExceeded), 0, time.Now(), nil) + assert.Equal(t, "gcsbucket", a.Bucket) + assert.Equal(t, 0, len(a.Endpoint)) + assert.Equal(t, 3, a.Status) + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, + 123, 0, c.getNotificationStatus(fmt.Errorf("wrapper quota error: %w", ErrQuotaExceeded)), 0, time.Now(), nil) + assert.Equal(t, "gcsbucket", a.Bucket) + assert.Equal(t, 0, len(a.Endpoint)) + assert.Equal(t, 3, a.Status) + + user.FsConfig.Provider = sdk.HTTPFilesystemProvider + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID, + 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) + assert.Equal(t, "httpendpoint", a.Endpoint) + assert.Equal(t, 1, a.Status) + + user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, + 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) + assert.Equal(t, "azcontainer", a.Bucket) + assert.Equal(t, "azendpoint", a.Endpoint) + assert.Equal(t, 1, a.Status) + + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, + 123, os.O_APPEND, c.getNotificationStatus(nil), 0, time.Now(), nil) + assert.Equal(t, "azcontainer", a.Bucket) + assert.Equal(t, "azendpoint", a.Endpoint) + assert.Equal(t, 1, a.Status) + assert.Equal(t, os.O_APPEND, a.OpenFlags) + + user.FsConfig.Provider = sdk.SFTPFilesystemProvider + a = newActionNotification(&user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, + 123, 0, c.getNotificationStatus(nil), 0, time.Now(), nil) + assert.Equal(t, "sftpendpoint", a.Endpoint) +} + +func TestActionHTTP(t *testing.T) { + actionsCopy := Config.Actions + + Config.Actions = ProtocolActions{ + ExecuteOn: []string{operationDownload}, + Hook: fmt.Sprintf("http://%v", httpAddr), + } + user := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "username", + }, + } + a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", + xid.New().String(), 123, 0, 1, 0, time.Now(), nil) + status, err := actionHandler.Handle(a) + assert.NoError(t, err) + assert.Equal(t, 1, status) + + Config.Actions.Hook = "http://invalid:1234" + status, err = actionHandler.Handle(a) + assert.Error(t, err) + assert.Equal(t, 1, status) + + Config.Actions.Hook = fmt.Sprintf("http://%v/404", httpAddr) + status, err = actionHandler.Handle(a) + if assert.Error(t, err) { + assert.EqualError(t, err, errUnexpectedHTTResponse.Error()) + } + assert.Equal(t, 1, status) + + Config.Actions = actionsCopy +} + +func TestActionCMD(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + actionsCopy := Config.Actions + + hookCmd, err := exec.LookPath("true") + assert.NoError(t, err) + + Config.Actions = ProtocolActions{ + ExecuteOn: []string{operationDownload}, + Hook: hookCmd, + } + user := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "username", + }, + } + sessionID := shortuuid.New() + a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, + 123, 0, 1, 0, time.Now(), map[string]string{"key": "value"}) + status, err := actionHandler.Handle(a) + assert.NoError(t, err) + assert.Equal(t, 1, status) + + c := NewBaseConnection("id", ProtocolSFTP, "", "", *user) + err = ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil, 0, nil) + assert.NoError(t, err) + + err = ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil, 0, nil) + assert.NoError(t, err) + + Config.Actions = actionsCopy +} + +func TestWrongActions(t *testing.T) { + actionsCopy := Config.Actions + + badCommand := "/bad/command" + if runtime.GOOS == osWindows { + badCommand = "C:\\bad\\command" + } + Config.Actions = ProtocolActions{ + ExecuteOn: []string{operationUpload}, + Hook: badCommand, + } + user := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "username", + }, + } + + a := newActionNotification(user, operationUpload, "", "", "", "", "", ProtocolSFTP, "", xid.New().String(), + 123, 0, 1, 0, time.Now(), nil) + status, err := actionHandler.Handle(a) + assert.Error(t, err, "action with bad command must fail") + assert.Equal(t, 1, status) + + a.Action = operationDelete + status, err = actionHandler.Handle(a) + assert.NoError(t, err) + assert.Equal(t, 0, status) + + Config.Actions.Hook = "http://foo\x7f.com/" + a.Action = operationUpload + status, err = actionHandler.Handle(a) + assert.Error(t, err, "action with bad url must fail") + assert.Equal(t, 1, status) + + Config.Actions.Hook = "" + status, err = actionHandler.Handle(a) + assert.NoError(t, err) + assert.Equal(t, 0, status) + + Config.Actions.Hook = "relative path" + status, err = actionHandler.Handle(a) + if assert.Error(t, err) { + assert.EqualError(t, err, fmt.Sprintf("invalid notification command %q", Config.Actions.Hook)) + } + assert.Equal(t, 1, status) + + Config.Actions = actionsCopy +} + +func TestPreDeleteAction(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + actionsCopy := Config.Actions + + hookCmd, err := exec.LookPath("true") + assert.NoError(t, err) + Config.Actions = ProtocolActions{ + ExecuteOn: []string{operationPreDelete}, + Hook: "missing hook", + } + homeDir := filepath.Join(os.TempDir(), "test_user") + err = os.MkdirAll(homeDir, os.ModePerm) + assert.NoError(t, err) + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "username", + HomeDir: homeDir, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("id", homeDir, "", nil) + c := NewBaseConnection("id", ProtocolSFTP, "", "", user) + + testfile := filepath.Join(user.HomeDir, "testfile") + err = os.WriteFile(testfile, []byte("test"), os.ModePerm) + assert.NoError(t, err) + info, err := os.Stat(testfile) + assert.NoError(t, err) + err = c.RemoveFile(fs, testfile, "testfile", info) + assert.ErrorIs(t, err, c.GetPermissionDeniedError()) + assert.FileExists(t, testfile) + Config.Actions.Hook = hookCmd + err = c.RemoveFile(fs, testfile, "testfile", info) + assert.NoError(t, err) + assert.NoFileExists(t, testfile) + + os.RemoveAll(homeDir) + + Config.Actions = actionsCopy +} + +func TestUnconfiguredHook(t *testing.T) { + actionsCopy := Config.Actions + + Config.Actions = ProtocolActions{ + ExecuteOn: []string{operationDownload}, + Hook: "", + } + pluginsConfig := []plugin.Config{ + { + Type: "notifier", + }, + } + err := plugin.Initialize(pluginsConfig, "debug") + assert.Error(t, err) + assert.True(t, plugin.Handler.HasNotifiers()) + + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) + status, err := ExecutePreAction(c, OperationPreDownload, "", "", 0, 0) + assert.NoError(t, err) + assert.Equal(t, status, 0) + status, err = ExecutePreAction(c, operationPreDelete, "", "", 0, 0) + assert.NoError(t, err) + assert.Equal(t, status, 0) + + err = ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil, 0, nil) + assert.NoError(t, err) + + err = plugin.Initialize(nil, "debug") + assert.NoError(t, err) + assert.False(t, plugin.Handler.HasNotifiers()) + + Config.Actions = actionsCopy +} + +type actionHandlerStub struct { + called bool +} + +func (h *actionHandlerStub) Handle(_ *notifier.FsEvent) (int, error) { + h.called = true + + return 1, nil +} + +func TestInitializeActionHandler(t *testing.T) { + handler := &actionHandlerStub{} + + InitializeActionHandler(handler) + t.Cleanup(func() { + InitializeActionHandler(&defaultActionHandler{}) + }) + + status, err := actionHandler.Handle(¬ifier.FsEvent{}) + assert.NoError(t, err) + assert.True(t, handler.called) + assert.Equal(t, 1, status) +} diff --git a/internal/common/clientsmap.go b/internal/common/clientsmap.go new file mode 100644 index 00000000..2401af5f --- /dev/null +++ b/internal/common/clientsmap.go @@ -0,0 +1,65 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "sync" + "sync/atomic" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// clienstMap is a struct containing the map of the connected clients +type clientsMap struct { + totalConnections atomic.Int32 + mu sync.RWMutex + clients map[string]int +} + +func (c *clientsMap) add(source string) { + c.totalConnections.Add(1) + + c.mu.Lock() + defer c.mu.Unlock() + + c.clients[source]++ +} + +func (c *clientsMap) remove(source string) { + c.mu.Lock() + defer c.mu.Unlock() + + if val, ok := c.clients[source]; ok { + c.totalConnections.Add(-1) + c.clients[source]-- + if val > 1 { + return + } + delete(c.clients, source) + } else { + logger.Warn(logSender, "", "cannot remove client %v it is not mapped", source) + } +} + +func (c *clientsMap) getTotal() int32 { + return c.totalConnections.Load() +} + +func (c *clientsMap) getTotalFrom(source string) int { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.clients[source] +} diff --git a/internal/common/clientsmap_test.go b/internal/common/clientsmap_test.go new file mode 100644 index 00000000..d3df2a48 --- /dev/null +++ b/internal/common/clientsmap_test.go @@ -0,0 +1,73 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClientsMap(t *testing.T) { + m := clientsMap{ + clients: make(map[string]int), + } + ip1 := "192.168.1.1" + ip2 := "192.168.1.2" + m.add(ip1) + assert.Equal(t, int32(1), m.getTotal()) + assert.Equal(t, 1, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) + + m.add(ip1) + m.add(ip2) + assert.Equal(t, int32(3), m.getTotal()) + assert.Equal(t, 2, m.getTotalFrom(ip1)) + assert.Equal(t, 1, m.getTotalFrom(ip2)) + + m.add(ip1) + m.add(ip1) + m.add(ip2) + assert.Equal(t, int32(6), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 2, m.getTotalFrom(ip2)) + + m.remove(ip2) + assert.Equal(t, int32(5), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 1, m.getTotalFrom(ip2)) + + m.remove("unknown") + assert.Equal(t, int32(5), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 1, m.getTotalFrom(ip2)) + + m.remove(ip2) + assert.Equal(t, int32(4), m.getTotal()) + assert.Equal(t, 4, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) + + m.remove(ip1) + m.remove(ip1) + m.remove(ip1) + assert.Equal(t, int32(1), m.getTotal()) + assert.Equal(t, 1, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) + + m.remove(ip1) + assert.Equal(t, int32(0), m.getTotal()) + assert.Equal(t, 0, m.getTotalFrom(ip1)) + assert.Equal(t, 0, m.getTotalFrom(ip2)) +} diff --git a/internal/common/common.go b/internal/common/common.go new file mode 100644 index 00000000..99b5e8d1 --- /dev/null +++ b/internal/common/common.go @@ -0,0 +1,1518 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package common defines code shared among file transfer packages and protocols +package common + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pires/go-proxyproto" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/command" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// constants +const ( + logSender = "common" + uploadLogSender = "Upload" + downloadLogSender = "Download" + renameLogSender = "Rename" + rmdirLogSender = "Rmdir" + mkdirLogSender = "Mkdir" + symlinkLogSender = "Symlink" + removeLogSender = "Remove" + chownLogSender = "Chown" + chmodLogSender = "Chmod" + chtimesLogSender = "Chtimes" + copyLogSender = "Copy" + truncateLogSender = "Truncate" + operationDownload = "download" + operationUpload = "upload" + operationFirstDownload = "first-download" + operationFirstUpload = "first-upload" + operationDelete = "delete" + operationCopy = "copy" + // Pre-download action name + OperationPreDownload = "pre-download" + // Pre-upload action name + OperationPreUpload = "pre-upload" + operationPreDelete = "pre-delete" + operationRename = "rename" + operationMkdir = "mkdir" + operationRmdir = "rmdir" + // SSH command action name + OperationSSHCmd = "ssh_cmd" + chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS + idleTimeoutCheckInterval = 3 * time.Minute + periodicTimeoutCheckInterval = 1 * time.Minute +) + +// Stat flags +const ( + StatAttrUIDGID = 1 + StatAttrPerms = 2 + StatAttrTimes = 4 + StatAttrSize = 8 +) + +// Transfer types +const ( + TransferUpload = iota + TransferDownload +) + +// Supported protocols +const ( + ProtocolSFTP = "SFTP" + ProtocolSCP = "SCP" + ProtocolSSH = "SSH" + ProtocolFTP = "FTP" + ProtocolWebDAV = "DAV" + ProtocolHTTP = "HTTP" + ProtocolHTTPShare = "HTTPShare" + ProtocolDataRetention = "DataRetention" + ProtocolOIDC = "OIDC" + protocolEventAction = "EventAction" +) + +// Upload modes +const ( + UploadModeStandard = 0 + UploadModeAtomic = 1 + UploadModeAtomicWithResume = 2 + UploadModeS3StoreOnError = 4 + UploadModeGCSStoreOnError = 8 + UploadModeAzureBlobStoreOnError = 16 +) + +func init() { + Connections.clients = clientsMap{ + clients: make(map[string]int), + } + Connections.transfers = clientsMap{ + clients: make(map[string]int), + } + Connections.perUserConns = make(map[string]int) + Connections.mapping = make(map[string]int) + Connections.sshMapping = make(map[string]int) +} + +// errors definitions +var ( + ErrPermissionDenied = errors.New("permission denied") + ErrNotExist = errors.New("no such file or directory") + ErrOpUnsupported = errors.New("operation unsupported") + ErrGenericFailure = errors.New("failure") + ErrQuotaExceeded = errors.New("denying write due to space limit") + ErrReadQuotaExceeded = errors.New("denying read due to quota limit") + ErrConnectionDenied = errors.New("you are not allowed to connect") + ErrNoBinding = errors.New("no binding configured") + ErrCrtRevoked = errors.New("your certificate has been revoked") + ErrNoCredentials = errors.New("no credential provided") + ErrInternalFailure = errors.New("internal failure") + ErrTransferAborted = errors.New("transfer aborted") + ErrShuttingDown = errors.New("the service is shutting down") + errNoTransfer = errors.New("requested transfer not found") + errTransferMismatch = errors.New("transfer mismatch") +) + +var ( + // Config is the configuration for the supported protocols + Config Configuration + // Connections is the list of active connections + Connections ActiveConnections + // QuotaScans is the list of active quota scans + QuotaScans ActiveScans + transfersChecker TransfersChecker + supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, + ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC} + disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} + // the map key is the protocol, for each protocol we can have multiple rate limiters + rateLimiters map[string][]*rateLimiter + isShuttingDown atomic.Bool + ftpLoginCommands = []string{"PASS", "USER"} + fnUpdateBranding func(*dataprovider.BrandingConfigs) +) + +// SetUpdateBrandingFn sets the function to call to update branding configs. +func SetUpdateBrandingFn(fn func(*dataprovider.BrandingConfigs)) { + fnUpdateBranding = fn +} + +// Initialize sets the common configuration +func Initialize(c Configuration, isShared int) error { + isShuttingDown.Store(false) + util.SetUmask(c.Umask) + version.SetConfig(c.ServerVersion) + dataprovider.SetTZ(c.TZ) + Config = c + Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true) + Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true) + Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true) + Config.idleLoginTimeout = 2 * time.Minute + Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute + startPeriodicChecks(periodicTimeoutCheckInterval, isShared) + Config.defender = nil + Config.allowList = nil + Config.rateLimitersList = nil + rateLimiters = make(map[string][]*rateLimiter) + for _, rlCfg := range c.RateLimitersConfig { + if rlCfg.isEnabled() { + if err := rlCfg.validate(); err != nil { + return fmt.Errorf("rate limiters initialization error: %w", err) + } + rateLimiter := rlCfg.getLimiter() + for _, protocol := range rlCfg.Protocols { + rateLimiters[protocol] = append(rateLimiters[protocol], rateLimiter) + } + } + } + if len(rateLimiters) > 0 { + rateLimitersList, err := dataprovider.NewIPList(dataprovider.IPListTypeRateLimiterSafeList) + if err != nil { + return fmt.Errorf("unable to initialize ratelimiters list: %w", err) + } + Config.rateLimitersList = rateLimitersList + } + if c.DefenderConfig.Enabled { + if !slices.Contains(supportedDefenderDrivers, c.DefenderConfig.Driver) { + return fmt.Errorf("unsupported defender driver %q", c.DefenderConfig.Driver) + } + var defender Defender + var err error + switch c.DefenderConfig.Driver { + case DefenderDriverProvider: + defender, err = newDBDefender(&c.DefenderConfig) + default: + defender, err = newInMemoryDefender(&c.DefenderConfig) + } + if err != nil { + return fmt.Errorf("defender initialization error: %v", err) + } + logger.Info(logSender, "", "defender initialized with config %+v", c.DefenderConfig) + Config.defender = defender + } + if c.AllowListStatus > 0 { + allowList, err := dataprovider.NewIPList(dataprovider.IPListTypeAllowList) + if err != nil { + return fmt.Errorf("unable to initialize the allow list: %w", err) + } + logger.Info(logSender, "", "allow list initialized") + Config.allowList = allowList + } + if err := c.initializeProxyProtocol(); err != nil { + return err + } + if err := c.EventManager.validate(); err != nil { + return err + } + vfs.SetTempPath(c.TempPath) + dataprovider.SetTempPath(c.TempPath) + vfs.SetAllowSelfConnections(c.AllowSelfConnections) + vfs.SetRenameMode(c.RenameMode) + vfs.SetReadMetadataMode(c.Metadata.Read) + vfs.SetResumeMaxSize(c.ResumeMaxSize) + vfs.SetUploadMode(c.UploadMode) + dataprovider.SetAllowSelfConnections(c.AllowSelfConnections) + dataprovider.EnabledActionCommands = c.EventManager.EnabledCommands + transfersChecker = getTransfersChecker(isShared) + return nil +} + +// CheckClosing returns an error if the service is closing +func CheckClosing() error { + if isShuttingDown.Load() { + return ErrShuttingDown + } + return nil +} + +// WaitForTransfers waits, for the specified grace time, for currently ongoing +// client-initiated transfer sessions to completes. +// A zero graceTime means no wait +func WaitForTransfers(graceTime int) { + if graceTime == 0 { + return + } + if isShuttingDown.Swap(true) { + return + } + + if activeHooks.Load() == 0 && getActiveConnections() == 0 { + return + } + + graceTimer := time.NewTimer(time.Duration(graceTime) * time.Second) + ticker := time.NewTicker(3 * time.Second) + + for { + select { + case <-ticker.C: + hooks := activeHooks.Load() + logger.Info(logSender, "", "active hooks: %d", hooks) + if hooks == 0 && getActiveConnections() == 0 { + logger.Info(logSender, "", "no more active connections, graceful shutdown") + ticker.Stop() + graceTimer.Stop() + return + } + case <-graceTimer.C: + logger.Info(logSender, "", "grace time expired, hard shutdown") + ticker.Stop() + return + } + } +} + +// getActiveConnections returns the number of connections with active transfers +func getActiveConnections() int { + var activeConns int + + Connections.RLock() + for _, c := range Connections.connections { + if len(c.GetTransfers()) > 0 { + activeConns++ + } + } + Connections.RUnlock() + + logger.Info(logSender, "", "number of connections with active transfers: %d", activeConns) + return activeConns +} + +// LimitRate blocks until all the configured rate limiters +// allow one event to happen. +// It returns an error if the time to wait exceeds the max +// allowed delay +func LimitRate(protocol, ip string) (time.Duration, error) { + if Config.rateLimitersList != nil { + isListed, _, err := Config.rateLimitersList.IsListed(ip, protocol) + if err == nil && isListed { + return 0, nil + } + } + for _, limiter := range rateLimiters[protocol] { + if delay, err := limiter.Wait(ip, protocol); err != nil { + logger.Debug(logSender, "", "protocol %s ip %s: %v", protocol, ip, err) + return delay, err + } + } + return 0, nil +} + +// Reload reloads the whitelist, the IP filter plugin and the defender's block and safe lists +func Reload() error { + plugin.Handler.ReloadFilter() + return nil +} + +// DelayLogin applies the configured login delay +func DelayLogin(err error) { + if Config.defender != nil { + Config.defender.DelayLogin(err) + } +} + +// IsBanned returns true if the specified IP address is banned +func IsBanned(ip, protocol string) bool { + if plugin.Handler.IsIPBanned(ip, protocol) { + return true + } + if Config.defender == nil { + return false + } + + return Config.defender.IsBanned(ip, protocol) +} + +// GetDefenderBanTime returns the ban time for the given IP +// or nil if the IP is not banned or the defender is disabled +func GetDefenderBanTime(ip string) (*time.Time, error) { + if Config.defender == nil { + return nil, nil + } + + return Config.defender.GetBanTime(ip) +} + +// GetDefenderHosts returns hosts that are banned or for which some violations have been detected +func GetDefenderHosts() ([]dataprovider.DefenderEntry, error) { + if Config.defender == nil { + return nil, nil + } + + return Config.defender.GetHosts() +} + +// GetDefenderHost returns a defender host by ip, if any +func GetDefenderHost(ip string) (dataprovider.DefenderEntry, error) { + if Config.defender == nil { + return dataprovider.DefenderEntry{}, errors.New("defender is disabled") + } + + return Config.defender.GetHost(ip) +} + +// DeleteDefenderHost removes the specified IP address from the defender lists +func DeleteDefenderHost(ip string) bool { + if Config.defender == nil { + return false + } + + return Config.defender.DeleteHost(ip) +} + +// GetDefenderScore returns the score for the given IP +func GetDefenderScore(ip string) (int, error) { + if Config.defender == nil { + return 0, nil + } + + return Config.defender.GetScore(ip) +} + +// AddDefenderEvent adds the specified defender event for the given IP. +// Returns true if the IP is in the defender's safe list. +func AddDefenderEvent(ip, protocol string, event HostEvent) bool { + if Config.defender == nil { + return false + } + + return Config.defender.AddEvent(ip, protocol, event) +} + +func reloadProviderConfigs() { + configs, err := dataprovider.GetConfigs() + if err != nil { + logger.Error(logSender, "", "unable to load config from provider: %v", err) + return + } + configs.SetNilsToEmpty() + if fnUpdateBranding != nil { + fnUpdateBranding(configs.Branding) + } + if err := configs.SMTP.TryDecrypt(); err != nil { + logger.Error(logSender, "", "unable to decrypt smtp config: %v", err) + return + } + smtp.Activate(configs.SMTP) +} + +func startPeriodicChecks(duration time.Duration, isShared int) { + startEventScheduler() + spec := fmt.Sprintf("@every %s", duration) + _, err := eventScheduler.AddFunc(spec, Connections.checkTransfers) + util.PanicOnError(err) + logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec) + if isShared == 1 { + logger.Info(logSender, "", "add reload configs task") + _, err := eventScheduler.AddFunc("@every 10m", reloadProviderConfigs) + util.PanicOnError(err) + } + if Config.IdleTimeout > 0 { + ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval + spec = fmt.Sprintf("@every %s", duration*ratio) + _, err = eventScheduler.AddFunc(spec, Connections.checkIdles) + util.PanicOnError(err) + logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec) + } +} + +// ActiveTransfer defines the interface for the current active transfers +type ActiveTransfer interface { + GetID() int64 + GetType() int + GetSize() int64 + GetDownloadedSize() int64 + GetUploadedSize() int64 + GetVirtualPath() string + GetFsPath() string + GetStartTime() time.Time + SignalClose(err error) + Truncate(fsPath string, size int64) (int64, error) + GetRealFsPath(fsPath string) string + SetTimes(fsPath string, atime time.Time, mtime time.Time) bool + GetTruncatedSize() int64 + HasSizeLimit() bool +} + +// ActiveConnection defines the interface for the current active connections +type ActiveConnection interface { + GetID() string + GetUsername() string + GetRole() string + GetMaxSessions() int + GetLocalAddress() string + GetRemoteAddress() string + GetClientVersion() string + GetProtocol() string + GetConnectionTime() time.Time + GetLastActivity() time.Time + GetCommand() string + Disconnect() error + AddTransfer(t ActiveTransfer) + RemoveTransfer(t ActiveTransfer) + GetTransfers() []ConnectionTransfer + SignalTransferClose(transferID int64, err error) + CloseFS() error + isAccessAllowed() bool +} + +// StatAttributes defines the attributes for set stat commands +type StatAttributes struct { + Mode os.FileMode + Atime time.Time + Mtime time.Time + UID int + GID int + Flags int + Size int64 +} + +// ConnectionTransfer defines the trasfer details +type ConnectionTransfer struct { + ID int64 `json:"-"` + OperationType string `json:"operation_type"` + StartTime int64 `json:"start_time"` + Size int64 `json:"size"` + VirtualPath string `json:"path"` + HasSizeLimit bool `json:"-"` + ULSize int64 `json:"-"` + DLSize int64 `json:"-"` +} + +// EventManagerConfig defines the configuration for the EventManager +type EventManagerConfig struct { + // EnabledCommands defines the system commands that can be executed via EventManager, + // an empty list means that any command is allowed to be executed. + // Commands must be set as an absolute path + EnabledCommands []string `json:"enabled_commands" mapstructure:"enabled_commands"` +} + +func (c *EventManagerConfig) validate() error { + for _, c := range c.EnabledCommands { + if !filepath.IsAbs(c) { + return fmt.Errorf("invalid command %q: it must be an absolute path", c) + } + } + return nil +} + +// MetadataConfig defines how to handle metadata for cloud storage backends +type MetadataConfig struct { + // If not zero the metadata will be read before downloads and will be + // available in notifications + Read int `json:"read" mapstructure:"read"` +} + +// Configuration defines configuration parameters common to all supported protocols +type Configuration struct { + // Maximum idle timeout as minutes. If a client is idle for a time that exceeds this setting it will be disconnected. + // 0 means disabled + IdleTimeout int `json:"idle_timeout" mapstructure:"idle_timeout"` + // UploadMode 0 means standard, the files are uploaded directly to the requested path. + // 1 means atomic: the files are uploaded to a temporary path and renamed to the requested path + // when the client ends the upload. Atomic mode avoid problems such as a web server that + // serves partial files when the files are being uploaded. + // In atomic mode if there is an upload error the temporary file is deleted and so the requested + // upload path will not contain a partial file. + // 2 means atomic with resume support: as atomic but if there is an upload error the temporary + // file is renamed to the requested path and not deleted, this way a client can reconnect and resume + // the upload. + // 4 means files for S3 backend are stored even if a client-side upload error is detected. + // 8 means files for Google Cloud Storage backend are stored even if a client-side upload error is detected. + // 16 means files for Azure Blob backend are stored even if a client-side upload error is detected. + UploadMode int `json:"upload_mode" mapstructure:"upload_mode"` + // Actions to execute for SFTP file operations and SSH commands + Actions ProtocolActions `json:"actions" mapstructure:"actions"` + // SetstatMode 0 means "normal mode": requests for changing permissions and owner/group are executed. + // 1 means "ignore mode": requests for changing permissions and owner/group are silently ignored. + // 2 means "ignore mode for cloud fs": requests for changing permissions and owner/group are + // silently ignored for cloud based filesystem such as S3, GCS, Azure Blob. Requests for changing + // modification times are ignored for cloud based filesystem if they are not supported. + SetstatMode int `json:"setstat_mode" mapstructure:"setstat_mode"` + // RenameMode defines how to handle directory renames. By default, renaming of non-empty directories + // is not allowed for cloud storage providers (S3, GCS, Azure Blob). Set to 1 to enable recursive + // renames for these providers, they may be slow, there is no atomic rename API like for local + // filesystem, so SFTPGo will recursively list the directory contents and do a rename for each entry + RenameMode int `json:"rename_mode" mapstructure:"rename_mode"` + // ResumeMaxSize defines the maximum size allowed, in bytes, to resume uploads on storage backends + // with immutable objects. By default, resuming uploads is not allowed for cloud storage providers + // (S3, GCS, Azure Blob) because SFTPGo must rewrite the entire file. + // Set to a value greater than 0 to allow resuming uploads of files smaller than or equal to the + // defined size. + ResumeMaxSize int64 `json:"resume_max_size" mapstructure:"resume_max_size"` + // TempPath defines the path for temporary files such as those used for atomic uploads or file pipes. + // If you set this option you must make sure that the defined path exists, is accessible for writing + // by the user running SFTPGo, and is on the same filesystem as the users home directories otherwise + // the renaming for atomic uploads will become a copy and therefore may take a long time. + // The temporary files are not namespaced. The default is generally fine. Leave empty for the default. + TempPath string `json:"temp_path" mapstructure:"temp_path"` + // Support for HAProxy PROXY protocol. + // If you are running SFTPGo behind a proxy server such as HAProxy, AWS ELB or NGNIX, you can enable + // the proxy protocol. It provides a convenient way to safely transport connection information + // such as a client's address across multiple layers of NAT or TCP proxies to get the real + // client IP address instead of the proxy IP. Both protocol versions 1 and 2 are supported. + // - 0 means disabled + // - 1 means proxy protocol enabled. Proxy header will be used and requests without proxy header will be accepted. + // - 2 means proxy protocol required. Proxy header will be used and requests without proxy header will be rejected. + // If the proxy protocol is enabled in SFTPGo then you have to enable the protocol in your proxy configuration too, + // for example for HAProxy add "send-proxy" or "send-proxy-v2" to each server configuration line. + ProxyProtocol int `json:"proxy_protocol" mapstructure:"proxy_protocol"` + // List of IP addresses and IP ranges allowed to send the proxy header. + // If proxy protocol is set to 1 and we receive a proxy header from an IP that is not in the list then the + // connection will be accepted and the header will be ignored. + // If proxy protocol is set to 2 and we receive a proxy header from an IP that is not in the list then the + // connection will be rejected. + ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` + // List of IP addresses and IP ranges for which not to read the proxy header + ProxySkipped []string `json:"proxy_skipped" mapstructure:"proxy_skipped"` + // Absolute path to an external program or an HTTP URL to invoke as soon as SFTPGo starts. + // If you define an HTTP URL it will be invoked using a `GET` request. + // Please note that SFTPGo services may not yet be available when this hook is run. + // Leave empty do disable. + StartupHook string `json:"startup_hook" mapstructure:"startup_hook"` + // Absolute path to an external program or an HTTP URL to invoke after a user connects + // and before he tries to login. It allows you to reject the connection based on the source + // ip address. Leave empty do disable. + PostConnectHook string `json:"post_connect_hook" mapstructure:"post_connect_hook"` + // Absolute path to an external program or an HTTP URL to invoke after an SSH/FTP connection ends. + // Leave empty do disable. + PostDisconnectHook string `json:"post_disconnect_hook" mapstructure:"post_disconnect_hook"` + // Maximum number of concurrent client connections. 0 means unlimited + MaxTotalConnections int `json:"max_total_connections" mapstructure:"max_total_connections"` + // Maximum number of concurrent client connections from the same host (IP). 0 means unlimited + MaxPerHostConnections int `json:"max_per_host_connections" mapstructure:"max_per_host_connections"` + // Defines the status of the global allow list. 0 means disabled, 1 enabled. + // If enabled, only the listed IPs/networks can access the configured services, all other + // client connections will be dropped before they even try to authenticate. + // Ensure to enable this setting only after adding some allowed ip/networks from the WebAdmin/REST API + AllowListStatus int `json:"allowlist_status" mapstructure:"allowlist_status"` + // Allow users on this instance to use other users/virtual folders on this instance as storage backend. + // Enable this setting if you know what you are doing. + AllowSelfConnections int `json:"allow_self_connections" mapstructure:"allow_self_connections"` + // Defender configuration + DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"` + // Rate limiter configurations + RateLimitersConfig []RateLimiterConfig `json:"rate_limiters" mapstructure:"rate_limiters"` + // Umask for new uploads. Leave blank to use the system default. + Umask string `json:"umask" mapstructure:"umask"` + // Defines the server version + ServerVersion string `json:"server_version" mapstructure:"server_version"` + // TZ defines the time zone to use for the EventManager scheduler and to + // control time-based access restrictions. Set to "local" to use the + // server's local time, otherwise UTC will be used. + TZ string `json:"tz" mapstructure:"tz"` + // Metadata configuration + Metadata MetadataConfig `json:"metadata" mapstructure:"metadata"` + // EventManager configuration + EventManager EventManagerConfig `json:"event_manager" mapstructure:"event_manager"` + idleTimeoutAsDuration time.Duration + idleLoginTimeout time.Duration + defender Defender + allowList *dataprovider.IPList + rateLimitersList *dataprovider.IPList + proxyAllowed []func(net.IP) bool + proxySkipped []func(net.IP) bool +} + +// IsAtomicUploadEnabled returns true if atomic upload is enabled +func (c *Configuration) IsAtomicUploadEnabled() bool { + return c.UploadMode&UploadModeAtomic != 0 || c.UploadMode&UploadModeAtomicWithResume != 0 +} + +func (c *Configuration) initializeProxyProtocol() error { + if c.ProxyProtocol > 0 { + allowed, err := util.ParseAllowedIPAndRanges(c.ProxyAllowed) + if err != nil { + return fmt.Errorf("invalid proxy allowed: %w", err) + } + skipped, err := util.ParseAllowedIPAndRanges(c.ProxySkipped) + if err != nil { + return fmt.Errorf("invalid proxy skipped: %w", err) + } + Config.proxyAllowed = allowed + Config.proxySkipped = skipped + } + return nil +} + +// GetProxyListener returns a wrapper for the given listener that supports the +// HAProxy Proxy Protocol +func (c *Configuration) GetProxyListener(listener net.Listener) (net.Listener, error) { + if c.ProxyProtocol > 0 { + defaultPolicy := proxyproto.REQUIRE + if c.ProxyProtocol == 1 { + defaultPolicy = proxyproto.IGNORE + } + + return &proxyproto.Listener{ + Listener: listener, + ConnPolicy: getProxyPolicy(c.proxyAllowed, c.proxySkipped, defaultPolicy), + ReadHeaderTimeout: 10 * time.Second, + }, nil + } + return nil, errors.New("proxy protocol not configured") +} + +// GetRateLimitersStatus returns the rate limiters status +func (c *Configuration) GetRateLimitersStatus() (bool, []string) { + enabled := false + var protocols []string + for _, rlCfg := range c.RateLimitersConfig { + if rlCfg.isEnabled() { + enabled = true + protocols = append(protocols, rlCfg.Protocols...) + } + } + return enabled, util.RemoveDuplicates(protocols, false) +} + +// IsAllowListEnabled returns true if the global allow list is enabled +func (c *Configuration) IsAllowListEnabled() bool { + return c.AllowListStatus > 0 +} + +// ExecuteStartupHook runs the startup hook if defined +func (c *Configuration) ExecuteStartupHook() error { + if c.StartupHook == "" { + return nil + } + if strings.HasPrefix(c.StartupHook, "http") { + var url *url.URL + url, err := url.Parse(c.StartupHook) + if err != nil { + logger.Warn(logSender, "", "Invalid startup hook %q: %v", c.StartupHook, err) + return err + } + startTime := time.Now() + resp, err := httpclient.RetryableGet(url.String()) + if err != nil { + logger.Warn(logSender, "", "Error executing startup hook: %v", err) + return err + } + defer resp.Body.Close() + logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, response code: %v", time.Since(startTime), resp.StatusCode) + return nil + } + if !filepath.IsAbs(c.StartupHook) { + err := fmt.Errorf("invalid startup hook %q", c.StartupHook) + logger.Warn(logSender, "", "Invalid startup hook %q", c.StartupHook) + return err + } + startTime := time.Now() + timeout, env, args := command.GetConfig(c.StartupHook, command.HookStartup) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, c.StartupHook, args...) + cmd.Env = env + err := cmd.Run() + logger.Debug(logSender, "", "Startup hook executed, elapsed: %s, error: %v", time.Since(startTime), err) + return nil +} + +func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) { + startNewHook() + defer hookEnded() + + ipAddr := util.GetIPFromRemoteAddress(remoteAddr) + connDuration := int64(time.Since(connectionTime) / time.Millisecond) + + if strings.HasPrefix(c.PostDisconnectHook, "http") { + var url *url.URL + url, err := url.Parse(c.PostDisconnectHook) + if err != nil { + logger.Warn(protocol, connID, "Invalid post disconnect hook %q: %v", c.PostDisconnectHook, err) + return + } + q := url.Query() + q.Add("ip", ipAddr) + q.Add("protocol", protocol) + q.Add("username", username) + q.Add("connection_duration", strconv.FormatInt(connDuration, 10)) + url.RawQuery = q.Encode() + startTime := time.Now() + resp, err := httpclient.RetryableGet(url.String()) + respCode := 0 + if err == nil { + respCode = resp.StatusCode + resp.Body.Close() + } + logger.Debug(protocol, connID, "Post disconnect hook response code: %v, elapsed: %v, err: %v", + respCode, time.Since(startTime), err) + return + } + if !filepath.IsAbs(c.PostDisconnectHook) { + logger.Debug(protocol, connID, "invalid post disconnect hook %q", c.PostDisconnectHook) + return + } + timeout, env, args := command.GetConfig(c.PostDisconnectHook, command.HookPostDisconnect) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + startTime := time.Now() + cmd := exec.CommandContext(ctx, c.PostDisconnectHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_CONNECTION_IP=%s", ipAddr), + fmt.Sprintf("SFTPGO_CONNECTION_USERNAME=%s", username), + fmt.Sprintf("SFTPGO_CONNECTION_DURATION=%d", connDuration), + fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%s", protocol)) + err := cmd.Run() + logger.Debug(protocol, connID, "Post disconnect hook executed, elapsed: %s error: %v", time.Since(startTime), err) +} + +func (c *Configuration) checkPostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) { + if c.PostDisconnectHook == "" { + return + } + if !slices.Contains(disconnHookProtocols, protocol) { + return + } + go c.executePostDisconnectHook(remoteAddr, protocol, username, connID, connectionTime) +} + +// ExecutePostConnectHook executes the post connect hook if defined +func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error { + if c.PostConnectHook == "" { + return nil + } + if strings.HasPrefix(c.PostConnectHook, "http") { + var url *url.URL + url, err := url.Parse(c.PostConnectHook) + if err != nil { + logger.Warn(protocol, "", "Login from ip %q denied, invalid post connect hook %q: %v", + ipAddr, c.PostConnectHook, err) + return getPermissionDeniedError(protocol) + } + q := url.Query() + q.Add("ip", ipAddr) + q.Add("protocol", protocol) + url.RawQuery = q.Encode() + + resp, err := httpclient.RetryableGet(url.String()) + if err != nil { + logger.Warn(protocol, "", "Login from ip %q denied, error executing post connect hook: %v", ipAddr, err) + return getPermissionDeniedError(protocol) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + logger.Warn(protocol, "", "Login from ip %q denied, post connect hook response code: %v", ipAddr, resp.StatusCode) + return getPermissionDeniedError(protocol) + } + return nil + } + if !filepath.IsAbs(c.PostConnectHook) { + err := fmt.Errorf("invalid post connect hook %q", c.PostConnectHook) + logger.Warn(protocol, "", "Login from ip %q denied: %v", ipAddr, err) + return getPermissionDeniedError(protocol) + } + timeout, env, args := command.GetConfig(c.PostConnectHook, command.HookPostConnect) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, c.PostConnectHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_CONNECTION_IP=%s", ipAddr), + fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%s", protocol)) + err := cmd.Run() + if err != nil { + logger.Warn(protocol, "", "Login from ip %q denied, connect hook error: %v", ipAddr, err) + return getPermissionDeniedError(protocol) + } + return nil +} + +func getProxyPolicy(allowed, skipped []func(net.IP) bool, def proxyproto.Policy) proxyproto.ConnPolicyFunc { + return func(connPolicyOptions proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) { + upstreamIP, err := util.GetIPFromNetAddr(connPolicyOptions.Upstream) + if err != nil { + // Something is wrong with the source IP, better reject the + // connection. + logger.Error(logSender, "", "reject connection from ip %q, err: %v", connPolicyOptions.Upstream, err) + return proxyproto.REJECT, proxyproto.ErrInvalidUpstream + } + + for _, skippedFrom := range skipped { + if skippedFrom(upstreamIP) { + return proxyproto.SKIP, nil + } + } + + for _, allowFrom := range allowed { + if allowFrom(upstreamIP) { + if def == proxyproto.REQUIRE { + return proxyproto.REQUIRE, nil + } + return proxyproto.USE, nil + } + } + + if def == proxyproto.REQUIRE { + logger.Debug(logSender, "", "reject connection from ip %q: proxy protocol signature required and not set", + upstreamIP) + return proxyproto.REJECT, proxyproto.ErrInvalidUpstream + } + return def, nil + } +} + +// SSHConnection defines an ssh connection. +// Each SSH connection can open several channels for SFTP or SSH commands +type SSHConnection struct { + id string + conn io.Closer + lastActivity atomic.Int64 +} + +// NewSSHConnection returns a new SSHConnection +func NewSSHConnection(id string, conn io.Closer) *SSHConnection { + c := &SSHConnection{ + id: id, + conn: conn, + } + c.lastActivity.Store(time.Now().UnixNano()) + return c +} + +// GetID returns the ID for this SSHConnection +func (c *SSHConnection) GetID() string { + return c.id +} + +// UpdateLastActivity updates last activity for this connection +func (c *SSHConnection) UpdateLastActivity() { + c.lastActivity.Store(time.Now().UnixNano()) +} + +// GetLastActivity returns the last connection activity +func (c *SSHConnection) GetLastActivity() time.Time { + return time.Unix(0, c.lastActivity.Load()) +} + +// Close closes the underlying network connection +func (c *SSHConnection) Close() error { + return c.conn.Close() +} + +// ActiveConnections holds the currect active connections with the associated transfers +type ActiveConnections struct { + // clients contains both authenticated and estabilished connections and the ones waiting + // for authentication + clients clientsMap + // transfers contains active transfers, total and per-user + transfers clientsMap + transfersCheckStatus atomic.Bool + sync.RWMutex + connections []ActiveConnection + mapping map[string]int + sshConnections []*SSHConnection + sshMapping map[string]int + perUserConns map[string]int +} + +// internal method, must be called within a locked block +func (conns *ActiveConnections) addUserConnection(username string) { + if username == "" { + return + } + conns.perUserConns[username]++ +} + +// internal method, must be called within a locked block +func (conns *ActiveConnections) removeUserConnection(username string) { + if username == "" { + return + } + if val, ok := conns.perUserConns[username]; ok { + conns.perUserConns[username]-- + if val > 1 { + return + } + delete(conns.perUserConns, username) + } +} + +// GetActiveSessions returns the number of active sessions for the given username. +// We return the open sessions for any protocol +func (conns *ActiveConnections) GetActiveSessions(username string) int { + conns.RLock() + defer conns.RUnlock() + + return conns.perUserConns[username] +} + +// Add adds a new connection to the active ones +func (conns *ActiveConnections) Add(c ActiveConnection) error { + conns.Lock() + defer conns.Unlock() + + if username := c.GetUsername(); username != "" { + if maxSessions := c.GetMaxSessions(); maxSessions > 0 { + if val := conns.perUserConns[username]; val >= maxSessions { + return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) + } + if val := conns.transfers.getTotalFrom(username); val >= maxSessions { + return fmt.Errorf("too many open transfers: %d/%d", val, maxSessions) + } + } + conns.addUserConnection(username) + } + conns.mapping[c.GetID()] = len(conns.connections) + conns.connections = append(conns.connections, c) + metric.UpdateActiveConnectionsSize(len(conns.connections)) + logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %q, remote address %q, num open connections: %d", + c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections)) + return nil +} + +// Swap replaces an existing connection with the given one. +// This method is useful if you have to change some connection details +// for example for FTP is used to update the connection once the user +// authenticates +func (conns *ActiveConnections) Swap(c ActiveConnection) error { + conns.Lock() + defer conns.Unlock() + + if idx, ok := conns.mapping[c.GetID()]; ok { + conn := conns.connections[idx] + conns.removeUserConnection(conn.GetUsername()) + if username := c.GetUsername(); username != "" { + if maxSessions := c.GetMaxSessions(); maxSessions > 0 { + if val, ok := conns.perUserConns[username]; ok && val >= maxSessions { + conns.addUserConnection(conn.GetUsername()) + return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) + } + } + conns.addUserConnection(username) + } + err := conn.CloseFS() + conns.connections[idx] = c + logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err) + conn = nil + return nil + } + + return errors.New("connection to swap not found") +} + +// Remove removes a connection from the active ones +func (conns *ActiveConnections) Remove(connectionID string) { + conns.Lock() + defer conns.Unlock() + + if idx, ok := conns.mapping[connectionID]; ok { + conn := conns.connections[idx] + err := conn.CloseFS() + lastIdx := len(conns.connections) - 1 + conns.connections[idx] = conns.connections[lastIdx] + conns.connections[lastIdx] = nil + conns.connections = conns.connections[:lastIdx] + delete(conns.mapping, connectionID) + if idx != lastIdx { + conns.mapping[conns.connections[idx].GetID()] = idx + } + conns.removeUserConnection(conn.GetUsername()) + metric.UpdateActiveConnectionsSize(lastIdx) + logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %q, remote address %q close fs error: %v, num open connections: %d", + conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx) + if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" && !slices.Contains(ftpLoginCommands, conn.GetCommand()) { + ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress()) + logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTried, ProtocolFTP, + dataprovider.ErrNoAuthTried.Error()) + metric.AddNoAuthTried() + AddDefenderEvent(ip, ProtocolFTP, HostEventNoLoginTried) + dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTried, ip, + ProtocolFTP, dataprovider.ErrNoAuthTried) + plugin.Handler.NotifyLogEvent(notifier.LogEventTypeNoLoginTried, ProtocolFTP, "", ip, "", + dataprovider.ErrNoAuthTried) + } + Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(), + conn.GetID(), conn.GetConnectionTime()) + return + } + + logger.Debug(logSender, "", "connection id %q to remove not found!", connectionID) +} + +// Close closes an active connection. +// It returns true on success +func (conns *ActiveConnections) Close(connectionID, role string) bool { + conns.RLock() + + var result bool + + if idx, ok := conns.mapping[connectionID]; ok { + c := conns.connections[idx] + + if role == "" || c.GetRole() == role { + defer func(conn ActiveConnection) { + err := conn.Disconnect() + logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err) + }(c) + result = true + } + } + + conns.RUnlock() + return result +} + +// AddSSHConnection adds a new ssh connection to the active ones +func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) { + conns.Lock() + defer conns.Unlock() + + conns.sshMapping[c.GetID()] = len(conns.sshConnections) + conns.sshConnections = append(conns.sshConnections, c) + logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %d", len(conns.sshConnections)) +} + +// RemoveSSHConnection removes a connection from the active ones +func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) { + conns.Lock() + defer conns.Unlock() + + if idx, ok := conns.sshMapping[connectionID]; ok { + lastIdx := len(conns.sshConnections) - 1 + conns.sshConnections[idx] = conns.sshConnections[lastIdx] + conns.sshConnections[lastIdx] = nil + conns.sshConnections = conns.sshConnections[:lastIdx] + delete(conns.sshMapping, connectionID) + if idx != lastIdx { + conns.sshMapping[conns.sshConnections[idx].GetID()] = idx + } + logger.Debug(logSender, connectionID, "ssh connection removed, num open ssh connections: %d", lastIdx) + return + } + logger.Warn(logSender, "", "ssh connection to remove with id %q not found!", connectionID) +} + +func (conns *ActiveConnections) checkIdles() { + conns.RLock() + + for _, sshConn := range conns.sshConnections { + idleTime := time.Since(sshConn.GetLastActivity()) + if idleTime > Config.idleTimeoutAsDuration { + // we close an SSH connection if it has no active connections associated + idToMatch := fmt.Sprintf("_%s_", sshConn.GetID()) + toClose := true + for _, conn := range conns.connections { + if strings.Contains(conn.GetID(), idToMatch) { + if time.Since(conn.GetLastActivity()) <= Config.idleTimeoutAsDuration { + toClose = false + break + } + } + } + if toClose { + defer func(c *SSHConnection) { + err := c.Close() + logger.Debug(logSender, c.GetID(), "close idle SSH connection, idle time: %v, close err: %v", + time.Since(c.GetLastActivity()), err) + }(sshConn) + } + } + } + + for _, c := range conns.connections { + idleTime := time.Since(c.GetLastActivity()) + isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && c.GetUsername() == "") + + if idleTime > Config.idleTimeoutAsDuration || (isUnauthenticatedFTPUser && idleTime > Config.idleLoginTimeout) { + defer func(conn ActiveConnection) { + err := conn.Disconnect() + logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %s, username: %q close err: %v", + time.Since(conn.GetLastActivity()), conn.GetUsername(), err) + }(c) + } else if !isUnauthenticatedFTPUser && !c.isAccessAllowed() { + defer func(conn ActiveConnection) { + err := conn.Disconnect() + logger.Info(conn.GetProtocol(), conn.GetID(), "access conditions not met for user: %q close connection err: %v", + conn.GetUsername(), err) + }(c) + } + } + + conns.RUnlock() +} + +func (conns *ActiveConnections) checkTransfers() { + if conns.transfersCheckStatus.Load() { + logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution") + return + } + conns.transfersCheckStatus.Store(true) + defer conns.transfersCheckStatus.Store(false) + + conns.RLock() + + if len(conns.connections) < 2 { + conns.RUnlock() + return + } + var wg sync.WaitGroup + logger.Debug(logSender, "", "start concurrent transfers check") + + // update the current size for transfers to monitors + for _, c := range conns.connections { + for _, t := range c.GetTransfers() { + if t.HasSizeLimit { + wg.Add(1) + + go func(transfer ConnectionTransfer, connID string) { + defer wg.Done() + transfersChecker.UpdateTransferCurrentSizes(transfer.ULSize, transfer.DLSize, transfer.ID, connID) + }(t, c.GetID()) + } + } + } + + conns.RUnlock() + logger.Debug(logSender, "", "waiting for the update of the transfers current size") + wg.Wait() + + logger.Debug(logSender, "", "getting overquota transfers") + overquotaTransfers := transfersChecker.GetOverquotaTransfers() + logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers)) + if len(overquotaTransfers) == 0 { + return + } + + conns.RLock() + defer conns.RUnlock() + + for _, c := range conns.connections { + for _, overquotaTransfer := range overquotaTransfers { + if c.GetID() == overquotaTransfer.ConnID { + logger.Info(logSender, c.GetID(), "user %q is overquota, try to close transfer id %v", + c.GetUsername(), overquotaTransfer.TransferID) + var err error + if overquotaTransfer.TransferType == TransferDownload { + err = getReadQuotaExceededError(c.GetProtocol()) + } else { + err = getQuotaExceededError(c.GetProtocol()) + } + c.SignalTransferClose(overquotaTransfer.TransferID, err) + } + } + } + logger.Debug(logSender, "", "transfers check completed") +} + +// AddClientConnection stores a new client connection +func (conns *ActiveConnections) AddClientConnection(ipAddr string) { + conns.clients.add(ipAddr) +} + +// RemoveClientConnection removes a disconnected client from the tracked ones +func (conns *ActiveConnections) RemoveClientConnection(ipAddr string) { + conns.clients.remove(ipAddr) +} + +// GetClientConnections returns the total number of client connections +func (conns *ActiveConnections) GetClientConnections() int32 { + return conns.clients.getTotal() +} + +// GetTotalTransfers returns the total number of active transfers +func (conns *ActiveConnections) GetTotalTransfers() int32 { + return conns.transfers.getTotal() +} + +// IsNewTransferAllowed returns an error if the maximum number of concurrent allowed +// transfers is exceeded +func (conns *ActiveConnections) IsNewTransferAllowed(username string) error { + if isShuttingDown.Load() { + return ErrShuttingDown + } + if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { + return nil + } + if Config.MaxPerHostConnections > 0 { + if transfers := conns.transfers.getTotalFrom(username); transfers >= Config.MaxPerHostConnections { + logger.Info(logSender, "", "active transfers from user %q: %d/%d", username, transfers, Config.MaxPerHostConnections) + return ErrConnectionDenied + } + } + if Config.MaxTotalConnections > 0 { + if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) { + logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections) + return ErrConnectionDenied + } + } + return nil +} + +// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed +// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed +// or the service is shutting down +func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr, protocol string) error { + if isShuttingDown.Load() { + return ErrShuttingDown + } + if Config.allowList != nil { + isListed, _, err := Config.allowList.IsListed(ipAddr, protocol) + if err != nil { + logger.Error(logSender, "", "unable to query allow list, connection denied, ip %q, protocol %s, err: %v", + ipAddr, protocol, err) + return ErrConnectionDenied + } + if !isListed { + return ErrConnectionDenied + } + } + if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { + return nil + } + + if Config.MaxPerHostConnections > 0 { + if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections { + if !AddDefenderEvent(ipAddr, protocol, HostEventLimitExceeded) { + logger.Warn(logSender, "", "connection denied, active connections from IP %q: %d/%d", + ipAddr, total, Config.MaxPerHostConnections) + return ErrConnectionDenied + } + logger.Info(logSender, "", "active connections from safe IP %q: %d", ipAddr, total) + } + } + + if Config.MaxTotalConnections > 0 { + if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) { + logger.Info(logSender, "", "active client connections %d/%d", total, Config.MaxTotalConnections) + return ErrConnectionDenied + } + + // on a single SFTP connection we could have multiple SFTP channels or commands + // so we check the estabilished connections and active uploads too + if transfers := conns.transfers.getTotal(); transfers >= int32(Config.MaxTotalConnections) { + logger.Info(logSender, "", "active transfers %d/%d", transfers, Config.MaxTotalConnections) + return ErrConnectionDenied + } + + conns.RLock() + defer conns.RUnlock() + + if sess := len(conns.connections); sess >= Config.MaxTotalConnections { + logger.Info(logSender, "", "active client sessions %d/%d", sess, Config.MaxTotalConnections) + return ErrConnectionDenied + } + } + + return nil +} + +// GetStats returns stats for active connections +func (conns *ActiveConnections) GetStats(role string) []ConnectionStatus { + conns.RLock() + defer conns.RUnlock() + + stats := make([]ConnectionStatus, 0, len(conns.connections)) + node := dataprovider.GetNodeName() + for _, c := range conns.connections { + if role == "" || c.GetRole() == role { + stat := ConnectionStatus{ + Username: c.GetUsername(), + ConnectionID: c.GetID(), + ClientVersion: c.GetClientVersion(), + RemoteAddress: c.GetRemoteAddress(), + ConnectionTime: util.GetTimeAsMsSinceEpoch(c.GetConnectionTime()), + LastActivity: util.GetTimeAsMsSinceEpoch(c.GetLastActivity()), + CurrentTime: util.GetTimeAsMsSinceEpoch(time.Now()), + Protocol: c.GetProtocol(), + Command: c.GetCommand(), + Transfers: c.GetTransfers(), + Node: node, + } + stats = append(stats, stat) + } + } + return stats +} + +// ConnectionStatus returns the status for an active connection +type ConnectionStatus struct { + // Logged in username + Username string `json:"username"` + // Unique identifier for the connection + ConnectionID string `json:"connection_id"` + // client's version string + ClientVersion string `json:"client_version,omitempty"` + // Remote address for this connection + RemoteAddress string `json:"remote_address"` + // Connection time as unix timestamp in milliseconds + ConnectionTime int64 `json:"connection_time"` + // Last activity as unix timestamp in milliseconds + LastActivity int64 `json:"last_activity"` + // Current time as unix timestamp in milliseconds + CurrentTime int64 `json:"current_time"` + // Protocol for this connection + Protocol string `json:"protocol"` + // active uploads/downloads + Transfers []ConnectionTransfer `json:"active_transfers,omitempty"` + // SSH command or WebDAV method + Command string `json:"command,omitempty"` + // Node identifier, omitted for single node installations + Node string `json:"node,omitempty"` +} + +// ActiveQuotaScan defines an active quota scan for a user +type ActiveQuotaScan struct { + // Username to which the quota scan refers + Username string `json:"username"` + // quota scan start time as unix timestamp in milliseconds + StartTime int64 `json:"start_time"` + Role string `json:"-"` +} + +// ActiveVirtualFolderQuotaScan defines an active quota scan for a virtual folder +type ActiveVirtualFolderQuotaScan struct { + // folder name to which the quota scan refers + Name string `json:"name"` + // quota scan start time as unix timestamp in milliseconds + StartTime int64 `json:"start_time"` +} + +// ActiveScans holds the active quota scans +type ActiveScans struct { + sync.RWMutex + UserScans []ActiveQuotaScan + FolderScans []ActiveVirtualFolderQuotaScan +} + +// GetUsersQuotaScans returns the active users quota scans +func (s *ActiveScans) GetUsersQuotaScans(role string) []ActiveQuotaScan { + s.RLock() + defer s.RUnlock() + + scans := make([]ActiveQuotaScan, 0, len(s.UserScans)) + for _, scan := range s.UserScans { + if role == "" || role == scan.Role { + scans = append(scans, ActiveQuotaScan{ + Username: scan.Username, + StartTime: scan.StartTime, + }) + } + } + + return scans +} + +// AddUserQuotaScan adds a user to the ones with active quota scans. +// Returns false if the user has a quota scan already running +func (s *ActiveScans) AddUserQuotaScan(username, role string) bool { + s.Lock() + defer s.Unlock() + + for _, scan := range s.UserScans { + if scan.Username == username { + return false + } + } + s.UserScans = append(s.UserScans, ActiveQuotaScan{ + Username: username, + StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), + Role: role, + }) + return true +} + +// RemoveUserQuotaScan removes a user from the ones with active quota scans. +// Returns false if the user has no active quota scans +func (s *ActiveScans) RemoveUserQuotaScan(username string) bool { + s.Lock() + defer s.Unlock() + + for idx, scan := range s.UserScans { + if scan.Username == username { + lastIdx := len(s.UserScans) - 1 + s.UserScans[idx] = s.UserScans[lastIdx] + s.UserScans = s.UserScans[:lastIdx] + return true + } + } + + return false +} + +// GetVFoldersQuotaScans returns the active quota scans for virtual folders +func (s *ActiveScans) GetVFoldersQuotaScans() []ActiveVirtualFolderQuotaScan { + s.RLock() + defer s.RUnlock() + scans := make([]ActiveVirtualFolderQuotaScan, len(s.FolderScans)) + copy(scans, s.FolderScans) + return scans +} + +// AddVFolderQuotaScan adds a virtual folder to the ones with active quota scans. +// Returns false if the folder has a quota scan already running +func (s *ActiveScans) AddVFolderQuotaScan(folderName string) bool { + s.Lock() + defer s.Unlock() + + for _, scan := range s.FolderScans { + if scan.Name == folderName { + return false + } + } + s.FolderScans = append(s.FolderScans, ActiveVirtualFolderQuotaScan{ + Name: folderName, + StartTime: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + return true +} + +// RemoveVFolderQuotaScan removes a folder from the ones with active quota scans. +// Returns false if the folder has no active quota scans +func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool { + s.Lock() + defer s.Unlock() + + for idx, scan := range s.FolderScans { + if scan.Name == folderName { + lastIdx := len(s.FolderScans) - 1 + s.FolderScans[idx] = s.FolderScans[lastIdx] + s.FolderScans = s.FolderScans[:lastIdx] + return true + } + } + + return false +} diff --git a/internal/common/common_test.go b/internal/common/common_test.go new file mode 100644 index 00000000..8566d773 --- /dev/null +++ b/internal/common/common_test.go @@ -0,0 +1,1957 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "runtime" + "slices" + "sync" + "testing" + "time" + + "github.com/alexedwards/argon2id" + "github.com/pires/go-proxyproto" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + logSenderTest = "common_test" + httpAddr = "127.0.0.1:9999" + osWindows = "windows" + userTestUsername = "common_test_username" +) + +var ( + configDir = filepath.Join(".", "..", "..") +) + +type fakeConnection struct { + *BaseConnection + command string +} + +func (c *fakeConnection) AddUser(user dataprovider.User) error { + _, err := user.GetFilesystem(c.GetID()) + if err != nil { + return err + } + c.User = user + return nil +} + +func (c *fakeConnection) Disconnect() error { + Connections.Remove(c.GetID()) + return nil +} + +func (c *fakeConnection) GetClientVersion() string { + return "" +} + +func (c *fakeConnection) GetCommand() string { + return c.command +} + +func (c *fakeConnection) GetLocalAddress() string { + return "" +} + +func (c *fakeConnection) GetRemoteAddress() string { + return "" +} + +type customNetConn struct { + net.Conn + id string + isClosed bool +} + +func (c *customNetConn) Close() error { + Connections.RemoveSSHConnection(c.id) + c.isClosed = true + return c.Conn.Close() +} + +func TestConnections(t *testing.T) { + c1 := &fakeConnection{ + BaseConnection: NewBaseConnection("id1", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }), + } + c2 := &fakeConnection{ + BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }), + } + c3 := &fakeConnection{ + BaseConnection: NewBaseConnection("id3", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }), + } + c4 := &fakeConnection{ + BaseConnection: NewBaseConnection("id4", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }), + } + assert.Equal(t, "SFTP_id1", c1.GetID()) + assert.Equal(t, "SFTP_id2", c2.GetID()) + assert.Equal(t, "SFTP_id3", c3.GetID()) + assert.Equal(t, "SFTP_id4", c4.GetID()) + err := Connections.Add(c1) + assert.NoError(t, err) + err = Connections.Add(c2) + assert.NoError(t, err) + err = Connections.Add(c3) + assert.NoError(t, err) + err = Connections.Add(c4) + assert.NoError(t, err) + + Connections.RLock() + assert.Len(t, Connections.connections, 4) + assert.Len(t, Connections.mapping, 4) + _, ok := Connections.mapping[c1.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.mapping[c1.GetID()]) + assert.Equal(t, 1, Connections.mapping[c2.GetID()]) + assert.Equal(t, 2, Connections.mapping[c3.GetID()]) + assert.Equal(t, 3, Connections.mapping[c4.GetID()]) + Connections.RUnlock() + + c2 = &fakeConnection{ + BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername + "_mod", + }, + }), + } + err = Connections.Swap(c2) + assert.NoError(t, err) + + Connections.RLock() + assert.Len(t, Connections.connections, 4) + assert.Len(t, Connections.mapping, 4) + _, ok = Connections.mapping[c1.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.mapping[c1.GetID()]) + assert.Equal(t, 1, Connections.mapping[c2.GetID()]) + assert.Equal(t, 2, Connections.mapping[c3.GetID()]) + assert.Equal(t, 3, Connections.mapping[c4.GetID()]) + assert.Equal(t, userTestUsername+"_mod", Connections.connections[1].GetUsername()) + Connections.RUnlock() + + Connections.Remove(c2.GetID()) + + Connections.RLock() + assert.Len(t, Connections.connections, 3) + assert.Len(t, Connections.mapping, 3) + _, ok = Connections.mapping[c1.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.mapping[c1.GetID()]) + assert.Equal(t, 1, Connections.mapping[c4.GetID()]) + assert.Equal(t, 2, Connections.mapping[c3.GetID()]) + Connections.RUnlock() + + Connections.Remove(c3.GetID()) + + Connections.RLock() + assert.Len(t, Connections.connections, 2) + assert.Len(t, Connections.mapping, 2) + _, ok = Connections.mapping[c1.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.mapping[c1.GetID()]) + assert.Equal(t, 1, Connections.mapping[c4.GetID()]) + Connections.RUnlock() + + Connections.Remove(c1.GetID()) + + Connections.RLock() + assert.Len(t, Connections.connections, 1) + assert.Len(t, Connections.mapping, 1) + _, ok = Connections.mapping[c4.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.mapping[c4.GetID()]) + Connections.RUnlock() + + Connections.Remove(c4.GetID()) + + Connections.RLock() + assert.Len(t, Connections.connections, 0) + assert.Len(t, Connections.mapping, 0) + Connections.RUnlock() +} + +func TestEventManagerCommandsInitialization(t *testing.T) { + configCopy := Config + + c := Configuration{ + EventManager: EventManagerConfig{ + EnabledCommands: []string{"ls"}, // not an absolute path + }, + } + err := Initialize(c, 0) + assert.ErrorContains(t, err, "invalid command") + + var commands []string + if runtime.GOOS == osWindows { + commands = []string{"C:\\command"} + } else { + commands = []string{"/bin/ls"} + } + + c.EventManager.EnabledCommands = commands + err = Initialize(c, 0) + assert.NoError(t, err) + assert.Equal(t, commands, dataprovider.EnabledActionCommands) + + dataprovider.EnabledActionCommands = configCopy.EventManager.EnabledCommands + Config = configCopy +} + +func TestInitializationProxyErrors(t *testing.T) { + configCopy := Config + + c := Configuration{ + ProxyProtocol: 1, + ProxyAllowed: []string{"1.1.1.1111"}, + } + err := Initialize(c, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid proxy allowed") + } + c.ProxyAllowed = nil + c.ProxySkipped = []string{"invalid"} + err = Initialize(c, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid proxy skipped") + } + c.ProxyAllowed = []string{"1.1.1.1"} + c.ProxySkipped = []string{"2.2.2.2", "10.8.0.0/24"} + err = Initialize(c, 0) + assert.NoError(t, err) + assert.Len(t, Config.proxyAllowed, 1) + assert.Len(t, Config.proxySkipped, 2) + + Config = configCopy + assert.Equal(t, 0, Config.ProxyProtocol) + assert.Len(t, Config.proxyAllowed, 0) + assert.Len(t, Config.proxySkipped, 0) +} + +func TestInitializationClosedProvider(t *testing.T) { + configCopy := Config + + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + + config := Configuration{ + AllowListStatus: 1, + } + err = Initialize(config, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to initialize the allow list") + } + + config.AllowListStatus = 0 + config.RateLimitersConfig = []RateLimiterConfig{ + { + Average: 100, + Period: 1000, + Burst: 5, + Type: int(rateLimiterTypeGlobal), + Protocols: rateLimiterProtocolValues, + }, + } + err = Initialize(config, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to initialize ratelimiters list") + } + + config.RateLimitersConfig = nil + config.DefenderConfig = DefenderConfig{ + Enabled: true, + Driver: DefenderDriverProvider, + BanTime: 10, + BanTimeIncrement: 50, + Threshold: 10, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreNoAuth: 2, + ObservationTime: 15, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + } + err = Initialize(config, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "defender initialization error") + } + config.DefenderConfig.Driver = DefenderDriverMemory + err = Initialize(config, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "defender initialization error") + } + + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + Config = configCopy +} + +func TestSSHConnections(t *testing.T) { + conn1, conn2 := net.Pipe() + now := time.Now() + sshConn1 := NewSSHConnection("id1", conn1) + sshConn2 := NewSSHConnection("id2", conn2) + sshConn3 := NewSSHConnection("id3", conn2) + assert.Equal(t, "id1", sshConn1.GetID()) + assert.Equal(t, "id2", sshConn2.GetID()) + assert.Equal(t, "id3", sshConn3.GetID()) + sshConn1.UpdateLastActivity() + assert.GreaterOrEqual(t, sshConn1.GetLastActivity().UnixNano(), now.UnixNano()) + Connections.AddSSHConnection(sshConn1) + Connections.AddSSHConnection(sshConn2) + Connections.AddSSHConnection(sshConn3) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 3) + _, ok := Connections.sshMapping[sshConn1.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.sshMapping[sshConn1.GetID()]) + assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()]) + assert.Equal(t, 2, Connections.sshMapping[sshConn3.GetID()]) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn1.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 2) + assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) + assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id) + _, ok = Connections.sshMapping[sshConn3.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()]) + assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()]) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn1.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 2) + assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) + assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id) + _, ok = Connections.sshMapping[sshConn3.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()]) + assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()]) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn2.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 1) + assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) + _, ok = Connections.sshMapping[sshConn3.GetID()] + assert.True(t, ok) + assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()]) + Connections.RUnlock() + Connections.RemoveSSHConnection(sshConn3.id) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 0) + assert.Len(t, Connections.sshMapping, 0) + Connections.RUnlock() + assert.NoError(t, sshConn1.Close()) + assert.NoError(t, sshConn2.Close()) + assert.NoError(t, sshConn3.Close()) +} + +func TestDefenderIntegration(t *testing.T) { + // by default defender is nil + configCopy := Config + + wdPath, err := os.Getwd() + require.NoError(t, err) + pluginsConfig := []plugin.Config{ + { + Type: "ipfilter", + Cmd: filepath.Join(wdPath, "..", "..", "tests", "ipfilter", "ipfilter"), + AutoMTLS: true, + }, + } + if runtime.GOOS == osWindows { + pluginsConfig[0].Cmd += ".exe" + } + err = plugin.Initialize(pluginsConfig, "debug") + require.NoError(t, err) + + ip := "127.1.1.1" + + assert.Nil(t, Reload()) + // 192.168.1.12 is banned from the ipfilter plugin + assert.True(t, IsBanned("192.168.1.12", ProtocolFTP)) + + AddDefenderEvent(ip, ProtocolFTP, HostEventNoLoginTried) + assert.False(t, IsBanned(ip, ProtocolFTP)) + + banTime, err := GetDefenderBanTime(ip) + assert.NoError(t, err) + assert.Nil(t, banTime) + assert.False(t, DeleteDefenderHost(ip)) + score, err := GetDefenderScore(ip) + assert.NoError(t, err) + assert.Equal(t, 0, score) + _, err = GetDefenderHost(ip) + assert.Error(t, err) + hosts, err := GetDefenderHosts() + assert.NoError(t, err) + assert.Nil(t, hosts) + + Config.DefenderConfig = DefenderConfig{ + Enabled: true, + Driver: DefenderDriverProvider, + BanTime: 10, + BanTimeIncrement: 50, + Threshold: 0, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreNoAuth: 2, + ObservationTime: 15, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + LoginDelay: LoginDelay{ + PasswordFailed: 200, + }, + } + err = Initialize(Config, 0) + // ScoreInvalid cannot be greater than threshold + assert.Error(t, err) + Config.DefenderConfig.Driver = "unsupported" + err = Initialize(Config, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported defender driver") + } + Config.DefenderConfig.Driver = DefenderDriverMemory + err = Initialize(Config, 0) + // ScoreInvalid cannot be greater than threshold + assert.Error(t, err) + Config.DefenderConfig.Threshold = 3 + + err = Initialize(Config, 0) + assert.NoError(t, err) + assert.Nil(t, Reload()) + + AddDefenderEvent(ip, ProtocolSSH, HostEventNoLoginTried) + assert.False(t, IsBanned(ip, ProtocolSSH)) + score, err = GetDefenderScore(ip) + assert.NoError(t, err) + assert.Equal(t, 2, score) + entry, err := GetDefenderHost(ip) + assert.NoError(t, err) + asJSON, err := json.Marshal(&entry) + assert.NoError(t, err) + assert.Equal(t, `{"id":"3132372e312e312e31","ip":"127.1.1.1","score":2}`, string(asJSON), "entry %v", entry) + assert.True(t, DeleteDefenderHost(ip)) + banTime, err = GetDefenderBanTime(ip) + assert.NoError(t, err) + assert.Nil(t, banTime) + + AddDefenderEvent(ip, ProtocolHTTP, HostEventLoginFailed) + AddDefenderEvent(ip, ProtocolHTTP, HostEventNoLoginTried) + assert.True(t, IsBanned(ip, ProtocolHTTP)) + score, err = GetDefenderScore(ip) + assert.NoError(t, err) + assert.Equal(t, 0, score) + banTime, err = GetDefenderBanTime(ip) + assert.NoError(t, err) + assert.NotNil(t, banTime) + hosts, err = GetDefenderHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 1) + entry, err = GetDefenderHost(ip) + assert.NoError(t, err) + assert.False(t, entry.BanTime.IsZero()) + assert.True(t, DeleteDefenderHost(ip)) + hosts, err = GetDefenderHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 0) + banTime, err = GetDefenderBanTime(ip) + assert.NoError(t, err) + assert.Nil(t, banTime) + assert.False(t, DeleteDefenderHost(ip)) + + startTime := time.Now() + DelayLogin(nil) + elapsed := time.Since(startTime) + assert.Less(t, elapsed, time.Millisecond*50) + + startTime = time.Now() + DelayLogin(ErrInternalFailure) + elapsed = time.Since(startTime) + assert.Greater(t, elapsed, time.Millisecond*150) + + Config = configCopy +} + +func TestRateLimitersIntegration(t *testing.T) { + configCopy := Config + + enabled, protocols := Config.GetRateLimitersStatus() + assert.False(t, enabled) + assert.Len(t, protocols, 0) + + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "172.16.24.7/32", + Type: dataprovider.IPListTypeRateLimiterSafeList, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "172.16.0.0/16", + Type: dataprovider.IPListTypeRateLimiterSafeList, + Mode: dataprovider.ListModeAllow, + }, + } + + for idx := range entries { + e := entries[idx] + err := dataprovider.AddIPListEntry(&e, "", "", "") + assert.NoError(t, err) + } + + Config.RateLimitersConfig = []RateLimiterConfig{ + { + Average: 100, + Period: 10, + Burst: 5, + Type: int(rateLimiterTypeGlobal), + Protocols: rateLimiterProtocolValues, + }, + { + Average: 1, + Period: 1000, + Burst: 1, + Type: int(rateLimiterTypeSource), + Protocols: []string{ProtocolWebDAV, ProtocolWebDAV, ProtocolFTP}, + GenerateDefenderEvents: true, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + }, + } + err := Initialize(Config, 0) + assert.Error(t, err) + Config.RateLimitersConfig[0].Period = 1000 + + err = Initialize(Config, 0) + assert.NoError(t, err) + assert.NotNil(t, Config.rateLimitersList) + + assert.Len(t, rateLimiters, 4) + assert.Len(t, rateLimiters[ProtocolSSH], 1) + assert.Len(t, rateLimiters[ProtocolFTP], 2) + assert.Len(t, rateLimiters[ProtocolWebDAV], 2) + assert.Len(t, rateLimiters[ProtocolHTTP], 1) + + enabled, protocols = Config.GetRateLimitersStatus() + assert.True(t, enabled) + assert.Len(t, protocols, 4) + assert.Contains(t, protocols, ProtocolFTP) + assert.Contains(t, protocols, ProtocolSSH) + assert.Contains(t, protocols, ProtocolHTTP) + assert.Contains(t, protocols, ProtocolWebDAV) + + source1 := "127.1.1.1" + source2 := "127.1.1.2" + source3 := "172.16.24.7" // in safelist + + _, err = LimitRate(ProtocolSSH, source1) + assert.NoError(t, err) + _, err = LimitRate(ProtocolFTP, source1) + assert.NoError(t, err) + // sleep to allow the add configured burst to the token. + // This sleep is not enough to add the per-source burst + time.Sleep(20 * time.Millisecond) + _, err = LimitRate(ProtocolWebDAV, source2) + assert.NoError(t, err) + _, err = LimitRate(ProtocolFTP, source1) + assert.Error(t, err) + _, err = LimitRate(ProtocolWebDAV, source2) + assert.Error(t, err) + _, err = LimitRate(ProtocolSSH, source1) + assert.NoError(t, err) + _, err = LimitRate(ProtocolSSH, source2) + assert.NoError(t, err) + for i := 0; i < 10; i++ { + _, err = LimitRate(ProtocolWebDAV, source3) + assert.NoError(t, err) + } + for _, e := range entries { + err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") + assert.NoError(t, err) + } + + assert.Nil(t, configCopy.rateLimitersList) + Config = configCopy +} + +func TestUserMaxSessions(t *testing.T) { + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + MaxSessions: 1, + }, + }) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err := Connections.Add(fakeConn) + assert.NoError(t, err) + err = Connections.Add(fakeConn) + assert.Error(t, err) + err = Connections.Swap(fakeConn) + assert.NoError(t, err) + Connections.Remove(fakeConn.GetID()) + Connections.Lock() + Connections.removeUserConnection(userTestUsername) + Connections.Unlock() + assert.Len(t, Connections.GetStats(""), 0) +} + +func TestMaxConnections(t *testing.T) { + oldValue := Config.MaxTotalConnections + perHost := Config.MaxPerHostConnections + + Config.MaxPerHostConnections = 0 + + ipAddr := "192.168.7.8" + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP)) + assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername)) + + Config.MaxTotalConnections = 1 + Config.MaxPerHostConnections = perHost + + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolHTTP)) + assert.NoError(t, Connections.IsNewTransferAllowed(userTestUsername)) + isShuttingDown.Store(true) + assert.ErrorIs(t, Connections.IsNewTransferAllowed(userTestUsername), ErrShuttingDown) + isShuttingDown.Store(false) + + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err := Connections.Add(fakeConn) + assert.NoError(t, err) + assert.Len(t, Connections.GetStats(""), 1) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + Connections.transfers.add(userTestUsername) + assert.Error(t, Connections.IsNewTransferAllowed(userTestUsername)) + Connections.transfers.remove(userTestUsername) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) + + res := Connections.Close(fakeConn.GetID(), "") + assert.True(t, res) + assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) + + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + Connections.AddClientConnection(ipAddr) + Connections.AddClientConnection(ipAddr) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + Connections.RemoveClientConnection(ipAddr) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV)) + Connections.transfers.add(userTestUsername) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + Connections.transfers.remove(userTestUsername) + Connections.RemoveClientConnection(ipAddr) + + Config.MaxTotalConnections = oldValue +} + +func TestConnectionRoles(t *testing.T) { + username := "testUsername" + role1 := "testRole1" + role2 := "testRole2" + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Role: role1, + }, + }) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err := Connections.Add(fakeConn) + assert.NoError(t, err) + assert.Len(t, Connections.GetStats(""), 1) + assert.Len(t, Connections.GetStats(role1), 1) + assert.Len(t, Connections.GetStats(role2), 0) + + res := Connections.Close(fakeConn.GetID(), role2) + assert.False(t, res) + assert.Len(t, Connections.GetStats(""), 1) + res = Connections.Close(fakeConn.GetID(), role1) + assert.True(t, res) + assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) +} + +func TestMaxConnectionPerHost(t *testing.T) { + defender, err := newInMemoryDefender(&DefenderConfig{ + Enabled: true, + Driver: DefenderDriverMemory, + BanTime: 30, + BanTimeIncrement: 50, + Threshold: 15, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreLimitExceeded: 3, + ObservationTime: 30, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + }) + require.NoError(t, err) + + oldMaxPerHostConn := Config.MaxPerHostConnections + oldDefender := Config.defender + + Config.MaxPerHostConnections = 2 + Config.defender = defender + + ipAddr := "192.168.9.9" + Connections.AddClientConnection(ipAddr) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + + Connections.AddClientConnection(ipAddr) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolWebDAV)) + + Connections.AddClientConnection(ipAddr) + assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolFTP)) + assert.Equal(t, int32(3), Connections.GetClientConnections()) + // Add the IP to the defender safe list + entry := dataprovider.IPListEntry{ + IPOrNet: ipAddr, + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + } + err = dataprovider.AddIPListEntry(&entry, "", "", "") + assert.NoError(t, err) + + Connections.AddClientConnection(ipAddr) + assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr, ProtocolSSH)) + + err = dataprovider.DeleteIPListEntry(entry.IPOrNet, dataprovider.IPListTypeDefender, "", "", "") + assert.NoError(t, err) + + Connections.RemoveClientConnection(ipAddr) + Connections.RemoveClientConnection(ipAddr) + Connections.RemoveClientConnection(ipAddr) + Connections.RemoveClientConnection(ipAddr) + + assert.Equal(t, int32(0), Connections.GetClientConnections()) + + Config.MaxPerHostConnections = oldMaxPerHostConn + Config.defender = oldDefender +} + +func TestIdleConnections(t *testing.T) { + configCopy := Config + + Config.IdleTimeout = 1 + err := Initialize(Config, 0) + assert.NoError(t, err) + + conn1, conn2 := net.Pipe() + customConn1 := &customNetConn{ + Conn: conn1, + id: "id1", + } + customConn2 := &customNetConn{ + Conn: conn2, + id: "id2", + } + sshConn1 := NewSSHConnection(customConn1.id, customConn1) + sshConn2 := NewSSHConnection(customConn2.id, customConn2) + + username := "test_user" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Status: 1, + }, + } + c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user) + c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + // both ssh connections are expired but they should get removed only + // if there is no associated connection + sshConn1.lastActivity.Store(c.lastActivity.Load()) + sshConn2.lastActivity.Store(c.lastActivity.Load()) + Connections.AddSSHConnection(sshConn1) + err = Connections.Add(fakeConn) + assert.NoError(t, err) + assert.Equal(t, Connections.GetActiveSessions(username), 1) + c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, "", "", user) + fakeConn = &fakeConnection{ + BaseConnection: c, + } + Connections.AddSSHConnection(sshConn2) + err = Connections.Add(fakeConn) + assert.NoError(t, err) + assert.Equal(t, Connections.GetActiveSessions(username), 2) + + cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{}) + cFTP.lastActivity.Store(time.Now().UnixNano()) + fakeConn = &fakeConnection{ + BaseConnection: cFTP, + } + err = Connections.Add(fakeConn) + assert.NoError(t, err) + // the user is expired, this connection will be removed + cDAV := NewBaseConnection("id3", ProtocolWebDAV, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username + "_2", + Status: 1, + ExpirationDate: util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)), + }, + }) + cDAV.lastActivity.Store(time.Now().UnixNano()) + fakeConn = &fakeConnection{ + BaseConnection: cDAV, + } + err = Connections.Add(fakeConn) + assert.NoError(t, err) + + assert.Equal(t, 2, Connections.GetActiveSessions(username)) + assert.Len(t, Connections.GetStats(""), 4) + Connections.RLock() + assert.Len(t, Connections.sshConnections, 2) + Connections.RUnlock() + + startPeriodicChecks(100*time.Millisecond, 0) + assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond) + assert.Eventually(t, func() bool { + Connections.RLock() + defer Connections.RUnlock() + return len(Connections.sshConnections) == 1 + }, 1*time.Second, 200*time.Millisecond) + stopEventScheduler() + assert.Len(t, Connections.GetStats(""), 2) + c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) + cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) + sshConn2.lastActivity.Store(c.lastActivity.Load()) + startPeriodicChecks(100*time.Millisecond, 1) + assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 2*time.Second, 200*time.Millisecond) + assert.Eventually(t, func() bool { + Connections.RLock() + defer Connections.RUnlock() + return len(Connections.sshConnections) == 0 + }, 1*time.Second, 200*time.Millisecond) + assert.Equal(t, int32(0), Connections.GetClientConnections()) + stopEventScheduler() + assert.True(t, customConn1.isClosed) + assert.True(t, customConn2.isClosed) + + Config = configCopy +} + +func TestCloseConnection(t *testing.T) { + c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + assert.NoError(t, Connections.IsNewConnectionAllowed("127.0.0.1", ProtocolHTTP)) + err := Connections.Add(fakeConn) + assert.NoError(t, err) + assert.Len(t, Connections.GetStats(""), 1) + res := Connections.Close(fakeConn.GetID(), "") + assert.True(t, res) + assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) + res = Connections.Close(fakeConn.GetID(), "") + assert.False(t, res) + Connections.Remove(fakeConn.GetID()) +} + +func TestSwapConnection(t *testing.T) { + c := NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{}) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err := Connections.Add(fakeConn) + assert.NoError(t, err) + if assert.Len(t, Connections.GetStats(""), 1) { + assert.Equal(t, "", Connections.GetStats("")[0].Username) + } + c = NewBaseConnection("id", ProtocolFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + MaxSessions: 1, + }, + }) + fakeConn = &fakeConnection{ + BaseConnection: c, + } + c1 := NewBaseConnection("id1", ProtocolFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }) + fakeConn1 := &fakeConnection{ + BaseConnection: c1, + } + err = Connections.Add(fakeConn1) + assert.NoError(t, err) + err = Connections.Swap(fakeConn) + assert.Error(t, err) + Connections.Remove(fakeConn1.ID) + err = Connections.Swap(fakeConn) + assert.NoError(t, err) + if assert.Len(t, Connections.GetStats(""), 1) { + assert.Equal(t, userTestUsername, Connections.GetStats("")[0].Username) + } + res := Connections.Close(fakeConn.GetID(), "") + assert.True(t, res) + assert.Eventually(t, func() bool { return len(Connections.GetStats("")) == 0 }, 300*time.Millisecond, 50*time.Millisecond) + err = Connections.Swap(fakeConn) + assert.Error(t, err) +} + +func TestAtomicUpload(t *testing.T) { + configCopy := Config + + Config.UploadMode = UploadModeStandard + assert.False(t, Config.IsAtomicUploadEnabled()) + Config.UploadMode = UploadModeAtomic + assert.True(t, Config.IsAtomicUploadEnabled()) + Config.UploadMode = UploadModeAtomicWithResume + assert.True(t, Config.IsAtomicUploadEnabled()) + + Config = configCopy +} + +func TestConnectionStatus(t *testing.T) { + username := "test_user" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + c1 := NewBaseConnection("id1", ProtocolSFTP, "", "", user) + fakeConn1 := &fakeConnection{ + BaseConnection: c1, + } + t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + t1.BytesReceived.Store(123) + t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + t2.BytesSent.Store(456) + c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) + fakeConn2 := &fakeConnection{ + BaseConnection: c2, + command: "md5sum", + } + c3 := NewBaseConnection("id3", ProtocolWebDAV, "", "", user) + fakeConn3 := &fakeConnection{ + BaseConnection: c3, + command: "PROPFIND", + } + t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + err := Connections.Add(fakeConn1) + assert.NoError(t, err) + err = Connections.Add(fakeConn2) + assert.NoError(t, err) + err = Connections.Add(fakeConn3) + assert.NoError(t, err) + + stats := Connections.GetStats("") + assert.Len(t, stats, 3) + for _, stat := range stats { + assert.Equal(t, stat.Username, username) + switch stat.ConnectionID { + case "SFTP_id1": + assert.Len(t, stat.Transfers, 2) + case "DAV_id3": + assert.Len(t, stat.Transfers, 1) + } + } + + err = t1.Close() + assert.NoError(t, err) + err = t2.Close() + assert.NoError(t, err) + + err = fakeConn3.SignalTransfersAbort() + assert.NoError(t, err) + assert.True(t, t3.AbortTransfer.Load()) + err = t3.Close() + assert.NoError(t, err) + err = fakeConn3.SignalTransfersAbort() + assert.Error(t, err) + + Connections.Remove(fakeConn1.GetID()) + stats = Connections.GetStats("") + assert.Len(t, stats, 2) + assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID) + assert.Equal(t, fakeConn2.GetID(), stats[1].ConnectionID) + Connections.Remove(fakeConn2.GetID()) + stats = Connections.GetStats("") + assert.Len(t, stats, 1) + assert.Equal(t, fakeConn3.GetID(), stats[0].ConnectionID) + Connections.Remove(fakeConn3.GetID()) + stats = Connections.GetStats("") + assert.Len(t, stats, 0) +} + +func TestQuotaScans(t *testing.T) { + username := "username" + assert.True(t, QuotaScans.AddUserQuotaScan(username, "")) + assert.False(t, QuotaScans.AddUserQuotaScan(username, "")) + usersScans := QuotaScans.GetUsersQuotaScans("") + if assert.Len(t, usersScans, 1) { + assert.Equal(t, usersScans[0].Username, username) + assert.Equal(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime) + QuotaScans.UserScans[0].StartTime = 0 + assert.NotEqual(t, QuotaScans.UserScans[0].StartTime, usersScans[0].StartTime) + } + + assert.True(t, QuotaScans.RemoveUserQuotaScan(username)) + assert.False(t, QuotaScans.RemoveUserQuotaScan(username)) + assert.Len(t, QuotaScans.GetUsersQuotaScans(""), 0) + assert.Len(t, usersScans, 1) + + folderName := "folder" + assert.True(t, QuotaScans.AddVFolderQuotaScan(folderName)) + assert.False(t, QuotaScans.AddVFolderQuotaScan(folderName)) + if assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 1) { + assert.Equal(t, QuotaScans.GetVFoldersQuotaScans()[0].Name, folderName) + } + + assert.True(t, QuotaScans.RemoveVFolderQuotaScan(folderName)) + assert.False(t, QuotaScans.RemoveVFolderQuotaScan(folderName)) + assert.Len(t, QuotaScans.GetVFoldersQuotaScans(), 0) +} + +func TestQuotaScansRole(t *testing.T) { + username := "u" + role1 := "r1" + role2 := "r2" + assert.True(t, QuotaScans.AddUserQuotaScan(username, role1)) + assert.False(t, QuotaScans.AddUserQuotaScan(username, "")) + usersScans := QuotaScans.GetUsersQuotaScans("") + assert.Len(t, usersScans, 1) + assert.Empty(t, usersScans[0].Role) + usersScans = QuotaScans.GetUsersQuotaScans(role1) + assert.Len(t, usersScans, 1) + usersScans = QuotaScans.GetUsersQuotaScans(role2) + assert.Len(t, usersScans, 0) + assert.True(t, QuotaScans.RemoveUserQuotaScan(username)) + assert.False(t, QuotaScans.RemoveUserQuotaScan(username)) + assert.Len(t, QuotaScans.GetUsersQuotaScans(""), 0) +} + +func TestProxyPolicy(t *testing.T) { + addr := net.TCPAddr{} + downstream := net.TCPAddr{IP: net.ParseIP("1.1.1.1")} + p := getProxyPolicy(nil, nil, proxyproto.IGNORE) + policy, err := p(proxyproto.ConnPolicyOptions{ + Upstream: &addr, + Downstream: &downstream, + }) + assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream) + assert.Equal(t, proxyproto.REJECT, policy) + ip1 := net.ParseIP("10.8.1.1") + ip2 := net.ParseIP("10.8.1.2") + ip3 := net.ParseIP("10.8.1.3") + allowed, err := util.ParseAllowedIPAndRanges([]string{ip1.String()}) + assert.NoError(t, err) + skipped, err := util.ParseAllowedIPAndRanges([]string{ip2.String(), ip3.String()}) + assert.NoError(t, err) + p = getProxyPolicy(allowed, skipped, proxyproto.IGNORE) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip1}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.USE, policy) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip2}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.SKIP, policy) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip3}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.SKIP, policy) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.4")}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.IGNORE, policy) + p = getProxyPolicy(allowed, skipped, proxyproto.REQUIRE) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip1}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.REQUIRE, policy) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip2}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.SKIP, policy) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: ip3}, + Downstream: &downstream, + }) + assert.NoError(t, err) + assert.Equal(t, proxyproto.SKIP, policy) + policy, err = p(proxyproto.ConnPolicyOptions{ + Upstream: &net.TCPAddr{IP: net.ParseIP("10.8.1.5")}, + Downstream: &downstream, + }) + assert.ErrorIs(t, err, proxyproto.ErrInvalidUpstream) + assert.Equal(t, proxyproto.REJECT, policy) +} + +func TestProxyProtocolVersion(t *testing.T) { + c := Configuration{ + ProxyProtocol: 0, + } + _, err := c.GetProxyListener(nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "proxy protocol not configured") + } + c.ProxyProtocol = 1 + listener, err := c.GetProxyListener(nil) + assert.NoError(t, err) + proxyListener, ok := listener.(*proxyproto.Listener) + require.True(t, ok) + assert.NotNil(t, proxyListener.ConnPolicy) + + c.ProxyProtocol = 2 + listener, err = c.GetProxyListener(nil) + assert.NoError(t, err) + proxyListener, ok = listener.(*proxyproto.Listener) + require.True(t, ok) + assert.NotNil(t, proxyListener.ConnPolicy) +} + +func TestStartupHook(t *testing.T) { + Config.StartupHook = "" + + assert.NoError(t, Config.ExecuteStartupHook()) + + Config.StartupHook = "http://foo\x7f.com/startup" + assert.Error(t, Config.ExecuteStartupHook()) + + Config.StartupHook = "http://invalid:5678/" + assert.Error(t, Config.ExecuteStartupHook()) + + Config.StartupHook = fmt.Sprintf("http://%v", httpAddr) + assert.NoError(t, Config.ExecuteStartupHook()) + + Config.StartupHook = "invalidhook" + assert.Error(t, Config.ExecuteStartupHook()) + + if runtime.GOOS != osWindows { + hookCmd, err := exec.LookPath("true") + assert.NoError(t, err) + Config.StartupHook = hookCmd + assert.NoError(t, Config.ExecuteStartupHook()) + } + + Config.StartupHook = "" +} + +func TestPostDisconnectHook(t *testing.T) { + Config.PostDisconnectHook = "http://127.0.0.1/" + + remoteAddr := "127.0.0.1:80" + Config.checkPostDisconnectHook(remoteAddr, ProtocolHTTP, "", "", time.Now()) + Config.checkPostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + + Config.PostDisconnectHook = "http://bar\x7f.com/" + Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + + Config.PostDisconnectHook = fmt.Sprintf("http://%v", httpAddr) + Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + + Config.PostDisconnectHook = "relativePath" + Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + + if runtime.GOOS == osWindows { + Config.PostDisconnectHook = "C:\\a\\bad\\command" + Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + } else { + Config.PostDisconnectHook = "/invalid/path" + Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + + hookCmd, err := exec.LookPath("true") + assert.NoError(t, err) + Config.PostDisconnectHook = hookCmd + Config.executePostDisconnectHook(remoteAddr, ProtocolSFTP, "", "", time.Now()) + } + Config.PostDisconnectHook = "" +} + +func TestPostConnectHook(t *testing.T) { + Config.PostConnectHook = "" + + ipAddr := "127.0.0.1" + + assert.NoError(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) + + Config.PostConnectHook = "http://foo\x7f.com/" + assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) + + Config.PostConnectHook = "http://invalid:1234/" + assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) + + Config.PostConnectHook = fmt.Sprintf("http://%v/404", httpAddr) + assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) + + Config.PostConnectHook = fmt.Sprintf("http://%v", httpAddr) + assert.NoError(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) + + Config.PostConnectHook = "invalid" + assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolFTP)) + + if runtime.GOOS == osWindows { + Config.PostConnectHook = "C:\\bad\\command" + assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) + } else { + Config.PostConnectHook = "/invalid/path" + assert.Error(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) + + hookCmd, err := exec.LookPath("true") + assert.NoError(t, err) + Config.PostConnectHook = hookCmd + assert.NoError(t, Config.ExecutePostConnectHook(ipAddr, ProtocolSFTP)) + } + + Config.PostConnectHook = "" +} + +func TestCryptoConvertFileInfo(t *testing.T) { + name := "name" + fs, err := vfs.NewCryptFs("connID1", os.TempDir(), "", vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("secret"), + }) + require.NoError(t, err) + cryptFs := fs.(*vfs.CryptFs) + info := vfs.NewFileInfo(name, true, 48, time.Now(), false) + assert.Equal(t, info, cryptFs.ConvertFileInfo(info)) + info = vfs.NewFileInfo(name, false, 48, time.Now(), false) + assert.NotEqual(t, info.Size(), cryptFs.ConvertFileInfo(info).Size()) + info = vfs.NewFileInfo(name, false, 33, time.Now(), false) + assert.Equal(t, int64(0), cryptFs.ConvertFileInfo(info).Size()) + info = vfs.NewFileInfo(name, false, 1, time.Now(), false) + assert.Equal(t, int64(0), cryptFs.ConvertFileInfo(info).Size()) +} + +func TestFolderCopy(t *testing.T) { + folder := vfs.BaseVirtualFolder{ + ID: 1, + Name: "name", + MappedPath: filepath.Clean(os.TempDir()), + UsedQuotaSize: 4096, + UsedQuotaFiles: 2, + LastQuotaUpdate: util.GetTimeAsMsSinceEpoch(time.Now()), + Users: []string{"user1", "user2"}, + } + folderCopy := folder.GetACopy() + folder.ID = 2 + folder.Users = []string{"user3"} + require.Len(t, folderCopy.Users, 2) + require.True(t, slices.Contains(folderCopy.Users, "user1")) + require.True(t, slices.Contains(folderCopy.Users, "user2")) + require.Equal(t, int64(1), folderCopy.ID) + require.Equal(t, folder.Name, folderCopy.Name) + require.Equal(t, folder.MappedPath, folderCopy.MappedPath) + require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize) + require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles) + require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate) + + folder.FsConfig = vfs.Filesystem{ + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("crypto secret"), + }, + } + folderCopy = folder.GetACopy() + folder.FsConfig.CryptConfig.Passphrase = kms.NewEmptySecret() + require.Len(t, folderCopy.Users, 1) + require.True(t, slices.Contains(folderCopy.Users, "user3")) + require.Equal(t, int64(2), folderCopy.ID) + require.Equal(t, folder.Name, folderCopy.Name) + require.Equal(t, folder.MappedPath, folderCopy.MappedPath) + require.Equal(t, folder.UsedQuotaSize, folderCopy.UsedQuotaSize) + require.Equal(t, folder.UsedQuotaFiles, folderCopy.UsedQuotaFiles) + require.Equal(t, folder.LastQuotaUpdate, folderCopy.LastQuotaUpdate) + require.Equal(t, "crypto secret", folderCopy.FsConfig.CryptConfig.Passphrase.GetPayload()) +} + +func TestCachedFs(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + }, + } + conn := NewBaseConnection("id", ProtocolSFTP, "", "", user) + // changing the user should not affect the connection + user.HomeDir = filepath.Join(os.TempDir(), "temp") + err := os.Mkdir(user.HomeDir, os.ModePerm) + assert.NoError(t, err) + fs, err := user.GetFilesystem("") + assert.NoError(t, err) + p, err := fs.ResolvePath("/") + assert.NoError(t, err) + assert.Equal(t, user.GetHomeDir(), p) + + _, p, err = conn.GetFsAndResolvedPath("/") + assert.NoError(t, err) + assert.Equal(t, filepath.Clean(os.TempDir()), p) + // the filesystem is cached changing the provider will not affect the connection + conn.User.FsConfig.Provider = sdk.S3FilesystemProvider + _, p, err = conn.GetFsAndResolvedPath("/") + assert.NoError(t, err) + assert.Equal(t, filepath.Clean(os.TempDir()), p) + user = dataprovider.User{} + user.HomeDir = filepath.Join(os.TempDir(), "temp") + user.FsConfig.Provider = sdk.S3FilesystemProvider + _, err = user.GetFilesystem("") + assert.Error(t, err) + + err = os.Remove(user.HomeDir) + assert.NoError(t, err) +} + +func TestParseAllowedIPAndRanges(t *testing.T) { + _, err := util.ParseAllowedIPAndRanges([]string{"1.1.1.1", "not an ip"}) + assert.Error(t, err) + _, err = util.ParseAllowedIPAndRanges([]string{"1.1.1.5", "192.168.1.0/240"}) + assert.Error(t, err) + allow, err := util.ParseAllowedIPAndRanges([]string{"192.168.1.2", "172.16.0.0/24"}) + assert.NoError(t, err) + assert.True(t, allow[0](net.ParseIP("192.168.1.2"))) + assert.False(t, allow[0](net.ParseIP("192.168.2.2"))) + assert.True(t, allow[1](net.ParseIP("172.16.0.1"))) + assert.False(t, allow[1](net.ParseIP("172.16.1.1"))) +} + +func TestHideConfidentialData(_ *testing.T) { + for _, provider := range []sdk.FilesystemProvider{sdk.LocalFilesystemProvider, + sdk.CryptedFilesystemProvider, sdk.S3FilesystemProvider, sdk.GCSFilesystemProvider, + sdk.AzureBlobFilesystemProvider, sdk.SFTPFilesystemProvider, + } { + u := dataprovider.User{ + FsConfig: vfs.Filesystem{ + Provider: provider, + }, + } + u.PrepareForRendering() + f := vfs.BaseVirtualFolder{ + FsConfig: vfs.Filesystem{ + Provider: provider, + }, + } + f.PrepareForRendering() + } + a := dataprovider.Admin{} + a.HideConfidentialData() +} + +func TestUserPerms(t *testing.T) { + u := dataprovider.User{} + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermDelete} + assert.True(t, u.HasAnyPerm([]string{dataprovider.PermRename, dataprovider.PermDelete}, "/")) + assert.False(t, u.HasAnyPerm([]string{dataprovider.PermRename, dataprovider.PermCreateDirs}, "/")) + u.Permissions["/"] = []string{dataprovider.PermDelete, dataprovider.PermCreateDirs} + assert.True(t, u.HasPermsDeleteAll("/")) + assert.False(t, u.HasPermsRenameAll("/")) + u.Permissions["/"] = []string{dataprovider.PermDeleteDirs, dataprovider.PermDeleteFiles, dataprovider.PermRenameDirs} + assert.True(t, u.HasPermsDeleteAll("/")) + assert.False(t, u.HasPermsRenameAll("/")) + u.Permissions["/"] = []string{dataprovider.PermDeleteDirs, dataprovider.PermRenameFiles, dataprovider.PermRenameDirs} + assert.False(t, u.HasPermsDeleteAll("/")) + assert.True(t, u.HasPermsRenameAll("/")) +} + +func TestGetTLSVersion(t *testing.T) { + tlsVer := util.GetTLSVersion(0) + assert.Equal(t, uint16(tls.VersionTLS12), tlsVer) + tlsVer = util.GetTLSVersion(12) + assert.Equal(t, uint16(tls.VersionTLS12), tlsVer) + tlsVer = util.GetTLSVersion(2) + assert.Equal(t, uint16(tls.VersionTLS12), tlsVer) + tlsVer = util.GetTLSVersion(13) + assert.Equal(t, uint16(tls.VersionTLS13), tlsVer) +} + +func TestCleanPath(t *testing.T) { + assert.Equal(t, "/", util.CleanPath("/")) + assert.Equal(t, "/", util.CleanPath(".")) + assert.Equal(t, "/", util.CleanPath("")) + assert.Equal(t, "/", util.CleanPath("/.")) + assert.Equal(t, "/", util.CleanPath("/a/..")) + assert.Equal(t, "/a", util.CleanPath("/a/")) + assert.Equal(t, "/a", util.CleanPath("a/")) + // filepath.ToSlash does not touch \ as char on unix systems + // so os.PathSeparator is used for windows compatible tests + bslash := string(os.PathSeparator) + assert.Equal(t, "/", util.CleanPath(bslash)) + assert.Equal(t, "/", util.CleanPath(bslash+bslash)) + assert.Equal(t, "/a", util.CleanPath(bslash+"a"+bslash)) + assert.Equal(t, "/a", util.CleanPath("a"+bslash)) + assert.Equal(t, "/a/b/c", util.CleanPath(bslash+"a"+bslash+bslash+"b"+bslash+bslash+"c"+bslash)) + assert.Equal(t, "/C:/a", util.CleanPath("C:"+bslash+"a")) +} + +func TestUserRecentActivity(t *testing.T) { + u := dataprovider.User{} + res := u.HasRecentActivity() + assert.False(t, res) + u.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + res = u.HasRecentActivity() + assert.True(t, res) + u.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)) + res = u.HasRecentActivity() + assert.False(t, res) + u.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Second)) + res = u.HasRecentActivity() + assert.True(t, res) +} + +func TestVfsSameResource(t *testing.T) { + fs := vfs.Filesystem{} + other := vfs.Filesystem{} + res := fs.IsSameResource(other) + assert.True(t, res) + fs = vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "a", + Region: "b", + }, + }, + } + other = vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "a", + Region: "c", + }, + }, + } + res = fs.IsSameResource(other) + assert.False(t, res) + other = vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "a", + Region: "b", + }, + }, + } + res = fs.IsSameResource(other) + assert.True(t, res) + fs = vfs.Filesystem{ + Provider: sdk.GCSFilesystemProvider, + GCSConfig: vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: "b", + }, + }, + } + other = vfs.Filesystem{ + Provider: sdk.GCSFilesystemProvider, + GCSConfig: vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: "c", + }, + }, + } + res = fs.IsSameResource(other) + assert.False(t, res) + other = vfs.Filesystem{ + Provider: sdk.GCSFilesystemProvider, + GCSConfig: vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: "b", + }, + }, + } + res = fs.IsSameResource(other) + assert.True(t, res) + sasURL := kms.NewPlainSecret("http://127.0.0.1/sasurl") + fs = vfs.Filesystem{ + Provider: sdk.AzureBlobFilesystemProvider, + AzBlobConfig: vfs.AzBlobFsConfig{ + BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ + AccountName: "a", + }, + SASURL: sasURL, + }, + } + err := fs.Validate("data1") + assert.NoError(t, err) + other = vfs.Filesystem{ + Provider: sdk.AzureBlobFilesystemProvider, + AzBlobConfig: vfs.AzBlobFsConfig{ + BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ + AccountName: "a", + }, + SASURL: sasURL, + }, + } + err = other.Validate("data2") + assert.NoError(t, err) + err = fs.AzBlobConfig.SASURL.TryDecrypt() + assert.NoError(t, err) + err = other.AzBlobConfig.SASURL.TryDecrypt() + assert.NoError(t, err) + res = fs.IsSameResource(other) + assert.True(t, res) + fs.AzBlobConfig.AccountName = "b" + res = fs.IsSameResource(other) + assert.False(t, res) + fs.AzBlobConfig.AccountName = "a" + other.AzBlobConfig.SASURL = kms.NewPlainSecret("http://127.1.1.1/sasurl") + err = other.Validate("data2") + assert.NoError(t, err) + err = other.AzBlobConfig.SASURL.TryDecrypt() + assert.NoError(t, err) + res = fs.IsSameResource(other) + assert.False(t, res) + fs = vfs.Filesystem{ + Provider: sdk.HTTPFilesystemProvider, + HTTPConfig: vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "http://127.0.0.1/httpfs", + Username: "a", + }, + }, + } + other = vfs.Filesystem{ + Provider: sdk.HTTPFilesystemProvider, + HTTPConfig: vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "http://127.0.0.1/httpfs", + Username: "b", + }, + }, + } + res = fs.IsSameResource(other) + assert.True(t, res) + fs.HTTPConfig.EqualityCheckMode = 1 + res = fs.IsSameResource(other) + assert.False(t, res) +} + +func TestUpdateTransferTimestamps(t *testing.T) { + username := "user_test_timestamps" + user := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + err := dataprovider.AddUser(user, "", "", "") + assert.NoError(t, err) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, int64(0), user.FirstDownload) + + err = dataprovider.UpdateUserTransferTimestamps(username, true) + assert.NoError(t, err) + userGet, err := dataprovider.UserExists(username, "") + assert.NoError(t, err) + assert.Greater(t, userGet.FirstUpload, int64(0)) + assert.Equal(t, int64(0), user.FirstDownload) + err = dataprovider.UpdateUserTransferTimestamps(username, false) + assert.NoError(t, err) + userGet, err = dataprovider.UserExists(username, "") + assert.NoError(t, err) + assert.Greater(t, userGet.FirstUpload, int64(0)) + assert.Greater(t, userGet.FirstDownload, int64(0)) + // updating again must fail + err = dataprovider.UpdateUserTransferTimestamps(username, true) + assert.Error(t, err) + err = dataprovider.UpdateUserTransferTimestamps(username, false) + assert.Error(t, err) + // cleanup + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) +} + +func TestIPList(t *testing.T) { + type test struct { + ip string + protocol string + expectedMatch bool + expectedMode int + expectedErr bool + } + + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "192.168.0.0/25", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "192.168.0.128/25", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + Protocols: 3, + }, + { + IPOrNet: "192.168.2.128/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + Protocols: 5, + }, + { + IPOrNet: "::/0", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + Protocols: 4, + }, + { + IPOrNet: "2001:4860:4860::8888/120", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + Protocols: 1, + }, + { + IPOrNet: "2001:4860:4860::8988/120", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + Protocols: 3, + }, + { + IPOrNet: "::1/128", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + } + ipList, err := dataprovider.NewIPList(dataprovider.IPListTypeDefender) + require.NoError(t, err) + for idx := range entries { + e := entries[idx] + err := dataprovider.AddIPListEntry(&e, "", "", "") + assert.NoError(t, err) + } + tests := []test{ + {ip: "1.1.1.1", protocol: ProtocolSSH, expectedMatch: false, expectedMode: 0, expectedErr: false}, + {ip: "invalid ip", protocol: ProtocolSSH, expectedMatch: false, expectedMode: 0, expectedErr: true}, + {ip: "192.168.0.1", protocol: ProtocolFTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "192.168.0.2", protocol: ProtocolHTTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "192.168.0.3", protocol: ProtocolWebDAV, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "192.168.0.4", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "192.168.0.156", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, + {ip: "192.168.0.158", protocol: ProtocolFTP, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, + {ip: "192.168.0.158", protocol: ProtocolHTTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, + {ip: "192.168.2.128", protocol: ProtocolHTTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, + {ip: "192.168.2.128", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "::2", protocol: ProtocolSSH, expectedMatch: false, expectedMode: 0, expectedErr: false}, + {ip: "::2", protocol: ProtocolWebDAV, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, + {ip: "::1", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "::1", protocol: ProtocolHTTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "2001:4860:4860:0000:0000:0000:0000:8889", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeDeny, expectedErr: false}, + {ip: "2001:4860:4860:0000:0000:0000:0000:8889", protocol: ProtocolFTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, + {ip: "2001:4860:4860:0000:0000:0000:0000:8989", protocol: ProtocolFTP, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "2001:4860:4860:0000:0000:0000:0000:89F1", protocol: ProtocolSSH, expectedMatch: true, expectedMode: dataprovider.ListModeAllow, expectedErr: false}, + {ip: "2001:4860:4860:0000:0000:0000:0000:89F1", protocol: ProtocolHTTP, expectedMatch: false, expectedMode: 0, expectedErr: false}, + } + + for _, tc := range tests { + match, mode, err := ipList.IsListed(tc.ip, tc.protocol) + if tc.expectedErr { + assert.Error(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) + } else { + assert.NoError(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) + } + assert.Equal(t, tc.expectedMatch, match, "ip %s, protocol %s", tc.ip, tc.protocol) + assert.Equal(t, tc.expectedMode, mode, "ip %s, protocol %s", tc.ip, tc.protocol) + } + + ipList.DisableMemoryMode() + + for _, tc := range tests { + match, mode, err := ipList.IsListed(tc.ip, tc.protocol) + if tc.expectedErr { + assert.Error(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) + } else { + assert.NoError(t, err, "ip %s, protocol %s", tc.ip, tc.protocol) + } + assert.Equal(t, tc.expectedMatch, match, "ip %s, protocol %s", tc.ip, tc.protocol) + assert.Equal(t, tc.expectedMode, mode, "ip %s, protocol %s", tc.ip, tc.protocol) + } + + for _, e := range entries { + err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") + assert.NoError(t, err) + } +} + +func TestSQLPlaceholderLimits(t *testing.T) { + numGroups := 120 + numUsers := 120 + var groupMapping []sdk.GroupMapping + + folder := vfs.BaseVirtualFolder{ + Name: "testfolder", + MappedPath: filepath.Join(os.TempDir(), "folder"), + } + err := dataprovider.AddFolder(&folder, "", "", "") + assert.NoError(t, err) + + for i := 0; i < numGroups; i++ { + group := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: fmt.Sprintf("testgroup%d", i), + }, + UserSettings: dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + Permissions: map[string][]string{ + fmt.Sprintf("/dir%d", i): {dataprovider.PermAny}, + }, + }, + }, + } + group.VirtualFolders = append(group.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: folder, + VirtualPath: "/vdir", + }) + err := dataprovider.AddGroup(&group, "", "", "") + assert.NoError(t, err) + + groupMapping = append(groupMapping, sdk.GroupMapping{ + Name: group.Name, + Type: sdk.GroupTypeSecondary, + }) + } + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testusername", + HomeDir: filepath.Join(os.TempDir(), "testhome"), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + Groups: groupMapping, + } + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + users, err := dataprovider.GetUsersForQuotaCheck(map[string]bool{user.Username: true}) + assert.NoError(t, err) + if assert.Len(t, users, 1) { + for i := 0; i < numGroups; i++ { + _, ok := users[0].Permissions[fmt.Sprintf("/dir%d", i)] + assert.True(t, ok) + } + } + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + + for i := 0; i < numUsers; i++ { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: fmt.Sprintf("testusername%d", i), + HomeDir: filepath.Join(os.TempDir()), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + Groups: []sdk.GroupMapping{ + { + Name: "testgroup0", + Type: sdk.GroupTypePrimary, + }, + }, + } + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + } + + time.Sleep(100 * time.Millisecond) + + err = dataprovider.DeleteFolder(folder.Name, "", "", "") + assert.NoError(t, err) + + for i := 0; i < numUsers; i++ { + username := fmt.Sprintf("testusername%d", i) + user, err := dataprovider.UserExists(username, "") + assert.NoError(t, err) + assert.Greater(t, user.UpdatedAt, user.CreatedAt) + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + } + + for i := 0; i < numGroups; i++ { + groupName := fmt.Sprintf("testgroup%d", i) + err = dataprovider.DeleteGroup(groupName, "", "", "") + assert.NoError(t, err) + } +} + +func TestALPNProtocols(t *testing.T) { + protocols := util.GetALPNProtocols(nil) + assert.Equal(t, []string{"http/1.1", "h2"}, protocols) + protocols = util.GetALPNProtocols([]string{"invalid1", "invalid2"}) + assert.Equal(t, []string{"http/1.1", "h2"}, protocols) + protocols = util.GetALPNProtocols([]string{"invalid1", "h2", "invalid2"}) + assert.Equal(t, []string{"h2"}, protocols) + protocols = util.GetALPNProtocols([]string{"h2", "http/1.1"}) + assert.Equal(t, []string{"h2", "http/1.1"}, protocols) +} + +func TestServerVersion(t *testing.T) { + appName := "SFTPGo" + version.SetConfig("") + v := version.GetServerVersion("_", false) + assert.Equal(t, fmt.Sprintf("%s_%s", appName, version.Get().Version), v) + v = version.GetServerVersion("-", true) + assert.Equal(t, fmt.Sprintf("%s-%s-", appName, version.Get().Version), v) + version.SetConfig("short") + v = version.GetServerVersion("_", false) + assert.Equal(t, appName, v) + v = version.GetServerVersion("_", true) + assert.Equal(t, appName+"_", v) + version.SetConfig("") +} + +func BenchmarkBcryptHashing(b *testing.B) { + bcryptPassword := "bcryptpassword" + for i := 0; i < b.N; i++ { + _, err := bcrypt.GenerateFromPassword([]byte(bcryptPassword), 10) + if err != nil { + panic(err) + } + } +} + +func BenchmarkCompareBcryptPassword(b *testing.B) { + bcryptPassword := "$2a$10$lPDdnDimJZ7d5/GwL6xDuOqoZVRXok6OHHhivCnanWUtcgN0Zafki" + for i := 0; i < b.N; i++ { + err := bcrypt.CompareHashAndPassword([]byte(bcryptPassword), []byte("password")) + if err != nil { + panic(err) + } + } +} + +func BenchmarkArgon2Hashing(b *testing.B) { + argonPassword := "argon2password" + for i := 0; i < b.N; i++ { + _, err := argon2id.CreateHash(argonPassword, argon2id.DefaultParams) + if err != nil { + panic(err) + } + } +} + +func BenchmarkCompareArgon2Password(b *testing.B) { + argon2Password := "$argon2id$v=19$m=65536,t=1,p=2$aOoAOdAwvzhOgi7wUFjXlw$wn/y37dBWdKHtPXHR03nNaKHWKPXyNuVXOknaU+YZ+s" + for i := 0; i < b.N; i++ { + _, err := argon2id.ComparePasswordAndHash("password", argon2Password) + if err != nil { + panic(err) + } + } +} + +func BenchmarkAddRemoveConnections(b *testing.B) { + var conns []ActiveConnection + for i := 0; i < 100; i++ { + conns = append(conns, &fakeConnection{ + BaseConnection: NewBaseConnection(fmt.Sprintf("id%d", i), ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + }, + }), + }) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, c := range conns { + if err := Connections.Add(c); err != nil { + panic(err) + } + } + var wg sync.WaitGroup + for idx := len(conns) - 1; idx >= 0; idx-- { + wg.Add(1) + go func(index int) { + defer wg.Done() + Connections.Remove(conns[index].GetID()) + }(idx) + } + wg.Wait() + } +} + +func BenchmarkAddRemoveSSHConnections(b *testing.B) { + conn1, conn2 := net.Pipe() + var conns []*SSHConnection + for i := 0; i < 2000; i++ { + conns = append(conns, NewSSHConnection(fmt.Sprintf("id%d", i), conn1)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, c := range conns { + Connections.AddSSHConnection(c) + } + for idx := len(conns) - 1; idx >= 0; idx-- { + Connections.RemoveSSHConnection(conns[idx].GetID()) + } + } + conn1.Close() + conn2.Close() +} diff --git a/internal/common/connection.go b/internal/common/connection.go new file mode 100644 index 00000000..4bbc89f0 --- /dev/null +++ b/internal/common/connection.go @@ -0,0 +1,1914 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "path" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + ftpserver "github.com/fclairamb/ftpserverlib" + "github.com/pkg/sftp" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// BaseConnection defines common fields for a connection using any supported protocol +type BaseConnection struct { + // last activity for this connection. + // Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment + lastActivity atomic.Int64 + uploadDone atomic.Bool + downloadDone atomic.Bool + // unique ID for a transfer. + // This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment + transferID atomic.Int64 + // Unique identifier for the connection + ID string + // user associated with this connection if any + User dataprovider.User + // start time for this connection + startTime time.Time + protocol string + remoteAddr string + localAddr string + sync.RWMutex + activeTransfers []ActiveTransfer +} + +// NewBaseConnection returns a new BaseConnection +func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprovider.User) *BaseConnection { + connID := id + if slices.Contains(supportedProtocols, protocol) { + connID = fmt.Sprintf("%s_%s", protocol, id) + } + user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID) + c := &BaseConnection{ + ID: connID, + User: user, + startTime: time.Now(), + protocol: protocol, + localAddr: localAddr, + remoteAddr: remoteAddr, + } + c.transferID.Store(0) + c.lastActivity.Store(time.Now().UnixNano()) + + return c +} + +// Log outputs a log entry to the configured logger +func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) { + logger.Log(level, c.protocol, c.ID, format, v...) +} + +// GetTransferID returns an unique transfer ID for this connection +func (c *BaseConnection) GetTransferID() int64 { + return c.transferID.Add(1) +} + +// GetID returns the connection ID +func (c *BaseConnection) GetID() string { + return c.ID +} + +// GetUsername returns the authenticated username associated with this connection if any +func (c *BaseConnection) GetUsername() string { + return c.User.Username +} + +// GetRole returns the role for the user associated with this connection +func (c *BaseConnection) GetRole() string { + return c.User.Role +} + +// GetMaxSessions returns the maximum number of concurrent sessions allowed +func (c *BaseConnection) GetMaxSessions() int { + return c.User.MaxSessions +} + +// isAccessAllowed returns true if the user's access conditions are met +func (c *BaseConnection) isAccessAllowed() bool { + if err := c.User.CheckLoginConditions(); err != nil { + return false + } + return true +} + +// GetProtocol returns the protocol for the connection +func (c *BaseConnection) GetProtocol() string { + return c.protocol +} + +// GetRemoteIP returns the remote ip address +func (c *BaseConnection) GetRemoteIP() string { + return util.GetIPFromRemoteAddress(c.remoteAddr) +} + +// SetProtocol sets the protocol for this connection +func (c *BaseConnection) SetProtocol(protocol string) { + c.protocol = protocol + if slices.Contains(supportedProtocols, c.protocol) { + c.ID = fmt.Sprintf("%v_%v", c.protocol, c.ID) + } +} + +// GetConnectionTime returns the initial connection time +func (c *BaseConnection) GetConnectionTime() time.Time { + return c.startTime +} + +// UpdateLastActivity updates last activity for this connection +func (c *BaseConnection) UpdateLastActivity() { + c.lastActivity.Store(time.Now().UnixNano()) +} + +// GetLastActivity returns the last connection activity +func (c *BaseConnection) GetLastActivity() time.Time { + return time.Unix(0, c.lastActivity.Load()) +} + +// CloseFS closes the underlying fs +func (c *BaseConnection) CloseFS() error { + return c.User.CloseFs() +} + +// AddTransfer associates a new transfer to this connection +func (c *BaseConnection) AddTransfer(t ActiveTransfer) { + Connections.transfers.add(c.User.Username) + + c.Lock() + defer c.Unlock() + + c.activeTransfers = append(c.activeTransfers, t) + c.Log(logger.LevelDebug, "transfer added, id: %v, active transfers: %v", t.GetID(), len(c.activeTransfers)) + if t.HasSizeLimit() { + folderName := "" + if t.GetType() == TransferUpload { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(t.GetVirtualPath())) + if err == nil { + if !vfolder.IsIncludedInUserQuota() { + folderName = vfolder.Name + } + } + } + go transfersChecker.AddTransfer(dataprovider.ActiveTransfer{ + ID: t.GetID(), + Type: t.GetType(), + ConnID: c.ID, + Username: c.GetUsername(), + FolderName: folderName, + IP: c.GetRemoteIP(), + TruncatedSize: t.GetTruncatedSize(), + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + } +} + +// RemoveTransfer removes the specified transfer from the active ones +func (c *BaseConnection) RemoveTransfer(t ActiveTransfer) { + Connections.transfers.remove(c.User.Username) + + c.Lock() + defer c.Unlock() + + if t.HasSizeLimit() { + go transfersChecker.RemoveTransfer(t.GetID(), c.ID) + } + + for idx, transfer := range c.activeTransfers { + if transfer.GetID() == t.GetID() { + lastIdx := len(c.activeTransfers) - 1 + c.activeTransfers[idx] = c.activeTransfers[lastIdx] + c.activeTransfers[lastIdx] = nil + c.activeTransfers = c.activeTransfers[:lastIdx] + c.Log(logger.LevelDebug, "transfer removed, id: %v active transfers: %v", t.GetID(), len(c.activeTransfers)) + return + } + } + c.Log(logger.LevelWarn, "transfer to remove with id %v not found!", t.GetID()) +} + +// SignalTransferClose makes the transfer fail on the next read/write with the +// specified error +func (c *BaseConnection) SignalTransferClose(transferID int64, err error) { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + if t.GetID() == transferID { + c.Log(logger.LevelInfo, "signal transfer close for transfer id %v", transferID) + t.SignalClose(err) + } + } +} + +// GetTransfers returns the active transfers +func (c *BaseConnection) GetTransfers() []ConnectionTransfer { + c.RLock() + defer c.RUnlock() + + transfers := make([]ConnectionTransfer, 0, len(c.activeTransfers)) + for _, t := range c.activeTransfers { + var operationType string + switch t.GetType() { + case TransferDownload: + operationType = operationDownload + case TransferUpload: + operationType = operationUpload + } + transfers = append(transfers, ConnectionTransfer{ + ID: t.GetID(), + OperationType: operationType, + StartTime: util.GetTimeAsMsSinceEpoch(t.GetStartTime()), + Size: t.GetSize(), + VirtualPath: t.GetVirtualPath(), + HasSizeLimit: t.HasSizeLimit(), + ULSize: t.GetUploadedSize(), + DLSize: t.GetDownloadedSize(), + }) + } + + return transfers +} + +// SignalTransfersAbort signals to the active transfers to exit as soon as possible +func (c *BaseConnection) SignalTransfersAbort() error { + c.RLock() + defer c.RUnlock() + + if len(c.activeTransfers) == 0 { + return errors.New("no active transfer found") + } + + for _, t := range c.activeTransfers { + t.SignalClose(ErrTransferAborted) + } + return nil +} + +func (c *BaseConnection) getRealFsPath(fsPath string) string { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + if p := t.GetRealFsPath(fsPath); p != "" { + return p + } + } + return fsPath +} + +func (c *BaseConnection) setTimes(fsPath string, atime time.Time, mtime time.Time) bool { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + if t.SetTimes(fsPath, atime, mtime) { + return true + } + } + return false +} + +// getInfoForOngoingUpload returns upload statistics for an upload currently in +// progress on this connection. +func (c *BaseConnection) getInfoForOngoingUpload(fsPath string) (os.FileInfo, error) { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + if t.GetType() == TransferUpload && t.GetFsPath() == fsPath { + return vfs.NewFileInfo(t.GetVirtualPath(), false, t.GetSize(), t.GetStartTime(), false), nil + } + } + return nil, os.ErrNotExist +} + +func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, error) { + c.RLock() + defer c.RUnlock() + + for _, t := range c.activeTransfers { + initialSize, err := t.Truncate(fsPath, size) + if err != errTransferMismatch { + return initialSize, err + } + } + + return 0, errNoTransfer +} + +// ListDir reads the directory matching virtualPath and returns a list of directory entries +func (c *BaseConnection) ListDir(virtualPath string) (*DirListerAt, error) { + if !c.User.HasPerm(dataprovider.PermListItems, virtualPath) { + return nil, c.GetPermissionDeniedError() + } + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return nil, err + } + lister, err := fs.ReadDir(fsPath) + if err != nil { + c.Log(logger.LevelDebug, "error listing directory: %+v", err) + return nil, c.GetFsError(fs, err) + } + return &DirListerAt{ + virtualPath: virtualPath, + conn: c, + fs: fs, + info: c.User.GetVirtualFoldersInfo(virtualPath), + lister: lister, + }, nil +} + +// CheckParentDirs tries to create the specified directory and any missing parent dirs +func (c *BaseConnection) CheckParentDirs(virtualPath string) error { + fs, err := c.User.GetFilesystemForPath(virtualPath, c.GetID()) + if err != nil { + return err + } + if fs.HasVirtualFolders() { + return nil + } + if _, err := c.DoStat(virtualPath, 0, false); !c.IsNotExistError(err) { + return err + } + dirs := util.GetDirsForVirtualPath(virtualPath) + for idx := len(dirs) - 1; idx >= 0; idx-- { + fs, err = c.User.GetFilesystemForPath(dirs[idx], c.GetID()) + if err != nil { + return err + } + if fs.HasVirtualFolders() { + continue + } + if err = c.createDirIfMissing(dirs[idx]); err != nil { + return fmt.Errorf("unable to check/create missing parent dir %q for virtual path %q: %w", + dirs[idx], virtualPath, err) + } + } + return nil +} + +// GetCreateChecks returns the checks for creating new files +func (c *BaseConnection) GetCreateChecks(virtualPath string, isNewFile bool, isResume bool) int { + result := 0 + if !isNewFile { + if isResume { + result += vfs.CheckResume + } + return result + } + if !c.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(virtualPath)) { + result += vfs.CheckParentDir + return result + } + return result +} + +// CreateDir creates a new directory at the specified fsPath +func (c *BaseConnection) CreateDir(virtualPath string, checkFilePatterns bool) error { + if !c.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(virtualPath)) { + return c.GetPermissionDeniedError() + } + if checkFilePatterns { + if ok, _ := c.User.IsFileAllowed(virtualPath); !ok { + return c.GetPermissionDeniedError() + } + } + if c.User.IsVirtualFolder(virtualPath) { + c.Log(logger.LevelWarn, "mkdir not allowed %q is a virtual folder", virtualPath) + return c.GetPermissionDeniedError() + } + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + startTime := time.Now() + if err := fs.Mkdir(fsPath); err != nil { + c.Log(logger.LevelError, "error creating dir: %q error: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + vfs.SetPathPermissions(fs, fsPath, c.User.GetUID(), c.User.GetGID()) + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + + logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, + c.localAddr, c.remoteAddr, elapsed) + ExecuteActionNotification(c, operationMkdir, fsPath, virtualPath, "", "", "", 0, nil, elapsed, nil) //nolint:errcheck + return nil +} + +// IsRemoveFileAllowed returns an error if removing this file is not allowed +func (c *BaseConnection) IsRemoveFileAllowed(virtualPath string) error { + if !c.User.HasAnyPerm([]string{dataprovider.PermDeleteFiles, dataprovider.PermDelete}, path.Dir(virtualPath)) { + return c.GetPermissionDeniedError() + } + if ok, policy := c.User.IsFileAllowed(virtualPath); !ok { + c.Log(logger.LevelDebug, "removing file %q is not allowed", virtualPath) + return c.GetErrorForDeniedFile(policy) + } + return nil +} + +// RemoveFile removes a file at the specified fsPath +func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo) error { + if err := c.IsRemoveFileAllowed(virtualPath); err != nil { + return err + } + + size := info.Size() + status, err := ExecutePreAction(c, operationPreDelete, fsPath, virtualPath, size, 0) + if err != nil { + c.Log(logger.LevelDebug, "delete for file %q denied by pre action: %v", virtualPath, err) + return c.GetPermissionDeniedError() + } + updateQuota := true + startTime := time.Now() + if err := fs.Remove(fsPath, false); err != nil { + if status > 0 && fs.IsNotExist(err) { + // file removed in the pre-action, if the file was deleted from the EventManager the quota is already updated + c.Log(logger.LevelDebug, "file deleted from the hook, status: %d", status) + updateQuota = (status == 1) + } else { + c.Log(logger.LevelError, "failed to remove file/symlink %q: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + + logger.CommandLog(removeLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, + c.localAddr, c.remoteAddr, elapsed) + if updateQuota && info.Mode()&os.ModeSymlink == 0 { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, -1, -size, false) + } else { + dataprovider.UpdateUserQuota(&c.User, -1, -size, false) //nolint:errcheck + } + } + ExecuteActionNotification(c, operationDelete, fsPath, virtualPath, "", "", "", size, nil, elapsed, nil) //nolint:errcheck + return nil +} + +// IsRemoveDirAllowed returns an error if removing this directory is not allowed +func (c *BaseConnection) IsRemoveDirAllowed(fs vfs.Fs, fsPath, virtualPath string) error { + if virtualPath == "/" || fs.GetRelativePath(fsPath) == "/" { + c.Log(logger.LevelWarn, "removing root dir is not allowed") + return c.GetPermissionDeniedError() + } + if c.User.IsVirtualFolder(virtualPath) { + c.Log(logger.LevelWarn, "removing a virtual folder is not allowed: %q", virtualPath) + return fmt.Errorf("removing virtual folders is not allowed: %w", c.GetPermissionDeniedError()) + } + if c.User.HasVirtualFoldersInside(virtualPath) { + c.Log(logger.LevelWarn, "removing a directory with a virtual folder inside is not allowed: %q", virtualPath) + return fmt.Errorf("cannot remove directory %q with virtual folders inside: %w", virtualPath, c.GetOpUnsupportedError()) + } + if c.User.IsMappedPath(fsPath) { + c.Log(logger.LevelWarn, "removing a directory mapped as virtual folder is not allowed: %q", fsPath) + return fmt.Errorf("removing the directory %q mapped as virtual folder is not allowed: %w", + virtualPath, c.GetPermissionDeniedError()) + } + if !c.User.HasAnyPerm([]string{dataprovider.PermDeleteDirs, dataprovider.PermDelete}, path.Dir(virtualPath)) { + return c.GetPermissionDeniedError() + } + if ok, policy := c.User.IsFileAllowed(virtualPath); !ok { + c.Log(logger.LevelDebug, "removing directory %q is not allowed", virtualPath) + return c.GetErrorForDeniedFile(policy) + } + return nil +} + +// RemoveDir removes a directory at the specified fsPath +func (c *BaseConnection) RemoveDir(virtualPath string) error { + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { + return err + } + + var fi os.FileInfo + if fi, err = fs.Lstat(fsPath); err != nil { + // see #149 + if fs.IsNotExist(err) && fs.HasVirtualFolders() { + return nil + } + c.Log(logger.LevelError, "failed to remove a dir %q: stat error: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + if !fi.IsDir() || fi.Mode()&os.ModeSymlink != 0 { + c.Log(logger.LevelError, "cannot remove %q is not a directory", fsPath) + return c.GetGenericError(nil) + } + + startTime := time.Now() + if err := fs.Remove(fsPath, true); err != nil { + c.Log(logger.LevelError, "failed to remove directory %q: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + + logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, + c.localAddr, c.remoteAddr, elapsed) + ExecuteActionNotification(c, operationRmdir, fsPath, virtualPath, "", "", "", 0, nil, elapsed, nil) //nolint:errcheck + return nil +} + +func (c *BaseConnection) doRecursiveRemoveDirEntry(virtualPath string, info os.FileInfo, recursion int) error { + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + return c.doRecursiveRemove(fs, fsPath, virtualPath, info, recursion) +} + +func (c *BaseConnection) doRecursiveRemove(fs vfs.Fs, fsPath, virtualPath string, info os.FileInfo, recursion int) error { + if info.IsDir() { + if recursion >= util.MaxRecursion { + c.Log(logger.LevelError, "recursive rename failed, recursion too depth: %d", recursion) + return util.ErrRecursionTooDeep + } + recursion++ + lister, err := c.ListDir(virtualPath) + if err != nil { + return fmt.Errorf("unable to get lister for dir %q: %w", virtualPath, err) + } + defer lister.Close() + + for { + entries, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return fmt.Errorf("unable to get content for dir %q: %w", virtualPath, err) + } + for _, fi := range entries { + targetPath := path.Join(virtualPath, fi.Name()) + if err := c.doRecursiveRemoveDirEntry(targetPath, fi, recursion); err != nil { + return err + } + } + if finished { + lister.Close() + break + } + } + return c.RemoveDir(virtualPath) + } + return c.RemoveFile(fs, fsPath, virtualPath, info) +} + +// RemoveAll removes the specified path and any children it contains +func (c *BaseConnection) RemoveAll(virtualPath string) error { + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + + fi, err := fs.Lstat(fsPath) + if err != nil { + c.Log(logger.LevelDebug, "failed to remove path %q: stat error: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { + if err := c.IsRemoveDirAllowed(fs, fsPath, virtualPath); err != nil { + return err + } + return c.doRecursiveRemove(fs, fsPath, virtualPath, fi, 0) + } + return c.RemoveFile(fs, fsPath, virtualPath, fi) +} + +func (c *BaseConnection) checkCopy(srcInfo, dstInfo os.FileInfo, virtualSource, virtualTarget string) error { + _, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSource) + if err != nil { + return err + } + _, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTarget) + if err != nil { + return err + } + if srcInfo.IsDir() { + if dstInfo != nil && !dstInfo.IsDir() { + return fmt.Errorf("cannot overwrite file %q with dir %q: %w", virtualTarget, virtualSource, c.GetOpUnsupportedError()) + } + if util.IsDirOverlapped(virtualSource, virtualTarget, true, "/") { + return fmt.Errorf("nested copy %q => %q is not supported: %w", virtualSource, virtualTarget, c.GetOpUnsupportedError()) + } + if util.IsDirOverlapped(fsSourcePath, fsTargetPath, true, c.User.FsConfig.GetPathSeparator()) { + c.Log(logger.LevelWarn, "nested fs copy %q => %q not allowed", fsSourcePath, fsTargetPath) + return fmt.Errorf("nested fs copy is not supported: %w", c.GetOpUnsupportedError()) + } + return nil + } + if dstInfo != nil && dstInfo.IsDir() { + return fmt.Errorf("cannot overwrite file %q with dir %q: %w", virtualSource, virtualTarget, c.GetOpUnsupportedError()) + } + if c.IsSameResource(virtualSource, virtualTarget) { + if fsSourcePath == fsTargetPath { + return fmt.Errorf("the copy source and target cannot be the same: %w", c.GetOpUnsupportedError()) + } + } + return nil +} + +func (c *BaseConnection) copyFile(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo) error { + if !c.User.HasPerm(dataprovider.PermCopy, virtualSourcePath) || !c.User.HasPerm(dataprovider.PermCopy, virtualTargetPath) { + return c.GetPermissionDeniedError() + } + if ok, _ := c.User.IsFileAllowed(virtualTargetPath); !ok { + return fmt.Errorf("file %q is not allowed: %w", virtualTargetPath, c.GetPermissionDeniedError()) + } + if c.IsSameResource(virtualSourcePath, virtualTargetPath) { + fs, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath) + if err != nil { + return err + } + if copier, ok := fs.(vfs.FsFileCopier); ok { + _, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) + if err != nil { + return err + } + startTime := time.Now() + numFiles, sizeDiff, err := copier.CopyFile(fsSourcePath, fsTargetPath, srcInfo) + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + updateUserQuotaAfterFileWrite(c, virtualTargetPath, numFiles, sizeDiff) + logger.CommandLog(copyLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, + "", "", "", srcInfo.Size(), c.localAddr, c.remoteAddr, elapsed) + ExecuteActionNotification(c, operationCopy, fsSourcePath, virtualSourcePath, fsTargetPath, virtualTargetPath, "", srcInfo.Size(), err, elapsed, nil) //nolint:errcheck + return err + } + } + + reader, rCancelFn, err := getFileReader(c, virtualSourcePath) + if err != nil { + return fmt.Errorf("unable to get reader for path %q: %w", virtualSourcePath, err) + } + defer rCancelFn() + defer reader.Close() + + writer, numFiles, truncatedSize, wCancelFn, err := getFileWriter(c, virtualTargetPath, srcInfo.Size()) + if err != nil { + return fmt.Errorf("unable to get writer for path %q: %w", virtualTargetPath, err) + } + defer wCancelFn() + + startTime := time.Now() + _, err = io.Copy(writer, reader) + return closeWriterAndUpdateQuota(writer, c, virtualSourcePath, virtualTargetPath, numFiles, truncatedSize, + err, operationCopy, startTime) +} + +func (c *BaseConnection) doRecursiveCopy(virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, + createTargetDir bool, recursion int, +) error { + if srcInfo.IsDir() { + if recursion >= util.MaxRecursion { + c.Log(logger.LevelError, "recursive copy failed, recursion too depth: %d", recursion) + return util.ErrRecursionTooDeep + } + recursion++ + if createTargetDir { + if err := c.CreateDir(virtualTargetPath, false); err != nil { + return fmt.Errorf("unable to create directory %q: %w", virtualTargetPath, err) + } + } + lister, err := c.ListDir(virtualSourcePath) + if err != nil { + return fmt.Errorf("unable to get lister for dir %q: %w", virtualSourcePath, err) + } + defer lister.Close() + + for { + entries, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return fmt.Errorf("unable to get contents for dir %q: %w", virtualSourcePath, err) + } + if err := c.recursiveCopyEntries(virtualSourcePath, virtualTargetPath, entries, recursion); err != nil { + return err + } + if finished { + return nil + } + } + } + if !srcInfo.Mode().IsRegular() { + c.Log(logger.LevelInfo, "skipping copy for non regular file %q", virtualSourcePath) + return nil + } + + return c.copyFile(virtualSourcePath, virtualTargetPath, srcInfo) +} + +func (c *BaseConnection) recursiveCopyEntries(virtualSourcePath, virtualTargetPath string, entries []os.FileInfo, recursion int) error { + for _, info := range entries { + sourcePath := path.Join(virtualSourcePath, info.Name()) + targetPath := path.Join(virtualTargetPath, info.Name()) + targetInfo, err := c.DoStat(targetPath, 1, false) + if err == nil { + if info.IsDir() && targetInfo.IsDir() { + c.Log(logger.LevelDebug, "target copy dir %q already exists", targetPath) + continue + } + } + if err != nil && !c.IsNotExistError(err) { + return err + } + if err := c.checkCopy(info, targetInfo, sourcePath, targetPath); err != nil { + return err + } + if err := c.doRecursiveCopy(sourcePath, targetPath, info, true, recursion); err != nil { + if c.IsNotExistError(err) { + c.Log(logger.LevelInfo, "skipping copy for source path %q: %v", sourcePath, err) + continue + } + return err + } + } + return nil +} + +// Copy virtualSourcePath to virtualTargetPath +func (c *BaseConnection) Copy(virtualSourcePath, virtualTargetPath string) error { + copyFromSource := strings.HasSuffix(virtualSourcePath, "/") + copyInTarget := strings.HasSuffix(virtualTargetPath, "/") + virtualSourcePath = path.Clean(virtualSourcePath) + virtualTargetPath = path.Clean(virtualTargetPath) + if virtualSourcePath == virtualTargetPath { + return fmt.Errorf("the copy source and target cannot be the same: %w", c.GetOpUnsupportedError()) + } + srcInfo, err := c.DoStat(virtualSourcePath, 1, false) + if err != nil { + return err + } + if srcInfo.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("copying symlinks is not supported: %w", c.GetOpUnsupportedError()) + } + dstInfo, err := c.DoStat(virtualTargetPath, 1, false) + if err == nil && !copyFromSource { + copyInTarget = dstInfo.IsDir() + } + if err != nil && !c.IsNotExistError(err) { + return err + } + destPath := virtualTargetPath + if copyInTarget { + destPath = path.Join(virtualTargetPath, path.Base(virtualSourcePath)) + dstInfo, err = c.DoStat(destPath, 1, false) + if err != nil && !c.IsNotExistError(err) { + return err + } + } + createTargetDir := dstInfo == nil || !dstInfo.IsDir() + if err := c.checkCopy(srcInfo, dstInfo, virtualSourcePath, destPath); err != nil { + return err + } + if err := c.CheckParentDirs(path.Dir(destPath)); err != nil { + return err + } + stopKeepAlive := keepConnectionAlive(c, 2*time.Minute) + defer stopKeepAlive() + + return c.doRecursiveCopy(virtualSourcePath, destPath, srcInfo, createTargetDir, 0) +} + +// Rename renames (moves) virtualSourcePath to virtualTargetPath +func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) error { + return c.renameInternal(virtualSourcePath, virtualTargetPath, false, vfs.CheckParentDir) +} + +func (c *BaseConnection) renameInternal(virtualSourcePath, virtualTargetPath string, //nolint:gocyclo + checkParentDestination bool, checks int, +) error { + if virtualSourcePath == virtualTargetPath { + return fmt.Errorf("the rename source and target cannot be the same: %w", c.GetOpUnsupportedError()) + } + fsSrc, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) + if err != nil { + return err + } + fsDst, fsTargetPath, err := c.GetFsAndResolvedPath(virtualTargetPath) + if err != nil { + return err + } + startTime := time.Now() + srcInfo, err := fsSrc.Lstat(fsSourcePath) + if err != nil { + return c.GetFsError(fsSrc, err) + } + if !c.isRenamePermitted(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo) { + return c.GetPermissionDeniedError() + } + initialSize := int64(-1) + dstInfo, err := fsDst.Lstat(fsTargetPath) + if err != nil && !fsDst.IsNotExist(err) { + return err + } + if err == nil { + checkParentDestination = false + if dstInfo.IsDir() { + c.Log(logger.LevelWarn, "attempted to rename %q overwriting an existing directory %q", + fsSourcePath, fsTargetPath) + return c.GetOpUnsupportedError() + } + // we are overwriting an existing file/symlink + if dstInfo.Mode().IsRegular() { + initialSize = dstInfo.Size() + } + if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualTargetPath)) { + c.Log(logger.LevelDebug, "renaming %q -> %q is not allowed. Target exists but the user %q"+ + "has no overwrite permission", virtualSourcePath, virtualTargetPath, c.User.Username) + return c.GetPermissionDeniedError() + } + } + if srcInfo.IsDir() { + if err := c.checkFolderRename(fsSrc, fsDst, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath, srcInfo); err != nil { + return err + } + } + if !c.hasSpaceForRename(fsSrc, virtualSourcePath, virtualTargetPath, initialSize, fsSourcePath, srcInfo) { + c.Log(logger.LevelInfo, "denying cross rename due to space limit") + return c.GetGenericError(ErrQuotaExceeded) + } + if checkParentDestination { + c.CheckParentDirs(path.Dir(virtualTargetPath)) //nolint:errcheck + } + stopKeepAlive := keepConnectionAlive(c, 2*time.Minute) + defer stopKeepAlive() + + files, size, err := fsDst.Rename(fsSourcePath, fsTargetPath, checks) + if err != nil { + c.Log(logger.LevelError, "failed to rename %q -> %q: %+v", fsSourcePath, fsTargetPath, err) + return c.GetFsError(fsSrc, err) + } + vfs.SetPathPermissions(fsDst, fsTargetPath, c.User.GetUID(), c.User.GetGID()) + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize, files, size) //nolint:errcheck + logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, + "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) + ExecuteActionNotification(c, operationRename, fsSourcePath, virtualSourcePath, fsTargetPath, //nolint:errcheck + virtualTargetPath, "", 0, nil, elapsed, nil) + + return nil +} + +// CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath +func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error { + var relativePath string + if !path.IsAbs(virtualSourcePath) { + relativePath = virtualSourcePath + virtualSourcePath = path.Join(path.Dir(virtualTargetPath), relativePath) + c.Log(logger.LevelDebug, "link relative path %q resolved as %q, target path %q", + relativePath, virtualSourcePath, virtualTargetPath) + } + if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) { + c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath) + return c.GetOpUnsupportedError() + } + // we cannot have a cross folder request here so only one fs is enough + fs, fsSourcePath, err := c.GetFsAndResolvedPath(virtualSourcePath) + if err != nil { + return err + } + fsTargetPath, err := fs.ResolvePath(virtualTargetPath) + if err != nil { + return c.GetFsError(fs, err) + } + if fs.GetRelativePath(fsSourcePath) == "/" { + c.Log(logger.LevelError, "symlinking root dir is not allowed") + return c.GetPermissionDeniedError() + } + if fs.GetRelativePath(fsTargetPath) == "/" { + c.Log(logger.LevelError, "symlinking to root dir is not allowed") + return c.GetPermissionDeniedError() + } + if !c.User.HasPerm(dataprovider.PermCreateSymlinks, path.Dir(virtualTargetPath)) { + return c.GetPermissionDeniedError() + } + ok, policy := c.User.IsFileAllowed(virtualSourcePath) + if !ok && policy == sdk.DenyPolicyHide { + c.Log(logger.LevelError, "symlink source path %q is not allowed", virtualSourcePath) + return c.GetNotExistError() + } + if ok, _ = c.User.IsFileAllowed(virtualTargetPath); !ok { + c.Log(logger.LevelError, "symlink target path %q is not allowed", virtualTargetPath) + return c.GetPermissionDeniedError() + } + if relativePath != "" { + fsSourcePath = relativePath + } + startTime := time.Now() + if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil { + c.Log(logger.LevelError, "failed to create symlink %q -> %q: %+v", fsSourcePath, fsTargetPath, err) + return c.GetFsError(fs, err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + logger.CommandLog(symlinkLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, "", + "", "", -1, c.localAddr, c.remoteAddr, elapsed) + return nil +} + +func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFilePatterns, + convertResult bool, +) (os.FileInfo, error) { + // for some vfs we don't create intermediary folders so we cannot simply check + // if virtualPath is a virtual folder. Allowing stat for hidden virtual folders + // is by purpose. + vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath)) + if _, ok := vfolders[virtualPath]; ok { + return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil + } + if checkFilePatterns && virtualPath != "/" { + ok, policy := c.User.IsFileAllowed(virtualPath) + if !ok && policy == sdk.DenyPolicyHide { + return nil, c.GetNotExistError() + } + } + + var info os.FileInfo + + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return nil, err + } + + if mode == 1 { + info, err = fs.Lstat(c.getRealFsPath(fsPath)) + } else { + info, err = fs.Stat(c.getRealFsPath(fsPath)) + } + if err != nil { + isNotExist := fs.IsNotExist(err) + if isNotExist { + // This is primarily useful for atomic storage backends, where files + // become visible only after they are closed. However, since we may + // be proxying (for example) an SFTP server backed by atomic + // storage, and this search only inspects transfers active on the + // current connection (typically just one), the check is inexpensive + // and safe to perform unconditionally. + if info, err := c.getInfoForOngoingUpload(fsPath); err == nil { + return info, nil + } + } + if !isNotExist { + c.Log(logger.LevelWarn, "stat error for path %q: %+v", virtualPath, err) + } + return nil, c.GetFsError(fs, err) + } + if convertResult && vfs.IsCryptOsFs(fs) { + info = fs.(*vfs.CryptFs).ConvertFileInfo(info) + } + return info, nil +} + +// DoStat execute a Stat if mode = 0, Lstat if mode = 1 +func (c *BaseConnection) DoStat(virtualPath string, mode int, checkFilePatterns bool) (os.FileInfo, error) { + return c.doStatInternal(virtualPath, mode, checkFilePatterns, true) +} + +func (c *BaseConnection) createDirIfMissing(name string) error { + _, err := c.DoStat(name, 0, false) + if c.IsNotExistError(err) { + return c.CreateDir(name, false) + } + return err +} + +func (c *BaseConnection) ignoreSetStat(fs vfs.Fs) bool { + if Config.SetstatMode == 1 { + return true + } + if Config.SetstatMode == 2 && !vfs.IsLocalOrSFTPFs(fs) && !vfs.IsCryptOsFs(fs) { + return true + } + return false +} + +func (c *BaseConnection) handleChmod(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { + if !c.User.HasPerm(dataprovider.PermChmod, pathForPerms) { + return c.GetPermissionDeniedError() + } + if c.ignoreSetStat(fs) { + return nil + } + startTime := time.Now() + if err := fs.Chmod(c.getRealFsPath(fsPath), attributes.Mode); err != nil { + c.Log(logger.LevelError, "failed to chmod path %q, mode: %v, err: %+v", fsPath, attributes.Mode.String(), err) + return c.GetFsError(fs, err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + logger.CommandLog(chmodLogSender, fsPath, "", c.User.Username, attributes.Mode.String(), c.ID, c.protocol, + -1, -1, "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) + return nil +} + +func (c *BaseConnection) handleChown(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { + if !c.User.HasPerm(dataprovider.PermChown, pathForPerms) { + return c.GetPermissionDeniedError() + } + if c.ignoreSetStat(fs) { + return nil + } + startTime := time.Now() + if err := fs.Chown(c.getRealFsPath(fsPath), attributes.UID, attributes.GID); err != nil { + c.Log(logger.LevelError, "failed to chown path %q, uid: %v, gid: %v, err: %+v", fsPath, attributes.UID, + attributes.GID, err) + return c.GetFsError(fs, err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + logger.CommandLog(chownLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, attributes.UID, attributes.GID, + "", "", "", -1, c.localAddr, c.remoteAddr, elapsed) + return nil +} + +func (c *BaseConnection) handleChtimes(fs vfs.Fs, fsPath, pathForPerms string, attributes *StatAttributes) error { + if !c.User.HasPerm(dataprovider.PermChtimes, pathForPerms) { + return c.GetPermissionDeniedError() + } + if Config.SetstatMode == 1 { + return nil + } + startTime := time.Now() + isUploading := c.setTimes(fsPath, attributes.Atime, attributes.Mtime) + if err := fs.Chtimes(c.getRealFsPath(fsPath), attributes.Atime, attributes.Mtime, isUploading); err != nil { + c.setTimes(fsPath, time.Time{}, time.Time{}) + if errors.Is(err, vfs.ErrVfsUnsupported) && Config.SetstatMode == 2 { + return nil + } + c.Log(logger.LevelError, "failed to chtimes for path %q, access time: %v, modification time: %v, err: %+v", + fsPath, attributes.Atime, attributes.Mtime, err) + return c.GetFsError(fs, err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + accessTimeString := attributes.Atime.Format(chtimesFormat) + modificationTimeString := attributes.Mtime.Format(chtimesFormat) + logger.CommandLog(chtimesLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, + accessTimeString, modificationTimeString, "", -1, c.localAddr, c.remoteAddr, elapsed) + return nil +} + +// SetStat set StatAttributes for the specified fsPath +func (c *BaseConnection) SetStat(virtualPath string, attributes *StatAttributes) error { + if ok, policy := c.User.IsFileAllowed(virtualPath); !ok { + return c.GetErrorForDeniedFile(policy) + } + fs, fsPath, err := c.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + pathForPerms := path.Dir(virtualPath) + + if attributes.Flags&StatAttrTimes != 0 { + if err = c.handleChtimes(fs, fsPath, pathForPerms, attributes); err != nil { + return err + } + } + + if attributes.Flags&StatAttrPerms != 0 { + if err = c.handleChmod(fs, fsPath, pathForPerms, attributes); err != nil { + return err + } + } + + if attributes.Flags&StatAttrUIDGID != 0 { + if err = c.handleChown(fs, fsPath, pathForPerms, attributes); err != nil { + return err + } + } + + if attributes.Flags&StatAttrSize != 0 { + if !c.User.HasPerm(dataprovider.PermOverwrite, pathForPerms) { + return c.GetPermissionDeniedError() + } + startTime := time.Now() + if err = c.truncateFile(fs, fsPath, virtualPath, attributes.Size); err != nil { + c.Log(logger.LevelError, "failed to truncate path %q, size: %v, err: %+v", fsPath, attributes.Size, err) + return c.GetFsError(fs, err) + } + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + logger.CommandLog(truncateLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", + "", attributes.Size, c.localAddr, c.remoteAddr, elapsed) + } + + return nil +} + +func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, size int64) error { + // check first if we have an open transfer for the given path and try to truncate the file already opened + // if we found no transfer we truncate by path. + var initialSize int64 + var err error + initialSize, err = c.truncateOpenHandle(fsPath, size) + if err == errNoTransfer { + c.Log(logger.LevelDebug, "file path %q not found in active transfers, execute trucate by path", fsPath) + var info os.FileInfo + info, err = fs.Stat(fsPath) + if err != nil { + return err + } + initialSize = info.Size() + err = fs.Truncate(fsPath, size) + } + if err == nil && vfs.HasTruncateSupport(fs) { + sizeDiff := initialSize - size + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -sizeDiff, false) + } else { + dataprovider.UpdateUserQuota(&c.User, 0, -sizeDiff, false) //nolint:errcheck + } + } + return err +} + +func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath, + virtualSourcePath, virtualTargetPath string, srcInfo os.FileInfo, +) error { + if !c.User.HasPermissionsInside(virtualSourcePath) && + !c.User.HasPermissionsInside(virtualTargetPath) { + if !c.isRenamePermitted(fsSrc, fsDst, sourcePath, targetPath, virtualSourcePath, virtualTargetPath, srcInfo) { + c.Log(logger.LevelInfo, "rename %q -> %q is not allowed, virtual destination path: %q", + sourcePath, targetPath, virtualTargetPath) + return c.GetPermissionDeniedError() + } + // if all rename permissions are granted we have finished, otherwise we have to walk + // because we could have the rename dir permission but not the rename file and the dir to + // rename could contain files + if c.User.HasPermsRenameAll(path.Dir(virtualSourcePath)) && c.User.HasPermsRenameAll(path.Dir(virtualTargetPath)) { + return nil + } + } + + return fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error { + if err != nil { + return c.GetFsError(fsSrc, err) + } + if walkedPath != sourcePath && !vfs.IsRenameAtomic(fsSrc) && Config.RenameMode == 0 { + c.Log(logger.LevelInfo, "cannot rename non empty directory %q on this filesystem", virtualSourcePath) + return c.GetOpUnsupportedError() + } + dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1) + virtualSrcPath := fsSrc.GetRelativePath(walkedPath) + virtualDstPath := fsDst.GetRelativePath(dstPath) + if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) { + c.Log(logger.LevelInfo, "rename %q -> %q is not allowed, virtual destination path: %q", + walkedPath, dstPath, virtualDstPath) + return c.GetPermissionDeniedError() + } + return nil + }) +} + +func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool { + if c.User.HasPermsRenameAll(path.Dir(virtualSourcePath)) && + c.User.HasPermsRenameAll(path.Dir(virtualTargetPath)) { + return true + } + if fi == nil { + // we don't know if this is a file or a directory and we don't have all the rename perms, return false + return false + } + if fi.IsDir() { + perms := []string{ + dataprovider.PermRenameDirs, + dataprovider.PermRename, + } + return c.User.HasAnyPerm(perms, path.Dir(virtualSourcePath)) && + c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath)) + } + // file or symlink + perms := []string{ + dataprovider.PermRenameFiles, + dataprovider.PermRename, + } + return c.User.HasAnyPerm(perms, path.Dir(virtualSourcePath)) && + c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath)) +} + +func (c *BaseConnection) checkFolderRename(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, + virtualTargetPath string, srcInfo os.FileInfo) error { + if util.IsDirOverlapped(virtualSourcePath, virtualTargetPath, true, "/") { + c.Log(logger.LevelDebug, "renaming the folder %q->%q is not supported: nested folders", + virtualSourcePath, virtualTargetPath) + return fmt.Errorf("nested rename %q => %q is not supported: %w", + virtualSourcePath, virtualTargetPath, c.GetOpUnsupportedError()) + } + if util.IsDirOverlapped(fsSourcePath, fsTargetPath, true, c.User.FsConfig.GetPathSeparator()) { + c.Log(logger.LevelDebug, "renaming the folder %q->%q is not supported: nested fs folders", + fsSourcePath, fsTargetPath) + return fmt.Errorf("nested fs rename %q => %q is not supported: %w", + fsSourcePath, fsTargetPath, c.GetOpUnsupportedError()) + } + if c.User.HasVirtualFoldersInside(virtualSourcePath) { + c.Log(logger.LevelDebug, "renaming the folder %q is not supported: it has virtual folders inside it", + virtualSourcePath) + return fmt.Errorf("folder %q has virtual folders inside it: %w", virtualSourcePath, c.GetOpUnsupportedError()) + } + if c.User.HasVirtualFoldersInside(virtualTargetPath) { + c.Log(logger.LevelDebug, "renaming the folder %q is not supported, the target %q has virtual folders inside it", + virtualSourcePath, virtualTargetPath) + return fmt.Errorf("folder %q has virtual folders inside it: %w", virtualTargetPath, c.GetOpUnsupportedError()) + } + if err := c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath, + virtualSourcePath, virtualTargetPath, srcInfo); err != nil { + c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %q: %+v", fsSourcePath, err) + return err + } + return nil +} + +func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, + virtualTargetPath string, srcInfo os.FileInfo, +) bool { + if !c.IsSameResource(virtualSourcePath, virtualTargetPath) { + c.Log(logger.LevelInfo, "rename %q->%q is not allowed: the paths must be on the same resource", + virtualSourcePath, virtualTargetPath) + return false + } + if c.User.IsMappedPath(fsSourcePath) && vfs.IsLocalOrCryptoFs(fsSrc) { + c.Log(logger.LevelWarn, "renaming a directory mapped as virtual folder is not allowed: %q", fsSourcePath) + return false + } + if c.User.IsMappedPath(fsTargetPath) && vfs.IsLocalOrCryptoFs(fsDst) { + c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %q", fsTargetPath) + return false + } + if virtualSourcePath == "/" || virtualTargetPath == "/" || fsSrc.GetRelativePath(fsSourcePath) == "/" { + c.Log(logger.LevelWarn, "renaming root dir is not allowed") + return false + } + if c.User.IsVirtualFolder(virtualSourcePath) || c.User.IsVirtualFolder(virtualTargetPath) { + c.Log(logger.LevelWarn, "renaming a virtual folder is not allowed") + return false + } + isSrcAllowed, _ := c.User.IsFileAllowed(virtualSourcePath) + isDstAllowed, _ := c.User.IsFileAllowed(virtualTargetPath) + if !isSrcAllowed || !isDstAllowed { + c.Log(logger.LevelDebug, "renaming source: %q to target: %q not allowed", virtualSourcePath, + virtualTargetPath) + return false + } + return c.hasRenamePerms(virtualSourcePath, virtualTargetPath, srcInfo) +} + +func (c *BaseConnection) hasSpaceForRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath string, initialSize int64, + sourcePath string, srcInfo os.FileInfo) bool { + if dataprovider.GetQuotaTracking() == 0 { + return true + } + sourceFolder, errSrc := c.User.GetVirtualFolderForPath(path.Dir(virtualSourcePath)) + dstFolder, errDst := c.User.GetVirtualFolderForPath(path.Dir(virtualTargetPath)) + if errSrc != nil && errDst != nil { + // rename inside the user home dir + return true + } + if errSrc == nil && errDst == nil { + // rename between virtual folders + if sourceFolder.Name == dstFolder.Name { + // rename inside the same virtual folder + return true + } + } + if errSrc != nil && dstFolder.IsIncludedInUserQuota() { + // rename between user root dir and a virtual folder included in user quota + return true + } + if errDst != nil && sourceFolder.IsIncludedInUserQuota() { + // rename between a virtual folder included in user quota and the user root dir + return true + } + quotaResult, _ := c.HasSpace(true, false, virtualTargetPath) + if quotaResult.HasSpace && quotaResult.QuotaSize == 0 && quotaResult.QuotaFiles == 0 { + // no quota restrictions + return true + } + return c.hasSpaceForCrossRename(fs, quotaResult, initialSize, sourcePath, srcInfo) +} + +// hasSpaceForCrossRename checks the quota after a rename between different folders +func (c *BaseConnection) hasSpaceForCrossRename(fs vfs.Fs, quotaResult vfs.QuotaCheckResult, initialSize int64, + sourcePath string, srcInfo os.FileInfo, +) bool { + if !quotaResult.HasSpace && initialSize == -1 { + // we are over quota and this is not a file replace + return false + } + var sizeDiff int64 + var filesDiff int + var err error + if srcInfo.Mode().IsRegular() { + sizeDiff = srcInfo.Size() + filesDiff = 1 + if initialSize != -1 { + sizeDiff -= initialSize + filesDiff = 0 + } + } else if srcInfo.IsDir() { + filesDiff, sizeDiff, err = fs.GetDirSize(sourcePath) + if err != nil { + c.Log(logger.LevelError, "cross rename denied, error getting size for directory %q: %v", sourcePath, err) + return false + } + } + if !quotaResult.HasSpace && initialSize != -1 { + // we are over quota but we are overwriting an existing file so we check if the quota size after the rename is ok + if quotaResult.QuotaSize == 0 { + return true + } + c.Log(logger.LevelDebug, "cross rename overwrite, source %q, used size %d, size to add %d", + sourcePath, quotaResult.UsedSize, sizeDiff) + quotaResult.UsedSize += sizeDiff + return quotaResult.GetRemainingSize() >= 0 + } + if quotaResult.QuotaFiles > 0 { + remainingFiles := quotaResult.GetRemainingFiles() + c.Log(logger.LevelDebug, "cross rename, source %q remaining file %d to add %d", sourcePath, + remainingFiles, filesDiff) + if remainingFiles < filesDiff { + return false + } + } + if quotaResult.QuotaSize > 0 { + remainingSize := quotaResult.GetRemainingSize() + c.Log(logger.LevelDebug, "cross rename, source %q remaining size %d to add %d", srcInfo.Name(), + remainingSize, sizeDiff) + if remainingSize < sizeDiff { + return false + } + } + return true +} + +// GetMaxWriteSize returns the allowed size for an upload or an error +// if no enough size is available for a resume/append +func (c *BaseConnection) GetMaxWriteSize(quotaResult vfs.QuotaCheckResult, isResume bool, fileSize int64, + isUploadResumeSupported bool, +) (int64, error) { + maxWriteSize := quotaResult.GetRemainingSize() + + if isResume { + if !isUploadResumeSupported { + return 0, c.GetOpUnsupportedError() + } + if c.User.Filters.MaxUploadFileSize > 0 && c.User.Filters.MaxUploadFileSize <= fileSize { + return 0, c.GetQuotaExceededError() + } + if c.User.Filters.MaxUploadFileSize > 0 { + maxUploadSize := c.User.Filters.MaxUploadFileSize - fileSize + if maxUploadSize < maxWriteSize || maxWriteSize == 0 { + maxWriteSize = maxUploadSize + } + } + } else { + if maxWriteSize > 0 { + maxWriteSize += fileSize + } + if c.User.Filters.MaxUploadFileSize > 0 && (c.User.Filters.MaxUploadFileSize < maxWriteSize || maxWriteSize == 0) { + maxWriteSize = c.User.Filters.MaxUploadFileSize + } + } + + return maxWriteSize, nil +} + +// GetTransferQuota returns the data transfers quota +func (c *BaseConnection) GetTransferQuota() dataprovider.TransferQuota { + result, _, _ := c.checkUserQuota() + return result +} + +func (c *BaseConnection) checkUserQuota() (dataprovider.TransferQuota, int, int64) { + ul, dl, total := c.User.GetDataTransferLimits() + result := dataprovider.TransferQuota{ + ULSize: ul, + DLSize: dl, + TotalSize: total, + AllowedULSize: 0, + AllowedDLSize: 0, + AllowedTotalSize: 0, + } + if !c.User.HasTransferQuotaRestrictions() { + return result, -1, -1 + } + usedFiles, usedSize, usedULSize, usedDLSize, err := dataprovider.GetUsedQuota(c.User.Username) + if err != nil { + c.Log(logger.LevelError, "error getting used quota for %q: %v", c.User.Username, err) + result.AllowedTotalSize = -1 + return result, -1, -1 + } + if result.TotalSize > 0 { + result.AllowedTotalSize = result.TotalSize - (usedULSize + usedDLSize) + } + if result.ULSize > 0 { + result.AllowedULSize = result.ULSize - usedULSize + } + if result.DLSize > 0 { + result.AllowedDLSize = result.DLSize - usedDLSize + } + + return result, usedFiles, usedSize +} + +// HasSpace checks user's quota usage +func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string) (vfs.QuotaCheckResult, + dataprovider.TransferQuota, +) { + result := vfs.QuotaCheckResult{ + HasSpace: true, + AllowedSize: 0, + AllowedFiles: 0, + UsedSize: 0, + UsedFiles: 0, + QuotaSize: 0, + QuotaFiles: 0, + } + if dataprovider.GetQuotaTracking() == 0 { + return result, dataprovider.TransferQuota{} + } + transferQuota, usedFiles, usedSize := c.checkUserQuota() + + var err error + var vfolder vfs.VirtualFolder + vfolder, err = c.User.GetVirtualFolderForPath(path.Dir(requestPath)) + if err == nil && !vfolder.IsIncludedInUserQuota() { + if vfolder.HasNoQuotaRestrictions(checkFiles) && !getUsage { + return result, transferQuota + } + result.QuotaSize = vfolder.QuotaSize + result.QuotaFiles = vfolder.QuotaFiles + result.UsedFiles, result.UsedSize, err = dataprovider.GetUsedVirtualFolderQuota(vfolder.Name) + } else { + if c.User.HasNoQuotaRestrictions(checkFiles) && !getUsage { + return result, transferQuota + } + result.QuotaSize = c.User.QuotaSize + result.QuotaFiles = c.User.QuotaFiles + if usedSize == -1 { + result.UsedFiles, result.UsedSize, _, _, err = dataprovider.GetUsedQuota(c.User.Username) + } else { + err = nil + result.UsedFiles = usedFiles + result.UsedSize = usedSize + } + } + if err != nil { + c.Log(logger.LevelError, "error getting used quota for %q request path %q: %v", c.User.Username, requestPath, err) + result.HasSpace = false + return result, transferQuota + } + result.AllowedFiles = result.QuotaFiles - result.UsedFiles + result.AllowedSize = result.QuotaSize - result.UsedSize + if (checkFiles && result.QuotaFiles > 0 && result.UsedFiles >= result.QuotaFiles) || + (result.QuotaSize > 0 && result.UsedSize >= result.QuotaSize) { + c.Log(logger.LevelDebug, "quota exceed for user %q, request path %q, num files: %d/%d, size: %d/%d check files: %t", + c.User.Username, requestPath, result.UsedFiles, result.QuotaFiles, result.UsedSize, result.QuotaSize, checkFiles) + result.HasSpace = false + return result, transferQuota + } + return result, transferQuota +} + +// IsSameResource returns true if source and target paths are on the same resource +func (c *BaseConnection) IsSameResource(virtualSourcePath, virtualTargetPath string) bool { + sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) + dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) + if errSrc != nil && errDst != nil { + return true + } + if errSrc == nil && errDst == nil { + if sourceFolder.Name == dstFolder.Name { + return true + } + // we have different folders, check if they point to the same resource + return sourceFolder.FsConfig.IsSameResource(dstFolder.FsConfig) + } + if errSrc == nil { + return sourceFolder.FsConfig.IsSameResource(c.User.FsConfig) + } + return dstFolder.FsConfig.IsSameResource(c.User.FsConfig) +} + +func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool { + sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) + dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) + if errSrc != nil && errDst != nil { + return false + } + if errSrc == nil && errDst == nil { + return sourceFolder.Name != dstFolder.Name + } + return true +} + +func (c *BaseConnection) updateQuotaMoveBetweenVFolders(sourceFolder, dstFolder *vfs.VirtualFolder, initialSize, + filesSize int64, numFiles int) { + if sourceFolder.Name == dstFolder.Name { + // both files are inside the same virtual folder + if initialSize != -1 { + dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, -numFiles, -initialSize, false) + } + return + } + // files are inside different virtual folders + dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false) + if initialSize == -1 { + dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false) + return + } + // we cannot have a directory here, initialSize != -1 only for files + dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false) +} + +func (c *BaseConnection) updateQuotaMoveFromVFolder(sourceFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { + // move between a virtual folder and the user home dir + dataprovider.UpdateUserFolderQuota(sourceFolder, &c.User, -numFiles, -filesSize, false) + if initialSize == -1 { + dataprovider.UpdateUserQuota(&c.User, numFiles, filesSize, false) //nolint:errcheck + return + } + // we cannot have a directory here, initialSize != -1 only for files + dataprovider.UpdateUserQuota(&c.User, 0, filesSize-initialSize, false) //nolint:errcheck +} + +func (c *BaseConnection) updateQuotaMoveToVFolder(dstFolder *vfs.VirtualFolder, initialSize, filesSize int64, numFiles int) { + // move between the user home dir and a virtual folder + dataprovider.UpdateUserQuota(&c.User, -numFiles, -filesSize, false) //nolint:errcheck + if initialSize == -1 { + dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, numFiles, filesSize, false) + return + } + // we cannot have a directory here, initialSize != -1 only for files + dataprovider.UpdateUserFolderQuota(dstFolder, &c.User, 0, filesSize-initialSize, false) +} + +func (c *BaseConnection) updateQuotaAfterRename(fs vfs.Fs, virtualSourcePath, virtualTargetPath, targetPath string, + initialSize int64, numFiles int, filesSize int64, +) error { + if dataprovider.GetQuotaTracking() == 0 { + return nil + } + // we don't allow to overwrite an existing directory so targetPath can be: + // - a new file, a symlink is as a new file here + // - a file overwriting an existing one + // - a new directory + // initialSize != -1 only when overwriting files + sourceFolder, errSrc := c.User.GetVirtualFolderForPath(path.Dir(virtualSourcePath)) + dstFolder, errDst := c.User.GetVirtualFolderForPath(path.Dir(virtualTargetPath)) + if errSrc != nil && errDst != nil { + // both files are contained inside the user home dir + if initialSize != -1 { + // we cannot have a directory here, we are overwriting an existing file + // we need to subtract the size of the overwritten file from the user quota + dataprovider.UpdateUserQuota(&c.User, -1, -initialSize, false) //nolint:errcheck + } + return nil + } + + if filesSize == -1 { + // fs.Rename didn't return the affected files/sizes, we need to calculate them + numFiles = 1 + if fi, err := fs.Stat(targetPath); err == nil { + if fi.Mode().IsDir() { + numFiles, filesSize, err = fs.GetDirSize(targetPath) + if err != nil { + c.Log(logger.LevelError, "failed to update quota after rename, error scanning moved folder %q: %+v", + targetPath, err) + return err + } + } else { + filesSize = fi.Size() + } + } else { + c.Log(logger.LevelError, "failed to update quota after renaming, file %q stat error: %+v", targetPath, err) + return err + } + c.Log(logger.LevelDebug, "calculated renamed files: %d, size: %d bytes", numFiles, filesSize) + } else { + c.Log(logger.LevelDebug, "returned renamed files: %d, size: %d bytes", numFiles, filesSize) + } + if errSrc == nil && errDst == nil { + c.updateQuotaMoveBetweenVFolders(&sourceFolder, &dstFolder, initialSize, filesSize, numFiles) + } + if errSrc == nil && errDst != nil { + c.updateQuotaMoveFromVFolder(&sourceFolder, initialSize, filesSize, numFiles) + } + if errSrc != nil && errDst == nil { + c.updateQuotaMoveToVFolder(&dstFolder, initialSize, filesSize, numFiles) + } + return nil +} + +// IsNotExistError returns true if the specified fs error is not exist for the connection protocol +func (c *BaseConnection) IsNotExistError(err error) bool { + switch c.protocol { + case ProtocolSFTP: + return errors.Is(err, sftp.ErrSSHFxNoSuchFile) + case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention: + return errors.Is(err, os.ErrNotExist) + default: + return errors.Is(err, ErrNotExist) + } +} + +// GetErrorForDeniedFile return permission denied or not exist error based on the specified policy +func (c *BaseConnection) GetErrorForDeniedFile(policy int) error { + switch policy { + case sdk.DenyPolicyHide: + return c.GetNotExistError() + default: + return c.GetPermissionDeniedError() + } +} + +// GetPermissionDeniedError returns an appropriate permission denied error for the connection protocol +func (c *BaseConnection) GetPermissionDeniedError() error { + return getPermissionDeniedError(c.protocol) +} + +// GetNotExistError returns an appropriate not exist error for the connection protocol +func (c *BaseConnection) GetNotExistError() error { + switch c.protocol { + case ProtocolSFTP: + return sftp.ErrSSHFxNoSuchFile + case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention: + return os.ErrNotExist + default: + return ErrNotExist + } +} + +// GetOpUnsupportedError returns an appropriate operation not supported error for the connection protocol +func (c *BaseConnection) GetOpUnsupportedError() error { + switch c.protocol { + case ProtocolSFTP: + return sftp.ErrSSHFxOpUnsupported + default: + return ErrOpUnsupported + } +} + +func getQuotaExceededError(protocol string) error { + switch protocol { + case ProtocolSFTP: + return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, ErrQuotaExceeded) + case ProtocolFTP: + return ftpserver.ErrStorageExceeded + default: + return ErrQuotaExceeded + } +} + +func getReadQuotaExceededError(protocol string) error { + switch protocol { + case ProtocolSFTP: + return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, ErrReadQuotaExceeded) + default: + return ErrReadQuotaExceeded + } +} + +// GetQuotaExceededError returns an appropriate storage limit exceeded error for the connection protocol +func (c *BaseConnection) GetQuotaExceededError() error { + return getQuotaExceededError(c.protocol) +} + +// GetReadQuotaExceededError returns an appropriate read quota limit exceeded error for the connection protocol +func (c *BaseConnection) GetReadQuotaExceededError() error { + return getReadQuotaExceededError(c.protocol) +} + +// IsQuotaExceededError returns true if the given error is a quota exceeded error +func (c *BaseConnection) IsQuotaExceededError(err error) bool { + switch c.protocol { + case ProtocolSFTP: + if err == nil { + return false + } + if errors.Is(err, ErrQuotaExceeded) { + return true + } + return errors.Is(err, sftp.ErrSSHFxFailure) && strings.Contains(err.Error(), ErrQuotaExceeded.Error()) + case ProtocolFTP: + return errors.Is(err, ftpserver.ErrStorageExceeded) || errors.Is(err, ErrQuotaExceeded) + default: + return errors.Is(err, ErrQuotaExceeded) + } +} + +func isSFTPGoError(err error) bool { + return errors.Is(err, ErrPermissionDenied) || errors.Is(err, ErrNotExist) || errors.Is(err, ErrOpUnsupported) || + errors.Is(err, ErrQuotaExceeded) || errors.Is(err, ErrReadQuotaExceeded) || + errors.Is(err, vfs.ErrStorageSizeUnavailable) || errors.Is(err, ErrShuttingDown) +} + +// GetGenericError returns an appropriate generic error for the connection protocol +func (c *BaseConnection) GetGenericError(err error) error { + switch c.protocol { + case ProtocolSFTP: + if errors.Is(err, vfs.ErrStorageSizeUnavailable) || errors.Is(err, ErrOpUnsupported) || errors.Is(err, sftp.ErrSSHFxOpUnsupported) { + return fmt.Errorf("%w: %w", sftp.ErrSSHFxOpUnsupported, err) + } + if isSFTPGoError(err) { + return fmt.Errorf("%w: %w", sftp.ErrSSHFxFailure, err) + } + if err != nil { + var pathError *fs.PathError + if errors.As(err, &pathError) { + c.Log(logger.LevelError, "generic path error: %+v", pathError) + return fmt.Errorf("%w: %v %v", sftp.ErrSSHFxFailure, pathError.Op, pathError.Err.Error()) + } + c.Log(logger.LevelError, "generic error: %+v", err) + } + return sftp.ErrSSHFxFailure + default: + if isSFTPGoError(err) { + return err + } + c.Log(logger.LevelError, "generic error: %+v", err) + return ErrGenericFailure + } +} + +// GetFsError converts a filesystem error to a protocol error +func (c *BaseConnection) GetFsError(fs vfs.Fs, err error) error { + if fs.IsNotExist(err) { + return c.GetNotExistError() + } else if fs.IsPermission(err) { + return c.GetPermissionDeniedError() + } else if fs.IsNotSupported(err) { + return c.GetOpUnsupportedError() + } else if err != nil { + return c.GetGenericError(err) + } + return nil +} + +func (c *BaseConnection) getNotificationStatus(err error) int { + if err == nil { + return 1 + } + if c.IsQuotaExceededError(err) { + return 3 + } + return 2 +} + +// GetFsAndResolvedPath returns the fs and the fs path matching virtualPath +func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, string, error) { + fs, err := c.User.GetFilesystemForPath(virtualPath, c.ID) + if err != nil { + if c.protocol == ProtocolWebDAV && strings.Contains(err.Error(), vfs.ErrSFTPLoop.Error()) { + // if there is an SFTP loop we return a permission error, for WebDAV, so the problematic folder + // will not be listed + return nil, "", util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } + return nil, "", c.GetGenericError(err) + } + + if isShuttingDown.Load() { + return nil, "", c.GetFsError(fs, ErrShuttingDown) + } + + fsPath, err := fs.ResolvePath(virtualPath) + if err != nil { + return nil, "", c.GetFsError(fs, err) + } + + return fs, fsPath, nil +} + +// DirListerAt defines a directory lister implementing the ListAt method. +type DirListerAt struct { + virtualPath string + conn *BaseConnection + fs vfs.Fs + info []os.FileInfo + mu sync.Mutex + lister vfs.DirLister +} + +// Prepend adds the given os.FileInfo as first element of the internal cache +func (l *DirListerAt) Prepend(fi os.FileInfo) { + l.mu.Lock() + defer l.mu.Unlock() + + l.info = slices.Insert(l.info, 0, fi) +} + +// ListAt implements sftp.ListerAt +func (l *DirListerAt) ListAt(f []os.FileInfo, _ int64) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + + if len(f) == 0 { + return 0, errors.New("invalid ListAt destination, zero size") + } + if len(f) <= len(l.info) { + files := make([]os.FileInfo, 0, len(f)) + for idx := range l.info { + files = append(files, l.info[idx]) + if len(files) == len(f) { + l.info = l.info[idx+1:] + n := copy(f, files) + return n, nil + } + } + } + limit := len(f) - len(l.info) + files, err := l.Next(limit) + n := copy(f, files) + return n, err +} + +// Next reads the directory and returns a slice of up to n FileInfo values. +func (l *DirListerAt) Next(limit int) ([]os.FileInfo, error) { + for { + files, err := l.lister.Next(limit) + if err != nil && !errors.Is(err, io.EOF) { + l.conn.Log(logger.LevelDebug, "error retrieving directory entries: %+v", err) + return files, l.conn.GetFsError(l.fs, err) + } + files = l.conn.User.FilterListDir(files, l.virtualPath) + if len(l.info) > 0 { + files = slices.Concat(l.info, files) + l.info = nil + } + if err != nil || len(files) > 0 { + return files, err + } + } +} + +// Close closes the DirListerAt +func (l *DirListerAt) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + + return l.lister.Close() +} + +func (l *DirListerAt) convertError(err error) error { + if errors.Is(err, io.EOF) { + return nil + } + return err +} + +func getPermissionDeniedError(protocol string) error { + switch protocol { + case ProtocolSFTP: + return sftp.ErrSSHFxPermissionDenied + case ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolOIDC, ProtocolHTTPShare, ProtocolDataRetention: + return os.ErrPermission + default: + return ErrPermissionDenied + } +} + +func keepConnectionAlive(c *BaseConnection, interval time.Duration) func() { + var timer *time.Timer + var closed atomic.Bool + + task := func() { + c.UpdateLastActivity() + + if !closed.Load() { + timer.Reset(interval) + } + } + + timer = time.AfterFunc(interval, task) + + return func() { + closed.Store(true) + timer.Stop() + } +} diff --git a/internal/common/connection_test.go b/internal/common/connection_test.go new file mode 100644 index 00000000..0f426858 --- /dev/null +++ b/internal/common/connection_test.go @@ -0,0 +1,1536 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "io" + "os" + "path" + "path/filepath" + "runtime" + "slices" + "strconv" + "testing" + "time" + + "github.com/pkg/sftp" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + errWalkDir = errors.New("err walk dir") +) + +// MockOsFs mockable OsFs +type MockOsFs struct { + vfs.Fs + hasVirtualFolders bool + name string + err error +} + +// Name returns the name for the Fs implementation +func (fs *MockOsFs) Name() string { + if fs.name != "" { + return fs.name + } + return "mockOsFs" +} + +// HasVirtualFolders returns true if folders are emulated +func (fs *MockOsFs) HasVirtualFolders() bool { + return fs.hasVirtualFolders +} + +func (fs *MockOsFs) IsUploadResumeSupported() bool { + return !fs.hasVirtualFolders +} + +func (fs *MockOsFs) Chtimes(_ string, _, _ time.Time, _ bool) error { + return vfs.ErrVfsUnsupported +} + +func (fs *MockOsFs) Lstat(name string) (os.FileInfo, error) { + if fs.err != nil { + return nil, fs.err + } + return fs.Fs.Lstat(name) +} + +// Walk returns a duplicate path for testing +func (fs *MockOsFs) Walk(_ string, walkFn filepath.WalkFunc) error { + if fs.err == errWalkDir { + walkFn("fsdpath", vfs.NewFileInfo("dpath", true, 0, time.Now(), false), nil) //nolint:errcheck + return walkFn("fsdpath", vfs.NewFileInfo("dpath", true, 0, time.Now(), false), nil) //nolint:errcheck + } + walkFn("fsfpath", vfs.NewFileInfo("fpath", false, 0, time.Now(), false), nil) //nolint:errcheck + return fs.err +} + +func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir, name string, err error) vfs.Fs { + return &MockOsFs{ + Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), + name: name, + hasVirtualFolders: hasVirtualFolders, + err: err, + } +} + +func TestRemoveErrors(t *testing.T) { + mappedPath := filepath.Join(os.TempDir(), "map") + homePath := filepath.Join(os.TempDir(), "home") + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "remove_errors_user", + HomeDir: homePath, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: filepath.Base(mappedPath), + MappedPath: mappedPath, + }, + VirtualPath: "/virtualpath", + }, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := NewBaseConnection("", ProtocolFTP, "", "", user) + err := conn.IsRemoveDirAllowed(fs, mappedPath, "/virtualpath1") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission denied") + } + err = conn.RemoveFile(fs, filepath.Join(homePath, "missing_file"), "/missing_file", + vfs.NewFileInfo("info", false, 100, time.Now(), false)) + assert.Error(t, err) +} + +func TestSetStatMode(t *testing.T) { + oldSetStatMode := Config.SetstatMode + Config.SetstatMode = 1 + + fakePath := "fake path" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: os.TempDir(), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := newMockOsFs(true, "", user.GetHomeDir(), "", nil) + conn := NewBaseConnection("", ProtocolWebDAV, "", "", user) + err := conn.handleChmod(fs, fakePath, fakePath, nil) + assert.NoError(t, err) + err = conn.handleChown(fs, fakePath, fakePath, nil) + assert.NoError(t, err) + err = conn.handleChtimes(fs, fakePath, fakePath, nil) + assert.NoError(t, err) + + Config.SetstatMode = 2 + err = conn.handleChmod(fs, fakePath, fakePath, nil) + assert.NoError(t, err) + err = conn.handleChtimes(fs, fakePath, fakePath, &StatAttributes{ + Atime: time.Now(), + Mtime: time.Now(), + }) + assert.NoError(t, err) + + Config.SetstatMode = oldSetStatMode +} + +func TestRecursiveRenameWalkError(t *testing.T) { + fs := vfs.NewOsFs("", filepath.Clean(os.TempDir()), "", nil) + conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Permissions: map[string][]string{ + "/": {dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDownload, dataprovider.PermRenameDirs}, + }, + }, + }) + err := conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"), + filepath.Join(os.TempDir(), "/target"), "/source", "/target", + vfs.NewFileInfo("source", true, 0, time.Now(), false)) + assert.ErrorIs(t, err, os.ErrNotExist) + + fs = newMockOsFs(false, "mockID", filepath.Clean(os.TempDir()), "S3Fs", errWalkDir) + err = conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"), + filepath.Join(os.TempDir(), "/target"), "/source", "/target", + vfs.NewFileInfo("source", true, 0, time.Now(), false)) + if assert.Error(t, err) { + assert.Equal(t, err.Error(), conn.GetOpUnsupportedError().Error()) + } + + conn.User.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDownload, dataprovider.PermRenameFiles} + // no dir rename permission, the quick check path returns permission error without walking + err = conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"), + filepath.Join(os.TempDir(), "/target"), "/source", "/target", + vfs.NewFileInfo("source", true, 0, time.Now(), false)) + if assert.Error(t, err) { + assert.EqualError(t, err, conn.GetPermissionDeniedError().Error()) + } +} + +func TestCrossRenameFsErrors(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{}) + dirPath := filepath.Join(os.TempDir(), "d") + err := os.Mkdir(dirPath, os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(dirPath, 0001) + assert.NoError(t, err) + srcInfo := vfs.NewFileInfo(filepath.Base(dirPath), true, 0, time.Now(), false) + res := conn.hasSpaceForCrossRename(fs, vfs.QuotaCheckResult{}, 1, dirPath, srcInfo) + assert.False(t, res) + + err = os.Chmod(dirPath, os.ModePerm) + assert.NoError(t, err) + err = os.Remove(dirPath) + assert.NoError(t, err) +} + +func TestRenameVirtualFolders(t *testing.T) { + vdir := "/avdir" + u := dataprovider.User{} + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "name", + MappedPath: "mappedPath", + }, + VirtualPath: vdir, + }) + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := NewBaseConnection("", ProtocolFTP, "", "", u) + res := conn.isRenamePermitted(fs, fs, "source", "target", vdir, "vdirtarget", nil) + assert.False(t, res) +} + +func TestRenamePerms(t *testing.T) { + src := "source" + target := "target" + sub := "/sub" + subTarget := sub + "/target" + u := dataprovider.User{} + u.Permissions = map[string][]string{} + u.Permissions["/"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermCreateSymlinks, + dataprovider.PermDeleteFiles} + conn := NewBaseConnection("", ProtocolSFTP, "", "", u) + assert.False(t, conn.hasRenamePerms(src, target, nil)) + u.Permissions["/"] = []string{dataprovider.PermRename} + assert.True(t, conn.hasRenamePerms(src, target, nil)) + u.Permissions["/"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermDeleteFiles, + dataprovider.PermDeleteDirs} + assert.False(t, conn.hasRenamePerms(src, target, nil)) + + info := vfs.NewFileInfo(src, true, 0, time.Now(), false) + u.Permissions["/"] = []string{dataprovider.PermRenameFiles} + assert.False(t, conn.hasRenamePerms(src, target, info)) + u.Permissions["/"] = []string{dataprovider.PermRenameDirs} + assert.True(t, conn.hasRenamePerms(src, target, info)) + u.Permissions["/"] = []string{dataprovider.PermRename} + assert.True(t, conn.hasRenamePerms(src, target, info)) + u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDeleteDirs} + assert.False(t, conn.hasRenamePerms(src, target, info)) + // test with different permissions between source and target + u.Permissions["/"] = []string{dataprovider.PermRename} + u.Permissions[sub] = []string{dataprovider.PermRenameFiles} + assert.False(t, conn.hasRenamePerms(src, subTarget, info)) + u.Permissions[sub] = []string{dataprovider.PermRenameDirs} + assert.True(t, conn.hasRenamePerms(src, subTarget, info)) + // test files + info = vfs.NewFileInfo(src, false, 0, time.Now(), false) + u.Permissions["/"] = []string{dataprovider.PermRenameDirs} + assert.False(t, conn.hasRenamePerms(src, target, info)) + u.Permissions["/"] = []string{dataprovider.PermRenameFiles} + assert.True(t, conn.hasRenamePerms(src, target, info)) + u.Permissions["/"] = []string{dataprovider.PermRename} + assert.True(t, conn.hasRenamePerms(src, target, info)) + // test with different permissions between source and target + u.Permissions["/"] = []string{dataprovider.PermRename} + u.Permissions[sub] = []string{dataprovider.PermRenameDirs} + assert.False(t, conn.hasRenamePerms(src, subTarget, info)) + u.Permissions[sub] = []string{dataprovider.PermRenameFiles} + assert.True(t, conn.hasRenamePerms(src, subTarget, info)) +} + +func TestRenameNestedFolders(t *testing.T) { + u := dataprovider.User{} + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "vfolder", + MappedPath: filepath.Join(os.TempDir(), "f"), + }, + VirtualPath: "/vdirs/f", + }) + conn := NewBaseConnection("", ProtocolSFTP, "", "", u) + err := conn.checkFolderRename(nil, nil, filepath.Clean(os.TempDir()), filepath.Join(os.TempDir(), "subdir"), "/src", "/dst", nil) + assert.Error(t, err) + err = conn.checkFolderRename(nil, nil, filepath.Join(os.TempDir(), "subdir"), filepath.Clean(os.TempDir()), "/src", "/dst", nil) + assert.Error(t, err) + err = conn.checkFolderRename(nil, nil, "", "", "/src/sub", "/src", nil) + assert.Error(t, err) + err = conn.checkFolderRename(nil, nil, filepath.Join(os.TempDir(), "src"), filepath.Join(os.TempDir(), "vdirs"), "/src", "/vdirs", nil) + assert.Error(t, err) +} + +func TestUpdateQuotaAfterRename(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + HomeDir: filepath.Join(os.TempDir(), "home"), + }, + } + mappedPath := filepath.Join(os.TempDir(), "vdir") + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath, + }, + VirtualPath: "/vdir", + QuotaFiles: -1, + QuotaSize: -1, + }) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath, + }, + VirtualPath: "/vdir1", + QuotaFiles: -1, + QuotaSize: -1, + }) + err := os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + fs, err := user.GetFilesystem("id") + assert.NoError(t, err) + c := NewBaseConnection("", ProtocolSFTP, "", "", user) + request := sftp.NewRequest("Rename", "/testfile") + if runtime.GOOS != osWindows { + request.Filepath = "/dir" + request.Target = path.Join("/vdir", "dir") + testDirPath := filepath.Join(mappedPath, "dir") + err := os.MkdirAll(testDirPath, os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(testDirPath, 0001) + assert.NoError(t, err) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, testDirPath, 0, -1, -1) + assert.Error(t, err) + err = os.Chmod(testDirPath, os.ModePerm) + assert.NoError(t, err) + } + testFile1 := "/testfile1" + request.Target = testFile1 + request.Filepath = path.Join("/vdir", "file") + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 0, -1, -1) + assert.Error(t, err) + err = os.WriteFile(filepath.Join(mappedPath, "file"), []byte("test content"), os.ModePerm) + assert.NoError(t, err) + request.Filepath = testFile1 + request.Target = path.Join("/vdir", "file") + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, -1, -1) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "testfile1"), []byte("test content"), os.ModePerm) + assert.NoError(t, err) + request.Target = testFile1 + request.Filepath = path.Join("/vdir", "file") + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, -1, -1) + assert.NoError(t, err) + request.Target = path.Join("/vdir1", "file") + request.Filepath = path.Join("/vdir", "file") + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, -1, -1) + assert.NoError(t, err) + err = c.updateQuotaAfterRename(fs, request.Filepath, request.Target, filepath.Join(mappedPath, "file"), 12, 1, 100) + assert.NoError(t, err) + + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestErrorsMapping(t *testing.T) { + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}) + osErrorsProtocols := []string{ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare, + ProtocolDataRetention, ProtocolOIDC, protocolEventAction} + for _, protocol := range supportedProtocols { + conn.SetProtocol(protocol) + err := conn.GetFsError(fs, os.ErrNotExist) + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) + } else if slices.Contains(osErrorsProtocols, protocol) { + assert.EqualError(t, err, os.ErrNotExist.Error()) + } else { + assert.EqualError(t, err, ErrNotExist.Error()) + } + err = conn.GetFsError(fs, os.ErrPermission) + if protocol == ProtocolSFTP { + assert.EqualError(t, err, sftp.ErrSSHFxPermissionDenied.Error()) + } else { + assert.EqualError(t, err, ErrPermissionDenied.Error()) + } + err = conn.GetFsError(fs, os.ErrClosed) + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + } else { + assert.EqualError(t, err, ErrGenericFailure.Error()) + } + err = conn.GetFsError(fs, ErrPermissionDenied) + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + } else { + assert.EqualError(t, err, ErrPermissionDenied.Error()) + } + err = conn.GetFsError(fs, vfs.ErrVfsUnsupported) + if protocol == ProtocolSFTP { + assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) + } else { + assert.EqualError(t, err, ErrOpUnsupported.Error()) + } + err = conn.GetFsError(fs, vfs.ErrStorageSizeUnavailable) + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxOpUnsupported) + assert.Contains(t, err.Error(), vfs.ErrStorageSizeUnavailable.Error()) + } else { + assert.EqualError(t, err, vfs.ErrStorageSizeUnavailable.Error()) + } + err = conn.GetQuotaExceededError() + assert.True(t, conn.IsQuotaExceededError(err)) + err = conn.GetReadQuotaExceededError() + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } else { + assert.ErrorIs(t, err, ErrReadQuotaExceeded) + } + err = conn.GetNotExistError() + assert.True(t, conn.IsNotExistError(err)) + err = conn.GetFsError(fs, nil) + assert.NoError(t, err) + err = conn.GetOpUnsupportedError() + if protocol == ProtocolSFTP { + assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) + } else { + assert.EqualError(t, err, ErrOpUnsupported.Error()) + } + err = conn.GetFsError(fs, ErrShuttingDown) + if protocol == ProtocolSFTP { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + assert.Contains(t, err.Error(), ErrShuttingDown.Error()) + } else { + assert.EqualError(t, err, ErrShuttingDown.Error()) + } + } +} + +func TestMaxWriteSize(t *testing.T) { + permissions := make(map[string][]string) + permissions["/"] = []string{dataprovider.PermAny} + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + Permissions: permissions, + HomeDir: filepath.Clean(os.TempDir()), + }, + } + fs, err := user.GetFilesystem("123") + assert.NoError(t, err) + conn := NewBaseConnection("", ProtocolFTP, "", "", user) + quotaResult := vfs.QuotaCheckResult{ + HasSpace: true, + } + size, err := conn.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) + assert.NoError(t, err) + assert.Equal(t, int64(0), size) + + conn.User.Filters.MaxUploadFileSize = 100 + size, err = conn.GetMaxWriteSize(quotaResult, false, 0, fs.IsUploadResumeSupported()) + assert.NoError(t, err) + assert.Equal(t, int64(100), size) + + quotaResult.QuotaSize = 1000 + size, err = conn.GetMaxWriteSize(quotaResult, false, 50, fs.IsUploadResumeSupported()) + assert.NoError(t, err) + assert.Equal(t, int64(100), size) + + quotaResult.QuotaSize = 1000 + quotaResult.UsedSize = 990 + size, err = conn.GetMaxWriteSize(quotaResult, false, 50, fs.IsUploadResumeSupported()) + assert.NoError(t, err) + assert.Equal(t, int64(60), size) + + quotaResult.QuotaSize = 0 + quotaResult.UsedSize = 0 + size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) + assert.True(t, conn.IsQuotaExceededError(err)) + assert.Equal(t, int64(0), size) + + size, err = conn.GetMaxWriteSize(quotaResult, true, 10, fs.IsUploadResumeSupported()) + assert.NoError(t, err) + assert.Equal(t, int64(90), size) + + fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), "", nil) + size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) + assert.EqualError(t, err, ErrOpUnsupported.Error()) + assert.Equal(t, int64(0), size) +} + +func TestCheckParentDirsErrors(t *testing.T) { + permissions := make(map[string][]string) + permissions["/"] = []string{dataprovider.PermAny} + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + Permissions: permissions, + HomeDir: filepath.Clean(os.TempDir()), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + }, + } + c := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + err := c.CheckParentDirs("/a/dir") + assert.Error(t, err) + + user.FsConfig.Provider = sdk.LocalFilesystemProvider + user.VirtualFolders = nil + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + }, + }, + VirtualPath: "/vdir", + }) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Clean(os.TempDir()), + }, + VirtualPath: "/vdir/sub", + }) + c = NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + err = c.CheckParentDirs("/vdir/sub/dir") + assert.Error(t, err) + + user = dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: userTestUsername, + Permissions: permissions, + HomeDir: filepath.Clean(os.TempDir()), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "buck", + Region: "us-east-1", + AccessKey: "key", + }, + AccessSecret: kms.NewPlainSecret("s3secret"), + }, + }, + } + c = NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + err = c.CheckParentDirs("/a/dir") + assert.NoError(t, err) + + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Clean(os.TempDir()), + }, + VirtualPath: "/local/dir", + }) + + c = NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + err = c.CheckParentDirs("/local/dir/sub-dir") + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), "sub-dir")) + assert.NoError(t, err) +} + +func TestErrorResolvePath(t *testing.T) { + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Join(os.TempDir(), "u"), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + u.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "f", + MappedPath: filepath.Join(os.TempDir(), "f"), + }, + VirtualPath: "/f", + }, + } + + conn := NewBaseConnection("", ProtocolSFTP, "", "", u) + err := conn.doRecursiveRemoveDirEntry("/vpath", nil, 0) + assert.Error(t, err) + err = conn.doRecursiveRemove(nil, "/fspath", "/vpath", vfs.NewFileInfo("vpath", true, 0, time.Now(), false), 2000) + assert.Error(t, err, util.ErrRecursionTooDeep) + err = conn.doRecursiveCopy("/src", "/dst", vfs.NewFileInfo("src", true, 0, time.Now(), false), false, 2000) + assert.Error(t, err, util.ErrRecursionTooDeep) + err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/source", "/target") + assert.Error(t, err) + sourceFile := filepath.Join(os.TempDir(), "f", "source") + err = os.MkdirAll(filepath.Dir(sourceFile), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(sourceFile, []byte(""), 0666) + assert.NoError(t, err) + err = conn.checkCopy(vfs.NewFileInfo("name", true, 0, time.Unix(0, 0), false), nil, "/f/source", "/target") + assert.Error(t, err) + err = conn.checkCopy(vfs.NewFileInfo("source", false, 0, time.Unix(0, 0), false), vfs.NewFileInfo("target", true, 0, time.Unix(0, 0), false), "/f/source", "/f/target") + assert.Error(t, err) + err = os.RemoveAll(filepath.Dir(sourceFile)) + assert.NoError(t, err) +} + +func TestConnectionKeepAlive(t *testing.T) { + conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{}) + lastActivity := conn.GetLastActivity() + + stop := keepConnectionAlive(conn, 50*time.Millisecond) + defer stop() + + time.Sleep(200 * time.Millisecond) + assert.Greater(t, conn.GetLastActivity(), lastActivity) +} + +func TestFsFileCopier(t *testing.T) { + fs := vfs.Fs(&vfs.AzureBlobFs{}) + _, ok := fs.(vfs.FsFileCopier) + assert.True(t, ok) + fs = vfs.Fs(&vfs.OsFs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.False(t, ok) + fs = vfs.Fs(&vfs.SFTPFs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.False(t, ok) + fs = vfs.Fs(&vfs.GCSFs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.True(t, ok) + fs = vfs.Fs(&vfs.S3Fs{}) + _, ok = fs.(vfs.FsFileCopier) + assert.True(t, ok) +} + +func TestFilePatterns(t *testing.T) { + filters := dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + FilePatterns: []sdk.PatternsFilter{ + { + Path: "/dir1", + DenyPolicy: sdk.DenyPolicyDefault, + AllowedPatterns: []string{"*.jpg"}, + }, + { + Path: "/dir2", + DenyPolicy: sdk.DenyPolicyHide, + AllowedPatterns: []string{"*.jpg"}, + }, + { + Path: "/dir3", + DenyPolicy: sdk.DenyPolicyDefault, + DeniedPatterns: []string{"*.jpg"}, + }, + { + Path: "/dir4", + DenyPolicy: sdk.DenyPolicyHide, + DeniedPatterns: []string{"*"}, + }, + }, + }, + } + virtualFolders := []vfs.VirtualFolder{ + { + VirtualPath: "/dir1/vdir1", + }, + { + VirtualPath: "/dir1/vdir2", + }, + { + VirtualPath: "/dir1/vdir3", + }, + { + VirtualPath: "/dir2/vdir1", + }, + { + VirtualPath: "/dir2/vdir2", + }, + { + VirtualPath: "/dir2/vdir3.jpg", + }, + } + user := dataprovider.User{ + Filters: filters, + VirtualFolders: virtualFolders, + } + + getFilteredInfo := func(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { + result := user.FilterListDir(dirContents, virtualPath) + result = append(result, user.GetVirtualFoldersInfo(virtualPath)...) + return result + } + + dirContents := []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + // dirContents are modified in place, we need to redefine them each time + filtered := getFilteredInfo(dirContents, "/dir1") + assert.Len(t, filtered, 5) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir1/vdir1") + assert.Len(t, filtered, 2) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir2/vdir2") + require.Len(t, filtered, 1) + assert.Equal(t, "file1.jpg", filtered[0].Name()) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir2/vdir2/sub") + require.Len(t, filtered, 1) + assert.Equal(t, "file1.jpg", filtered[0].Name()) + + res, _ := user.IsFileAllowed("/dir1/vdir1/file.txt") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir1/vdir1/sub/file.txt") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir1/vdir1/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir1/vdir1/sub/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/dir1/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/dir1/sub/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir4/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir4/dir1/sub/file.jpg") + assert.False(t, res) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir4") + require.Len(t, filtered, 0) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir4/vdir2/sub") + require.Len(t, filtered, 0) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + + filtered = getFilteredInfo(dirContents, "/dir2") + assert.Len(t, filtered, 2) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + + filtered = getFilteredInfo(dirContents, "/dir4") + assert.Len(t, filtered, 0) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + } + + filtered = getFilteredInfo(dirContents, "/dir4/sub") + assert.Len(t, filtered, 0) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), + } + + filtered = getFilteredInfo(dirContents, "/dir1") + assert.Len(t, filtered, 5) + + filtered = getFilteredInfo(dirContents, "/dir2") + if assert.Len(t, filtered, 1) { + assert.True(t, filtered[0].IsDir()) + } + + user.VirtualFolders = nil + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir1") + assert.Len(t, filtered, 2) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir2") + if assert.Len(t, filtered, 1) { + assert.False(t, filtered[0].IsDir()) + } + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("vdir3.jpg", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir2") + if assert.Len(t, filtered, 2) { + assert.False(t, filtered[0].IsDir()) + assert.False(t, filtered[1].IsDir()) + } + + user.VirtualFolders = virtualFolders + user.Filters = filters + filtered = getFilteredInfo(nil, "/dir1") + assert.Len(t, filtered, 3) + filtered = getFilteredInfo(nil, "/dir2") + assert.Len(t, filtered, 1) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jPg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir2") + assert.Len(t, filtered, 2) + + user = dataprovider.User{ + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + FilePatterns: []sdk.PatternsFilter{ + { + Path: "/dir3", + AllowedPatterns: []string{"ic35"}, + DeniedPatterns: []string{"*"}, + DenyPolicy: sdk.DenyPolicyHide, + }, + }, + }, + }, + } + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("vdir3.jpg", false, 456, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3") + assert.Len(t, filtered, 0) + + dirContents = nil + for i := 0; i < 100; i++ { + dirContents = append(dirContents, vfs.NewFileInfo(fmt.Sprintf("ic%02d", i), i%2 == 0, int64(i), time.Now(), false)) + } + dirContents = append(dirContents, vfs.NewFileInfo("ic350", false, 123, time.Now(), false)) + dirContents = append(dirContents, vfs.NewFileInfo(".ic35", false, 123, time.Now(), false)) + dirContents = append(dirContents, vfs.NewFileInfo("ic35.", false, 123, time.Now(), false)) + dirContents = append(dirContents, vfs.NewFileInfo("*ic35", false, 123, time.Now(), false)) + dirContents = append(dirContents, vfs.NewFileInfo("ic35*", false, 123, time.Now(), false)) + dirContents = append(dirContents, vfs.NewFileInfo("ic35.*", false, 123, time.Now(), false)) + dirContents = append(dirContents, vfs.NewFileInfo("file.jpg", false, 123, time.Now(), false)) + + filtered = getFilteredInfo(dirContents, "/dir3") + require.Len(t, filtered, 1) + assert.Equal(t, "ic35", filtered[0].Name()) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic36") + require.Len(t, filtered, 0) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35") + require.Len(t, filtered, 3) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub") + require.Len(t, filtered, 3) + + res, _ = user.IsFileAllowed("/dir3/file.txt") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35a") + assert.False(t, res) + res, policy := user.IsFileAllowed("/dir3/ic35a/file") + assert.False(t, res) + assert.Equal(t, sdk.DenyPolicyHide, policy) + res, _ = user.IsFileAllowed("/dir3/ic35") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/file.txt") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub/file.txt") + assert.True(t, res) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub") + require.Len(t, filtered, 3) + + user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ + Path: "/dir3/ic35/sub1", + AllowedPatterns: []string{"*.jpg"}, + DenyPolicy: sdk.DenyPolicyDefault, + }) + user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ + Path: "/dir3/ic35/sub2", + DeniedPatterns: []string{"*.jpg"}, + DenyPolicy: sdk.DenyPolicyHide, + }) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub1") + require.Len(t, filtered, 3) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2") + require.Len(t, filtered, 2) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35/sub2/sub1") + require.Len(t, filtered, 2) + + res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/file.txt") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub/dir/file.txt") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub/dir/file.jpg") + assert.True(t, res) + + res, _ = user.IsFileAllowed("/dir3/ic35/sub1/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub1/file.txt") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub1/sub/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub1/sub2/file.txt") + assert.False(t, res) + + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.txt") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub1/file.txt") + assert.True(t, res) + + user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ + Path: "/dir3/ic35", + DeniedPatterns: []string{"*.txt"}, + DenyPolicy: sdk.DenyPolicyHide, + }) + res, _ = user.IsFileAllowed("/dir3/ic35/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/file.txt") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/adir/sub/file.jpg") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/adir/file.txt") + assert.False(t, res) + + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/file.txt") + assert.True(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub/file.jpg") + assert.False(t, res) + res, _ = user.IsFileAllowed("/dir3/ic35/sub2/sub1/file.txt") + assert.True(t, res) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35") + require.Len(t, filtered, 1) + + dirContents = []os.FileInfo{ + vfs.NewFileInfo("file1.jpg", false, 123, time.Now(), false), + vfs.NewFileInfo("file1.txt", false, 123, time.Now(), false), + vfs.NewFileInfo("file2.txt", false, 123, time.Now(), false), + } + filtered = getFilteredInfo(dirContents, "/dir3/ic35/abc") + require.Len(t, filtered, 1) +} + +func TestStatForOngoingTransfers(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: xid.New().String(), + Password: xid.New().String(), + HomeDir: filepath.Clean(os.TempDir()), + Status: 1, + Permissions: map[string][]string{ + "/": {"*"}, + }, + }, + } + fileName := "file.txt" + conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + tr := NewBaseTransfer(nil, conn, nil, filepath.Join(os.TempDir(), fileName), filepath.Join(os.TempDir(), fileName), + fileName, TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + _, err := conn.DoStat("/file.txt", 0, false) + assert.NoError(t, err) + err = tr.Close() + assert.NoError(t, err) + tr = NewBaseTransfer(nil, conn, nil, filepath.Join(os.TempDir(), fileName), filepath.Join(os.TempDir(), fileName), + fileName, TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + _, err = conn.DoStat("/file.txt", 0, false) + assert.Error(t, err) + err = tr.Close() + assert.NoError(t, err) + err = conn.CloseFS() + assert.NoError(t, err) +} + +func TestListerAt(t *testing.T) { + dir := t.TempDir() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "u", + Password: "p", + HomeDir: dir, + Status: 1, + Permissions: map[string][]string{ + "/": {"*"}, + }, + }, + } + conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + lister, err := conn.ListDir("/") + require.NoError(t, err) + files, err := lister.Next(1) + require.ErrorIs(t, err, io.EOF) + require.Len(t, files, 0) + err = lister.Close() + require.NoError(t, err) + + conn.User.VirtualFolders = []vfs.VirtualFolder{ + { + VirtualPath: "p1", + }, + { + VirtualPath: "p2", + }, + { + VirtualPath: "p3", + }, + } + lister, err = conn.ListDir("/") + require.NoError(t, err) + files, err = lister.Next(2) + // virtual directories exceeds the limit + require.ErrorIs(t, err, io.EOF) + require.Len(t, files, 3) + files, err = lister.Next(2) + require.ErrorIs(t, err, io.EOF) + require.Len(t, files, 0) + _, err = lister.Next(-1) + require.ErrorContains(t, err, conn.GetGenericError(err).Error()) + err = lister.Close() + require.NoError(t, err) + + lister, err = conn.ListDir("/") + require.NoError(t, err) + _, err = lister.ListAt(nil, 0) + require.ErrorContains(t, err, "zero size") + err = lister.Close() + require.NoError(t, err) + + for i := 0; i < 100; i++ { + f, err := os.Create(filepath.Join(dir, strconv.Itoa(i))) + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + } + lister, err = conn.ListDir("/") + require.NoError(t, err) + files = make([]os.FileInfo, 18) + n, err := lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 18, n) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 18, n) + files = make([]os.FileInfo, 100) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 64+3, n) + n, err = lister.ListAt(files, 0) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + n, err = lister.ListAt(files, 0) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + err = lister.Close() + require.NoError(t, err) + n, err = lister.ListAt(files, 0) + assert.Error(t, err) + assert.NotErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + lister, err = conn.ListDir("/") + require.NoError(t, err) + lister.Prepend(vfs.NewFileInfo("..", true, 0, time.Unix(0, 0), false)) + lister.Prepend(vfs.NewFileInfo(".", true, 0, time.Unix(0, 0), false)) + files = make([]os.FileInfo, 1) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 1, n) + assert.Equal(t, ".", files[0].Name()) + files = make([]os.FileInfo, 2) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 2, n) + assert.Equal(t, "..", files[0].Name()) + vfolders := []string{files[1].Name()} + files = make([]os.FileInfo, 200) + n, err = lister.ListAt(files, 0) + require.NoError(t, err) + require.Equal(t, 102, n) + vfolders = append(vfolders, files[0].Name()) + vfolders = append(vfolders, files[1].Name()) + assert.Contains(t, vfolders, "p1") + assert.Contains(t, vfolders, "p2") + assert.Contains(t, vfolders, "p3") + err = lister.Close() + require.NoError(t, err) +} + +func TestGetFsAndResolvedPath(t *testing.T) { + homeDir := filepath.Join(os.TempDir(), "home_test") + localVdir := filepath.Join(os.TempDir(), "local_mount_test") + + err := os.MkdirAll(homeDir, 0777) + require.NoError(t, err) + err = os.MkdirAll(localVdir, 0777) + require.NoError(t, err) + + t.Cleanup(func() { + os.RemoveAll(homeDir) + os.RemoveAll(localVdir) + }) + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: xid.New().String(), + Status: 1, + HomeDir: homeDir, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "s3", + MappedPath: "", + FsConfig: vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "my-test-bucket", + Region: "us-east-1", + }, + }, + }, + }, + VirtualPath: "/s3", + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "local", + MappedPath: localVdir, + FsConfig: vfs.Filesystem{ + Provider: sdk.LocalFilesystemProvider, + }, + }, + VirtualPath: "/local", + }, + }, + } + + conn := NewBaseConnection(xid.New().String(), ProtocolSFTP, "", "", user) + + tests := []struct { + name string + inputVirtualPath string + expectedFsType string + expectedPhyPath string // The resolved path on the target FS + expectedRelativePath string + }{ + { + name: "Root File", + inputVirtualPath: "/file.txt", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(homeDir, "file.txt"), + expectedRelativePath: "/file.txt", + }, + { + name: "Standard S3 File", + inputVirtualPath: "/s3/image.png", + expectedFsType: "S3Fs", + expectedPhyPath: "image.png", + expectedRelativePath: "/s3/image.png", + }, + { + name: "Standard Local Mount File", + inputVirtualPath: "/local/config.json", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(localVdir, "config.json"), + expectedRelativePath: "/local/config.json", + }, + + { + name: "Backslash Separator -> Should hit S3", + inputVirtualPath: "\\s3\\doc.txt", + expectedFsType: "S3Fs", + expectedPhyPath: "doc.txt", + expectedRelativePath: "/s3/doc.txt", + }, + { + name: "Mixed Separators -> Should hit Local Mount", + inputVirtualPath: "/local\\subdir/test.txt", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(localVdir, "subdir", "test.txt"), + expectedRelativePath: "/local/subdir/test.txt", + }, + { + name: "Double Slash -> Should normalize and hit S3", + inputVirtualPath: "//s3//dir @1/data.csv", + expectedFsType: "S3Fs", + expectedPhyPath: "dir @1/data.csv", + expectedRelativePath: "/s3/dir @1/data.csv", + }, + + { + name: "Local Mount Traversal (Attempt to escape)", + inputVirtualPath: "/local/../../etc/passwd", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(homeDir, "/etc/passwd"), + expectedRelativePath: "/etc/passwd", + }, + { + name: "Traversal Out of S3 (Valid)", + inputVirtualPath: "/s3/../../secret.txt", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(homeDir, "secret.txt"), + expectedRelativePath: "/secret.txt", + }, + { + name: "Traversal Inside S3", + inputVirtualPath: "/s3/subdir/../image.png", + expectedFsType: "S3Fs", + expectedPhyPath: "image.png", + expectedRelativePath: "/s3/image.png", + }, + { + name: "Mount Point Bypass -> Target Local Mount", + inputVirtualPath: "/s3\\..\\local\\secret.txt", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(localVdir, "secret.txt"), + expectedRelativePath: "/local/secret.txt", + }, + { + name: "Dirty Relative Path (Your Case)", + inputVirtualPath: "test\\..\\..\\oops/file.txt", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(homeDir, "oops", "file.txt"), + expectedRelativePath: "/oops/file.txt", + }, + { + name: "Relative Path targeting S3 (No leading slash)", + inputVirtualPath: "s3//sub/../image.png", + expectedFsType: "S3Fs", + expectedPhyPath: "image.png", + expectedRelativePath: "/s3/image.png", + }, + { + name: "Windows Path starting with Backslash", + inputVirtualPath: "\\s3\\doc/dir\\doc.txt", + expectedFsType: "S3Fs", + expectedPhyPath: "doc/dir/doc.txt", + expectedRelativePath: "/s3/doc/dir/doc.txt", + }, + { + name: "Filesystem Juggling (Relative)", + inputVirtualPath: "local/../s3/file.txt", + expectedFsType: "S3Fs", + expectedPhyPath: "file.txt", + expectedRelativePath: "/s3/file.txt", + }, + { + name: "Triple Dot Filename (Valid Name)", + inputVirtualPath: "/...hidden/secret", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(homeDir, "...hidden", "secret"), + expectedRelativePath: "/...hidden/secret", + }, + { + name: "Dot Slash Prefix", + inputVirtualPath: "./local/file.txt", + expectedFsType: "osfs", + expectedPhyPath: filepath.Join(localVdir, "file.txt"), + expectedRelativePath: "/local/file.txt", + }, + { + name: "Root of Local Mount Exactly", + inputVirtualPath: "/local/", + expectedFsType: "osfs", + expectedPhyPath: localVdir, + expectedRelativePath: "/local", + }, + { + name: "Root of S3 Mount Exactly", + inputVirtualPath: "/s3/", + expectedFsType: "S3Fs", + expectedPhyPath: "", + expectedRelativePath: "/s3", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // The input path is sanitized by the protocol handler + // implementations before reaching GetFsAndResolvedPath. + cleanInput := util.CleanPath(tc.inputVirtualPath) + fs, resolvedPath, err := conn.GetFsAndResolvedPath(cleanInput) + if assert.NoError(t, err, "did not expect error for path: %q, got: %v", tc.inputVirtualPath, err) { + assert.Contains(t, fs.Name(), tc.expectedFsType, + "routing error: input %q but expected fs %q, got %q", tc.inputVirtualPath, tc.expectedFsType, fs.Name()) + assert.Equal(t, tc.expectedPhyPath, resolvedPath, + "resolution error: input %q resolved to %q expected %q", tc.inputVirtualPath, resolvedPath, tc.expectedPhyPath) + relativePath := fs.GetRelativePath(resolvedPath) + assert.Equal(t, tc.expectedRelativePath, relativePath, + "relative path error, input %q, got %q, expected %q", tc.inputVirtualPath, tc.expectedRelativePath, relativePath) + } + }) + } +} + +func TestOsFsGetRelativePath(t *testing.T) { + homeDir := filepath.Join(os.TempDir(), "home_test") + localVdir := filepath.Join(os.TempDir(), "local_mount_test") + + err := os.MkdirAll(homeDir, 0777) + require.NoError(t, err) + err = os.MkdirAll(localVdir, 0777) + require.NoError(t, err) + + t.Cleanup(func() { + os.RemoveAll(homeDir) + os.RemoveAll(localVdir) + }) + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: xid.New().String(), + Status: 1, + HomeDir: homeDir, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "local", + MappedPath: localVdir, + FsConfig: vfs.Filesystem{ + Provider: sdk.LocalFilesystemProvider, + }, + }, + VirtualPath: "/local", + }, + }, + } + + connID := xid.New().String() + rootFs, err := user.GetFilesystemForPath("/", connID) + require.NoError(t, err) + + localFs, err := user.GetFilesystemForPath("/local", connID) + require.NoError(t, err) + + tests := []struct { + name string + fs vfs.Fs + inputPath string // The physical path to reverse-map + expectedRel string // The expected virtual path + }{ + { + name: "Root FS - Inside root", + fs: rootFs, + inputPath: filepath.Join(homeDir, "docs", "file.txt"), + expectedRel: "/docs/file.txt", + }, + { + name: "Root FS - Exact root directory", + fs: rootFs, + inputPath: homeDir, + expectedRel: "/", + }, + { + name: "Root FS - External absolute path (Jail to /)", + fs: rootFs, + inputPath: "/etc/passwd", + expectedRel: "/", + }, + { + name: "Root FS - Traversal escape (Jail to /)", + fs: rootFs, + inputPath: filepath.Join(homeDir, "..", "escaped.txt"), + expectedRel: "/", + }, + { + name: "Root FS - Valid file named with triple dots", + fs: rootFs, + inputPath: filepath.Join(homeDir, "..."), + expectedRel: "/...", + }, + { + name: "Local FS - Up path in dir", + fs: rootFs, + inputPath: homeDir + "/../" + filepath.Base(homeDir) + "/dir/test.txt", + expectedRel: "/dir/test.txt", + }, + + { + name: "Local FS - Inside mount", + fs: localFs, + inputPath: filepath.Join(localVdir, "data", "config.json"), + expectedRel: "/local/data/config.json", + }, + { + name: "Local FS - Exact mount directory", + fs: localFs, + inputPath: localVdir, + expectedRel: "/local", + }, + { + name: "Local FS - External absolute path (Jail to /local)", + fs: localFs, + inputPath: "/var/log/syslog", + expectedRel: "/local", + }, + { + name: "Local FS - Traversal escape (Jail to /local)", + fs: localFs, + inputPath: filepath.Join(localVdir, "..", "..", "etc", "passwd"), + expectedRel: "/local", + }, + { + name: "Local FS - Partial prefix (Jail to /local)", + fs: localFs, + inputPath: localVdir + "_backup", + expectedRel: "/local", + }, + { + name: "Local FS - Relative traversal matching virual dir", + fs: localFs, + inputPath: localVdir + "/../" + filepath.Base(localVdir) + "/dir/test.txt", + expectedRel: "/local/dir/test.txt", + }, + { + name: "Local FS - Valid file starting with two dots", + fs: localFs, + inputPath: filepath.Join(localVdir, "..hidden_file.txt"), + expectedRel: "/local/..hidden_file.txt", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actualRel := tc.fs.GetRelativePath(tc.inputPath) + assert.Equal(t, tc.expectedRel, actualRel, + "Failed mapping physical path %q on FS %q", tc.inputPath, tc.fs.Name()) + }) + } +} diff --git a/internal/common/dataretention.go b/internal/common/dataretention.go new file mode 100644 index 00000000..cbde9521 --- /dev/null +++ b/internal/common/dataretention.go @@ -0,0 +1,299 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "io" + "os" + "path" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + // RetentionChecks is the list of active retention checks + RetentionChecks ActiveRetentionChecks +) + +// ActiveRetentionChecks holds the active retention checks +type ActiveRetentionChecks struct { + sync.RWMutex + Checks []RetentionCheck +} + +// Get returns the active retention checks +func (c *ActiveRetentionChecks) Get(role string) []RetentionCheck { + c.RLock() + defer c.RUnlock() + + checks := make([]RetentionCheck, 0, len(c.Checks)) + for _, check := range c.Checks { + if role == "" || role == check.Role { + foldersCopy := make([]dataprovider.FolderRetention, len(check.Folders)) + copy(foldersCopy, check.Folders) + checks = append(checks, RetentionCheck{ + Username: check.Username, + StartTime: check.StartTime, + Folders: foldersCopy, + }) + } + } + return checks +} + +// Add a new retention check, returns nil if a retention check for the given +// username is already active. The returned result can be used to start the check +func (c *ActiveRetentionChecks) Add(check RetentionCheck, user *dataprovider.User) *RetentionCheck { + c.Lock() + defer c.Unlock() + + for _, val := range c.Checks { + if val.Username == user.Username { + return nil + } + } + // we silently ignore file patterns + user.Filters.FilePatterns = nil + conn := NewBaseConnection("", "", "", "", *user) + conn.SetProtocol(ProtocolDataRetention) + conn.ID = fmt.Sprintf("data_retention_%v", user.Username) + check.Username = user.Username + check.Role = user.Role + check.StartTime = util.GetTimeAsMsSinceEpoch(time.Now()) + check.conn = conn + check.updateUserPermissions() + c.Checks = append(c.Checks, check) + + return &check +} + +// remove a user from the ones with active retention checks +// and returns true if the user is removed +func (c *ActiveRetentionChecks) remove(username string) bool { + c.Lock() + defer c.Unlock() + + for idx, check := range c.Checks { + if check.Username == username { + lastIdx := len(c.Checks) - 1 + c.Checks[idx] = c.Checks[lastIdx] + c.Checks = c.Checks[:lastIdx] + return true + } + } + + return false +} + +type folderRetentionCheckResult struct { + Path string `json:"path"` + Retention int `json:"retention"` + DeletedFiles int `json:"deleted_files"` + DeletedSize int64 `json:"deleted_size"` + Elapsed time.Duration `json:"-"` + Info string `json:"info,omitempty"` + Error string `json:"error,omitempty"` +} + +// RetentionCheck defines an active retention check +type RetentionCheck struct { + // Username to which the retention check refers + Username string `json:"username"` + // retention check start time as unix timestamp in milliseconds + StartTime int64 `json:"start_time"` + // affected folders + Folders []dataprovider.FolderRetention `json:"folders"` + Role string `json:"-"` + // Cleanup results + results []folderRetentionCheckResult `json:"-"` + conn *BaseConnection `json:"-"` +} + +func (c *RetentionCheck) updateUserPermissions() { + for k := range c.conn.User.Permissions { + c.conn.User.Permissions[k] = []string{dataprovider.PermAny} + } +} + +func (c *RetentionCheck) getFolderRetention(folderPath string) (dataprovider.FolderRetention, error) { + dirsForPath := util.GetDirsForVirtualPath(folderPath) + for _, dirPath := range dirsForPath { + for _, folder := range c.Folders { + if folder.Path == dirPath { + return folder, nil + } + } + } + + return dataprovider.FolderRetention{}, fmt.Errorf("unable to find folder retention for %q", folderPath) +} + +func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error { + fs, fsPath, err := c.conn.GetFsAndResolvedPath(virtualPath) + if err != nil { + return err + } + return c.conn.RemoveFile(fs, fsPath, virtualPath, info) +} + +func (c *RetentionCheck) cleanupFolder(folderPath string, recursion int) error { + startTime := time.Now() + result := folderRetentionCheckResult{ + Path: folderPath, + } + defer func() { + c.results = append(c.results, result) + }() + if recursion >= util.MaxRecursion { + result.Elapsed = time.Since(startTime) + result.Info = "data retention check skipped: recursion too deep" + c.conn.Log(logger.LevelError, "data retention check skipped, recursion too depth for %q: %d", + folderPath, recursion) + return util.ErrRecursionTooDeep + } + recursion++ + + folderRetention, err := c.getFolderRetention(folderPath) + if err != nil { + result.Elapsed = time.Since(startTime) + result.Error = "unable to get folder retention" + c.conn.Log(logger.LevelError, "unable to get folder retention for path %q", folderPath) + return err + } + result.Retention = folderRetention.Retention + if folderRetention.Retention == 0 { + result.Elapsed = time.Since(startTime) + result.Info = "data retention check skipped: retention is set to 0" + c.conn.Log(logger.LevelDebug, "retention check skipped for folder %q, retention is set to 0", folderPath) + return nil + } + c.conn.Log(logger.LevelDebug, "start retention check for folder %q, retention: %v hours, delete empty dirs? %v", + folderPath, folderRetention.Retention, folderRetention.DeleteEmptyDirs) + lister, err := c.conn.ListDir(folderPath) + if err != nil { + result.Elapsed = time.Since(startTime) + if err == c.conn.GetNotExistError() { + result.Info = "data retention check skipped, folder does not exist" + c.conn.Log(logger.LevelDebug, "folder %q does not exist, retention check skipped", folderPath) + return nil + } + result.Error = fmt.Sprintf("unable to get lister for directory %q", folderPath) + c.conn.Log(logger.LevelError, "%s", result.Error) + return err + } + defer lister.Close() + + for { + files, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err := lister.convertError(err); err != nil { + result.Elapsed = time.Since(startTime) + result.Error = fmt.Sprintf("unable to list directory %q", folderPath) + c.conn.Log(logger.LevelError, "unable to list dir %q: %v", folderPath, err) + return err + } + for _, info := range files { + virtualPath := path.Join(folderPath, info.Name()) + if info.IsDir() { + if err := c.cleanupFolder(virtualPath, recursion); err != nil { + result.Elapsed = time.Since(startTime) + result.Error = fmt.Sprintf("unable to check folder: %v", err) + c.conn.Log(logger.LevelError, "unable to cleanup folder %q: %v", virtualPath, err) + return err + } + } else { + retentionTime := info.ModTime().Add(time.Duration(folderRetention.Retention) * time.Hour) + if retentionTime.Before(time.Now()) { + if err := c.removeFile(virtualPath, info); err != nil { + result.Elapsed = time.Since(startTime) + result.Error = fmt.Sprintf("unable to remove file %q: %v", virtualPath, err) + c.conn.Log(logger.LevelError, "unable to remove file %q, retention %v: %v", + virtualPath, retentionTime, err) + return err + } + c.conn.Log(logger.LevelDebug, "removed file %q, modification time: %v, retention: %v hours, retention time: %v", + virtualPath, info.ModTime(), folderRetention.Retention, retentionTime) + result.DeletedFiles++ + result.DeletedSize += info.Size() + } + } + } + if finished { + break + } + } + + lister.Close() + c.checkEmptyDirRemoval(folderPath, folderRetention.DeleteEmptyDirs) + result.Elapsed = time.Since(startTime) + c.conn.Log(logger.LevelDebug, "retention check completed for folder %q, deleted files: %v, deleted size: %v bytes", + folderPath, result.DeletedFiles, result.DeletedSize) + + return nil +} + +func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string, checkVal bool) { + if folderPath == "/" || !checkVal { + return + } + for _, folder := range c.Folders { + if folderPath == folder.Path { + return + } + } + if c.conn.User.HasAnyPerm([]string{ + dataprovider.PermDelete, + dataprovider.PermDeleteDirs, + }, path.Dir(folderPath), + ) { + lister, err := c.conn.ListDir(folderPath) + if err == nil { + files, err := lister.Next(1) + lister.Close() + if len(files) == 0 && errors.Is(err, io.EOF) { + err = c.conn.RemoveDir(folderPath) + c.conn.Log(logger.LevelDebug, "tried to remove empty dir %q, error: %v", folderPath, err) + } + } + } +} + +// Start starts the retention check +func (c *RetentionCheck) Start() error { + c.conn.Log(logger.LevelInfo, "retention check started") + defer RetentionChecks.remove(c.conn.User.Username) + defer c.conn.CloseFS() //nolint:errcheck + + startTime := time.Now() + for _, folder := range c.Folders { + if folder.Retention > 0 { + if err := c.cleanupFolder(folder.Path, 0); err != nil { + c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %q, elapsed: %s", + folder.Path, time.Since(startTime)) + return err + } + } + } + + c.conn.Log(logger.LevelInfo, "retention check completed, elapsed: %s", time.Since(startTime)) + return nil +} diff --git a/internal/common/dataretention_test.go b/internal/common/dataretention_test.go new file mode 100644 index 00000000..b3ce46df --- /dev/null +++ b/internal/common/dataretention_test.go @@ -0,0 +1,182 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "fmt" + "testing" + + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func TestRetentionPermissionsAndGetFolder(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user1", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDelete} + user.Permissions["/dir1"] = []string{dataprovider.PermListItems} + user.Permissions["/dir2/sub1"] = []string{dataprovider.PermCreateDirs} + user.Permissions["/dir2/sub2"] = []string{dataprovider.PermDelete} + + check := RetentionCheck{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/dir2", + Retention: 24 * 7, + }, + { + Path: "/dir3", + Retention: 24 * 7, + }, + { + Path: "/dir2/sub1/sub", + Retention: 24, + }, + }, + } + + conn := NewBaseConnection("", "", "", "", user) + conn.SetProtocol(ProtocolDataRetention) + conn.ID = fmt.Sprintf("data_retention_%v", user.Username) + check.conn = conn + check.updateUserPermissions() + assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/"]) + assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir1"]) + assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2/sub1"]) + assert.Equal(t, []string{dataprovider.PermAny}, conn.User.Permissions["/dir2/sub2"]) + + _, err := check.getFolderRetention("/") + assert.Error(t, err) + folder, err := check.getFolderRetention("/dir3") + assert.NoError(t, err) + assert.Equal(t, "/dir3", folder.Path) + folder, err = check.getFolderRetention("/dir2/sub3") + assert.NoError(t, err) + assert.Equal(t, "/dir2", folder.Path) + folder, err = check.getFolderRetention("/dir2/sub2") + assert.NoError(t, err) + assert.Equal(t, "/dir2", folder.Path) + folder, err = check.getFolderRetention("/dir2/sub1") + assert.NoError(t, err) + assert.Equal(t, "/dir2", folder.Path) + folder, err = check.getFolderRetention("/dir2/sub1/sub/sub") + assert.NoError(t, err) + assert.Equal(t, "/dir2/sub1/sub", folder.Path) +} + +func TestRetentionCheckAddRemove(t *testing.T) { + username := "username" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + check := RetentionCheck{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 48, + }, + }, + } + assert.NotNil(t, RetentionChecks.Add(check, &user)) + checks := RetentionChecks.Get("") + require.Len(t, checks, 1) + assert.Equal(t, username, checks[0].Username) + assert.Greater(t, checks[0].StartTime, int64(0)) + require.Len(t, checks[0].Folders, 1) + assert.Equal(t, check.Folders[0].Path, checks[0].Folders[0].Path) + assert.Equal(t, check.Folders[0].Retention, checks[0].Folders[0].Retention) + + assert.Nil(t, RetentionChecks.Add(check, &user)) + assert.True(t, RetentionChecks.remove(username)) + require.Len(t, RetentionChecks.Get(""), 0) + assert.False(t, RetentionChecks.remove(username)) +} + +func TestRetentionCheckRole(t *testing.T) { + username := "retuser" + role1 := "retrole1" + role2 := "retrole2" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Role: role1, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + check := RetentionCheck{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 48, + }, + }, + } + assert.NotNil(t, RetentionChecks.Add(check, &user)) + checks := RetentionChecks.Get("") + require.Len(t, checks, 1) + assert.Empty(t, checks[0].Role) + checks = RetentionChecks.Get(role1) + require.Len(t, checks, 1) + checks = RetentionChecks.Get(role2) + require.Len(t, checks, 0) + user.Role = "" + assert.Nil(t, RetentionChecks.Add(check, &user)) + assert.True(t, RetentionChecks.remove(username)) + require.Len(t, RetentionChecks.Get(""), 0) +} + +func TestCleanupErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "u", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + check := &RetentionCheck{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/path", + Retention: 48, + }, + }, + } + check = RetentionChecks.Add(*check, &user) + require.NotNil(t, check) + + err := check.removeFile("missing file", nil) + assert.Error(t, err) + + err = check.cleanupFolder("/", 0) + assert.Error(t, err) + + err = check.cleanupFolder("/", 1000) + assert.ErrorIs(t, err, util.ErrRecursionTooDeep) + + assert.True(t, RetentionChecks.remove(user.Username)) +} diff --git a/internal/common/defender.go b/internal/common/defender.go new file mode 100644 index 00000000..18ea0ea7 --- /dev/null +++ b/internal/common/defender.go @@ -0,0 +1,256 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "fmt" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// HostEvent is the enumerable for the supported host events +type HostEvent string + +// Supported host events +const ( + HostEventLoginFailed HostEvent = "LoginFailed" + HostEventUserNotFound HostEvent = "UserNotFound" + HostEventNoLoginTried HostEvent = "NoLoginTried" + HostEventLimitExceeded HostEvent = "LimitExceeded" +) + +// Supported defender drivers +const ( + DefenderDriverMemory = "memory" + DefenderDriverProvider = "provider" +) + +var ( + supportedDefenderDrivers = []string{DefenderDriverMemory, DefenderDriverProvider} +) + +// Defender defines the interface that a defender must implements +type Defender interface { + GetHosts() ([]dataprovider.DefenderEntry, error) + GetHost(ip string) (dataprovider.DefenderEntry, error) + AddEvent(ip, protocol string, event HostEvent) bool + IsBanned(ip, protocol string) bool + IsSafe(ip, protocol string) bool + GetBanTime(ip string) (*time.Time, error) + GetScore(ip string) (int, error) + DeleteHost(ip string) bool + DelayLogin(err error) +} + +// DefenderConfig defines the "defender" configuration +type DefenderConfig struct { + // Set to true to enable the defender + Enabled bool `json:"enabled" mapstructure:"enabled"` + // Defender implementation to use, we support "memory" and "provider". + // Using "provider" as driver you can share the defender events among + // multiple SFTPGo instances. For a single instance "memory" provider will + // be much faster + Driver string `json:"driver" mapstructure:"driver"` + // BanTime is the number of minutes that a host is banned + BanTime int `json:"ban_time" mapstructure:"ban_time"` + // Percentage increase of the ban time if a banned host tries to connect again + BanTimeIncrement int `json:"ban_time_increment" mapstructure:"ban_time_increment"` + // Threshold value for banning a client + Threshold int `json:"threshold" mapstructure:"threshold"` + // Score for invalid login attempts, eg. non-existent user accounts + ScoreInvalid int `json:"score_invalid" mapstructure:"score_invalid"` + // Score for valid login attempts, eg. user accounts that exist + ScoreValid int `json:"score_valid" mapstructure:"score_valid"` + // Score for limit exceeded events, generated from the rate limiters or for max connections + // per-host exceeded + ScoreLimitExceeded int `json:"score_limit_exceeded" mapstructure:"score_limit_exceeded"` + // ScoreNoAuth defines the score for clients disconnected without authentication + // attempts + ScoreNoAuth int `json:"score_no_auth" mapstructure:"score_no_auth"` + // Defines the time window, in minutes, for tracking client errors. + // A host is banned if it has exceeded the defined threshold during + // the last observation time minutes + ObservationTime int `json:"observation_time" mapstructure:"observation_time"` + // The number of banned IPs and host scores kept in memory will vary between the + // soft and hard limit for the "memory" driver. For the "provider" driver the + // soft limit is ignored and the hard limit is used to limit the number of entries + // to return when you request for the entire host list from the defender + EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"` + EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"` + // Configuration to impose a delay between login attempts + LoginDelay LoginDelay `json:"login_delay" mapstructure:"login_delay"` +} + +// LoginDelay defines the delays to impose between login attempts. +type LoginDelay struct { + // The number of milliseconds to pause prior to allowing a successful login + Success int `json:"success" mapstructure:"success"` + // The number of milliseconds to pause prior to reporting a failed login + PasswordFailed int `json:"password_failed" mapstructure:"password_failed"` +} + +type baseDefender struct { + config *DefenderConfig + ipList *dataprovider.IPList +} + +func (d *baseDefender) isBanned(ip, protocol string) bool { + isListed, mode, err := d.ipList.IsListed(ip, protocol) + if err != nil { + return false + } + if isListed && mode == dataprovider.ListModeDeny { + return true + } + + return false +} + +func (d *baseDefender) IsSafe(ip, protocol string) bool { + isListed, mode, err := d.ipList.IsListed(ip, protocol) + if err == nil && isListed && mode == dataprovider.ListModeAllow { + return true + } + return false +} + +func (d *baseDefender) getScore(event HostEvent) int { + var score int + + switch event { + case HostEventLoginFailed: + score = d.config.ScoreValid + case HostEventLimitExceeded: + score = d.config.ScoreLimitExceeded + case HostEventUserNotFound: + score = d.config.ScoreInvalid + case HostEventNoLoginTried: + score = d.config.ScoreNoAuth + } + return score +} + +// logEvent logs a defender event that changes a host's score +func (d *baseDefender) logEvent(ip, protocol string, event HostEvent, totalScore int) { + // ignore events which do not change the host score + eventScore := d.getScore(event) + if eventScore == 0 { + return + } + + logger.GetLogger().Debug(). + Timestamp(). + Str("sender", "defender"). + Str("client_ip", ip). + Str("protocol", protocol). + Str("event", string(event)). + Int("increase_score_by", eventScore). + Int("score", totalScore). + Send() +} + +// logBan logs a host's ban due to a too high host score +func (d *baseDefender) logBan(ip, protocol string) { + logger.GetLogger().Info(). + Timestamp(). + Str("sender", "defender"). + Str("client_ip", ip). + Str("protocol", protocol). + Str("event", "banned"). + Send() +} + +// DelayLogin applies the configured login delay. +func (d *baseDefender) DelayLogin(err error) { + if err == nil { + if d.config.LoginDelay.Success > 0 { + time.Sleep(time.Duration(d.config.LoginDelay.Success) * time.Millisecond) + } + return + } + if d.config.LoginDelay.PasswordFailed > 0 { + time.Sleep(time.Duration(d.config.LoginDelay.PasswordFailed) * time.Millisecond) + } +} + +type hostEvent struct { + dateTime time.Time + score int +} + +type hostScore struct { + TotalScore int + Events []hostEvent +} + +func (c *DefenderConfig) checkScores() error { + if c.ScoreInvalid < 0 { + c.ScoreInvalid = 0 + } + if c.ScoreValid < 0 { + c.ScoreValid = 0 + } + if c.ScoreLimitExceeded < 0 { + c.ScoreLimitExceeded = 0 + } + if c.ScoreNoAuth < 0 { + c.ScoreNoAuth = 0 + } + if c.ScoreInvalid == 0 && c.ScoreValid == 0 && c.ScoreLimitExceeded == 0 && c.ScoreNoAuth == 0 { + return fmt.Errorf("invalid defender configuration: all scores are disabled") + } + return nil +} + +// validate returns an error if the configuration is invalid +func (c *DefenderConfig) validate() error { + if !c.Enabled { + return nil + } + if err := c.checkScores(); err != nil { + return err + } + if c.ScoreInvalid >= c.Threshold { + return fmt.Errorf("score_invalid %d cannot be greater than threshold %d", c.ScoreInvalid, c.Threshold) + } + if c.ScoreValid >= c.Threshold { + return fmt.Errorf("score_valid %d cannot be greater than threshold %d", c.ScoreValid, c.Threshold) + } + if c.ScoreLimitExceeded >= c.Threshold { + return fmt.Errorf("score_limit_exceeded %d cannot be greater than threshold %d", c.ScoreLimitExceeded, c.Threshold) + } + if c.ScoreNoAuth >= c.Threshold { + return fmt.Errorf("score_no_auth %d cannot be greater than threshold %d", c.ScoreNoAuth, c.Threshold) + } + if c.BanTime <= 0 { + return fmt.Errorf("invalid ban_time %v", c.BanTime) + } + if c.BanTimeIncrement <= 0 { + return fmt.Errorf("invalid ban_time_increment %v", c.BanTimeIncrement) + } + if c.ObservationTime <= 0 { + return fmt.Errorf("invalid observation_time %v", c.ObservationTime) + } + if c.EntriesSoftLimit <= 0 { + return fmt.Errorf("invalid entries_soft_limit %v", c.EntriesSoftLimit) + } + if c.EntriesHardLimit <= c.EntriesSoftLimit { + return fmt.Errorf("invalid entries_hard_limit %v must be > %v", c.EntriesHardLimit, c.EntriesSoftLimit) + } + + return nil +} diff --git a/internal/common/defender_test.go b/internal/common/defender_test.go new file mode 100644 index 00000000..f7cb9ad4 --- /dev/null +++ b/internal/common/defender_test.go @@ -0,0 +1,643 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "encoding/hex" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/yl2chen/cidranger" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" +) + +func TestBasicDefender(t *testing.T) { + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "172.16.1.1/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "172.16.1.2/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "10.8.0.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "192.168.1.1/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "192.168.1.2/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "10.8.9.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "172.16.1.3/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "172.16.1.4/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "192.168.8.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "192.168.1.3/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "192.168.1.4/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "192.168.9.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + } + + for idx := range entries { + e := entries[idx] + err := dataprovider.AddIPListEntry(&e, "", "", "") + assert.NoError(t, err) + } + + config := &DefenderConfig{ + Enabled: true, + BanTime: 10, + BanTimeIncrement: 2, + Threshold: 5, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreNoAuth: 2, + ScoreLimitExceeded: 3, + ObservationTime: 15, + EntriesSoftLimit: 1, + EntriesHardLimit: 2, + } + + d, err := newInMemoryDefender(config) + assert.NoError(t, err) + + defender := d.(*memoryDefender) + assert.True(t, defender.IsBanned("172.16.1.1", ProtocolSSH)) + assert.True(t, defender.IsBanned("192.168.1.1", ProtocolFTP)) + assert.False(t, defender.IsBanned("172.16.1.10", ProtocolSSH)) + assert.False(t, defender.IsBanned("192.168.1.10", ProtocolSSH)) + assert.False(t, defender.IsBanned("10.8.2.3", ProtocolSSH)) + assert.False(t, defender.IsBanned("10.9.2.3", ProtocolSSH)) + assert.True(t, defender.IsBanned("10.8.0.3", ProtocolSSH)) + assert.True(t, defender.IsBanned("10.8.9.3", ProtocolSSH)) + assert.False(t, defender.IsBanned("invalid ip", ProtocolSSH)) + assert.Equal(t, 0, defender.countBanned()) + assert.Equal(t, 0, defender.countHosts()) + hosts, err := defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 0) + _, err = defender.GetHost("10.8.0.4") + assert.Error(t, err) + + defender.AddEvent("172.16.1.4", ProtocolSSH, HostEventLoginFailed) + defender.AddEvent("192.168.1.4", ProtocolSSH, HostEventLoginFailed) + defender.AddEvent("192.168.8.4", ProtocolSSH, HostEventUserNotFound) + defender.AddEvent("172.16.1.3", ProtocolSSH, HostEventLimitExceeded) + defender.AddEvent("192.168.1.3", ProtocolSSH, HostEventLimitExceeded) + assert.Equal(t, 0, defender.countHosts()) + + testIP := "12.34.56.78" + defender.AddEvent(testIP, ProtocolSSH, HostEventLoginFailed) + assert.Equal(t, 1, defender.countHosts()) + assert.Equal(t, 0, defender.countBanned()) + score, err := defender.GetScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 1, score) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, 1, hosts[0].Score) + assert.True(t, hosts[0].BanTime.IsZero()) + assert.Empty(t, hosts[0].GetBanTime()) + } + host, err := defender.GetHost(testIP) + assert.NoError(t, err) + assert.Equal(t, 1, host.Score) + assert.Empty(t, host.GetBanTime()) + banTime, err := defender.GetBanTime(testIP) + assert.NoError(t, err) + assert.Nil(t, banTime) + defender.AddEvent(testIP, ProtocolSSH, HostEventLimitExceeded) + assert.Equal(t, 1, defender.countHosts()) + assert.Equal(t, 0, defender.countBanned()) + score, err = defender.GetScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 4, score) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, 4, hosts[0].Score) + assert.True(t, hosts[0].BanTime.IsZero()) + assert.Empty(t, hosts[0].GetBanTime()) + } + defender.AddEvent(testIP, ProtocolSSH, HostEventUserNotFound) + defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) + assert.Equal(t, 0, defender.countHosts()) + assert.Equal(t, 1, defender.countBanned()) + score, err = defender.GetScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 0, score) + banTime, err = defender.GetBanTime(testIP) + assert.NoError(t, err) + assert.NotNil(t, banTime) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, 0, hosts[0].Score) + assert.False(t, hosts[0].BanTime.IsZero()) + assert.NotEmpty(t, hosts[0].GetBanTime()) + assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID()) + } + host, err = defender.GetHost(testIP) + assert.NoError(t, err) + assert.Equal(t, 0, host.Score) + assert.NotEmpty(t, host.GetBanTime()) + + // now test cleanup, testIP is already banned + testIP1 := "12.34.56.79" + testIP2 := "12.34.56.80" + testIP3 := "12.34.56.81" + + defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried) + defender.AddEvent(testIP2, ProtocolSSH, HostEventNoLoginTried) + assert.Equal(t, 2, defender.countHosts()) + time.Sleep(20 * time.Millisecond) + defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) + assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts()) + // testIP1 and testIP2 should be removed + assert.Equal(t, defender.config.EntriesSoftLimit, defender.countHosts()) + score, err = defender.GetScore(testIP1) + assert.NoError(t, err) + assert.Equal(t, 0, score) + score, err = defender.GetScore(testIP2) + assert.NoError(t, err) + assert.Equal(t, 0, score) + score, err = defender.GetScore(testIP3) + assert.NoError(t, err) + assert.Equal(t, 2, score) + + defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) + defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) + // IP3 is now banned + banTime, err = defender.GetBanTime(testIP3) + assert.NoError(t, err) + assert.NotNil(t, banTime) + assert.Equal(t, 0, defender.countHosts()) + + time.Sleep(20 * time.Millisecond) + for i := 0; i < 3; i++ { + defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried) + } + assert.Equal(t, 0, defender.countHosts()) + assert.Equal(t, config.EntriesSoftLimit, defender.countBanned()) + banTime, err = defender.GetBanTime(testIP) + assert.NoError(t, err) + assert.Nil(t, banTime) + banTime, err = defender.GetBanTime(testIP3) + assert.NoError(t, err) + assert.Nil(t, banTime) + banTime, err = defender.GetBanTime(testIP1) + assert.NoError(t, err) + assert.NotNil(t, banTime) + + for i := 0; i < 3; i++ { + defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) + time.Sleep(10 * time.Millisecond) + defender.AddEvent(testIP3, ProtocolSSH, HostEventNoLoginTried) + } + assert.Equal(t, 0, defender.countHosts()) + assert.Equal(t, defender.config.EntriesSoftLimit, defender.countBanned()) + + banTime, err = defender.GetBanTime(testIP3) + assert.NoError(t, err) + if assert.NotNil(t, banTime) { + assert.True(t, defender.IsBanned(testIP3, ProtocolFTP)) + // ban time should increase + newBanTime, err := defender.GetBanTime(testIP3) + assert.NoError(t, err) + assert.True(t, newBanTime.After(*banTime)) + } + + assert.True(t, defender.DeleteHost(testIP3)) + assert.False(t, defender.DeleteHost(testIP3)) + + for _, e := range entries { + err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") + assert.NoError(t, err) + } +} + +func TestExpiredHostBans(t *testing.T) { + config := &DefenderConfig{ + Enabled: true, + BanTime: 10, + BanTimeIncrement: 2, + Threshold: 5, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreLimitExceeded: 3, + ObservationTime: 15, + EntriesSoftLimit: 1, + EntriesHardLimit: 2, + } + + d, err := newInMemoryDefender(config) + assert.NoError(t, err) + + defender := d.(*memoryDefender) + + testIP := "1.2.3.4" + defender.banned[testIP] = time.Now().Add(-24 * time.Hour) + + // the ban is expired testIP should not be listed + res, err := defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, res, 0) + + assert.False(t, defender.IsBanned(testIP, ProtocolFTP)) + _, err = defender.GetHost(testIP) + assert.Error(t, err) + _, ok := defender.banned[testIP] + assert.True(t, ok) + // now add an event for an expired banned ip, it should be removed + defender.AddEvent(testIP, ProtocolFTP, HostEventLoginFailed) + assert.False(t, defender.IsBanned(testIP, ProtocolFTP)) + entry, err := defender.GetHost(testIP) + assert.NoError(t, err) + assert.Equal(t, testIP, entry.IP) + assert.Empty(t, entry.GetBanTime()) + assert.Equal(t, 1, entry.Score) + + res, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, res, 1) { + assert.Equal(t, testIP, res[0].IP) + assert.Empty(t, res[0].GetBanTime()) + assert.Equal(t, 1, res[0].Score) + } + + events := []hostEvent{ + { + dateTime: time.Now().Add(-24 * time.Hour), + score: 2, + }, + { + dateTime: time.Now().Add(-24 * time.Hour), + score: 3, + }, + } + + hs := hostScore{ + Events: events, + TotalScore: 5, + } + + defender.hosts[testIP] = hs + // the recorded scored are too old + res, err = defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, res, 0) + _, err = defender.GetHost(testIP) + assert.Error(t, err) + _, ok = defender.hosts[testIP] + assert.True(t, ok) +} + +func TestDefenderCleanup(t *testing.T) { + d := memoryDefender{ + baseDefender: baseDefender{ + config: &DefenderConfig{ + ObservationTime: 1, + EntriesSoftLimit: 2, + EntriesHardLimit: 3, + }, + }, + banned: make(map[string]time.Time), + hosts: make(map[string]hostScore), + } + + d.banned["1.1.1.1"] = time.Now().Add(-24 * time.Hour) + d.banned["1.1.1.2"] = time.Now().Add(-24 * time.Hour) + d.banned["1.1.1.3"] = time.Now().Add(-24 * time.Hour) + d.banned["1.1.1.4"] = time.Now().Add(-24 * time.Hour) + + d.cleanupBanned() + assert.Equal(t, 0, d.countBanned()) + + d.banned["2.2.2.2"] = time.Now().Add(2 * time.Minute) + d.banned["2.2.2.3"] = time.Now().Add(1 * time.Minute) + d.banned["2.2.2.4"] = time.Now().Add(3 * time.Minute) + d.banned["2.2.2.5"] = time.Now().Add(4 * time.Minute) + + d.cleanupBanned() + assert.Equal(t, d.config.EntriesSoftLimit, d.countBanned()) + banTime, err := d.GetBanTime("2.2.2.3") + assert.NoError(t, err) + assert.Nil(t, banTime) + + d.hosts["3.3.3.3"] = hostScore{ + TotalScore: 0, + Events: []hostEvent{ + { + dateTime: time.Now().Add(-5 * time.Minute), + score: 1, + }, + { + dateTime: time.Now().Add(-3 * time.Minute), + score: 1, + }, + { + dateTime: time.Now(), + score: 1, + }, + }, + } + d.hosts["3.3.3.4"] = hostScore{ + TotalScore: 1, + Events: []hostEvent{ + { + dateTime: time.Now().Add(-3 * time.Minute), + score: 1, + }, + }, + } + d.hosts["3.3.3.5"] = hostScore{ + TotalScore: 1, + Events: []hostEvent{ + { + dateTime: time.Now().Add(-2 * time.Minute), + score: 1, + }, + }, + } + d.hosts["3.3.3.6"] = hostScore{ + TotalScore: 1, + Events: []hostEvent{ + { + dateTime: time.Now().Add(-1 * time.Minute), + score: 1, + }, + }, + } + + score, err := d.GetScore("3.3.3.3") + assert.NoError(t, err) + assert.Equal(t, 1, score) + + d.cleanupHosts() + assert.Equal(t, d.config.EntriesSoftLimit, d.countHosts()) + score, err = d.GetScore("3.3.3.4") + assert.NoError(t, err) + assert.Equal(t, 0, score) +} + +func TestDefenderDelay(t *testing.T) { + d := memoryDefender{ + baseDefender: baseDefender{ + config: &DefenderConfig{ + ObservationTime: 1, + EntriesSoftLimit: 2, + EntriesHardLimit: 3, + LoginDelay: LoginDelay{ + Success: 50, + PasswordFailed: 200, + }, + }, + }, + } + startTime := time.Now() + d.DelayLogin(nil) + elapsed := time.Since(startTime) + assert.Less(t, elapsed, time.Millisecond*100) + + startTime = time.Now() + d.DelayLogin(ErrInternalFailure) + elapsed = time.Since(startTime) + assert.Greater(t, elapsed, time.Millisecond*150) +} + +func TestDefenderConfig(t *testing.T) { + c := DefenderConfig{} + err := c.validate() + require.NoError(t, err) + + c.Enabled = true + c.Threshold = 10 + c.ScoreInvalid = 10 + err = c.validate() + require.Error(t, err) + + c.ScoreInvalid = 2 + c.ScoreLimitExceeded = 10 + err = c.validate() + require.Error(t, err) + + c.ScoreLimitExceeded = 2 + c.ScoreValid = 10 + err = c.validate() + require.Error(t, err) + + c.ScoreValid = 1 + c.ScoreNoAuth = 10 + err = c.validate() + require.Error(t, err) + + c.ScoreNoAuth = 2 + c.BanTime = 0 + err = c.validate() + require.Error(t, err) + + c.BanTime = 30 + c.BanTimeIncrement = 0 + err = c.validate() + require.Error(t, err) + + c.BanTimeIncrement = 50 + c.ObservationTime = 0 + err = c.validate() + require.Error(t, err) + + c.ObservationTime = 30 + err = c.validate() + require.Error(t, err) + + c.EntriesSoftLimit = 10 + err = c.validate() + require.Error(t, err) + + c.EntriesHardLimit = 10 + err = c.validate() + require.Error(t, err) + + c.EntriesHardLimit = 20 + err = c.validate() + require.NoError(t, err) + + c = DefenderConfig{ + Enabled: true, + ScoreInvalid: -1, + ScoreLimitExceeded: -1, + ScoreNoAuth: -1, + ScoreValid: -1, + } + err = c.validate() + require.Error(t, err) + assert.Equal(t, 0, c.ScoreInvalid) + assert.Equal(t, 0, c.ScoreValid) + assert.Equal(t, 0, c.ScoreLimitExceeded) + assert.Equal(t, 0, c.ScoreNoAuth) +} + +func BenchmarkDefenderBannedSearch(b *testing.B) { + d := getDefenderForBench() + + ip, ipnet, err := net.ParseCIDR("10.8.0.0/12") // 1048574 ip addresses + if err != nil { + panic(err) + } + + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + d.banned[ip.String()] = time.Now().Add(10 * time.Minute) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + d.IsBanned("192.168.1.1", ProtocolSSH) + } +} + +func BenchmarkCleanup(b *testing.B) { + d := getDefenderForBench() + + ip, ipnet, err := net.ParseCIDR("192.168.4.0/24") + if err != nil { + panic(err) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { + d.AddEvent(ip.String(), ProtocolSSH, HostEventLoginFailed) + if d.countHosts() > d.config.EntriesHardLimit { + panic("too many hosts") + } + if d.countBanned() > d.config.EntriesSoftLimit { + panic("too many ip banned") + } + } + } +} + +func BenchmarkCIDRanger(b *testing.B) { + ranger := cidranger.NewPCTrieRanger() + for i := 0; i < 255; i++ { + cidr := fmt.Sprintf("192.168.%d.1/24", i) + _, network, _ := net.ParseCIDR(cidr) + if err := ranger.Insert(cidranger.NewBasicRangerEntry(*network)); err != nil { + panic(err) + } + } + + ipToMatch := net.ParseIP("192.167.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := ranger.Contains(ipToMatch); err != nil { + panic(err) + } + } +} + +func BenchmarkNetContains(b *testing.B) { + var nets []*net.IPNet + for i := 0; i < 255; i++ { + cidr := fmt.Sprintf("192.168.%d.1/24", i) + _, network, _ := net.ParseCIDR(cidr) + nets = append(nets, network) + } + + ipToMatch := net.ParseIP("192.167.1.1") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, n := range nets { + n.Contains(ipToMatch) + } + } +} + +func getDefenderForBench() *memoryDefender { + config := &DefenderConfig{ + Enabled: true, + BanTime: 30, + BanTimeIncrement: 50, + Threshold: 10, + ScoreInvalid: 2, + ScoreValid: 2, + ObservationTime: 30, + EntriesSoftLimit: 50, + EntriesHardLimit: 100, + } + return &memoryDefender{ + baseDefender: baseDefender{ + config: config, + }, + hosts: make(map[string]hostScore), + banned: make(map[string]time.Time), + } +} + +func inc(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} diff --git a/internal/common/defenderdb.go b/internal/common/defenderdb.go new file mode 100644 index 00000000..63995862 --- /dev/null +++ b/internal/common/defenderdb.go @@ -0,0 +1,181 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "sync/atomic" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type dbDefender struct { + baseDefender + lastCleanup atomic.Int64 +} + +func newDBDefender(config *DefenderConfig) (Defender, error) { + err := config.validate() + if err != nil { + return nil, err + } + ipList, err := dataprovider.NewIPList(dataprovider.IPListTypeDefender) + if err != nil { + return nil, err + } + defender := &dbDefender{ + baseDefender: baseDefender{ + config: config, + ipList: ipList, + }, + } + defender.lastCleanup.Store(0) + + return defender, nil +} + +// GetHosts returns hosts that are banned or for which some violations have been detected +func (d *dbDefender) GetHosts() ([]dataprovider.DefenderEntry, error) { + return dataprovider.GetDefenderHosts(d.getStartObservationTime(), d.config.EntriesHardLimit) +} + +// GetHost returns a defender host by ip, if any +func (d *dbDefender) GetHost(ip string) (dataprovider.DefenderEntry, error) { + return dataprovider.GetDefenderHostByIP(ip, d.getStartObservationTime()) +} + +// IsBanned returns true if the specified IP is banned +// and increase ban time if the IP is found. +// This method must be called as soon as the client connects +func (d *dbDefender) IsBanned(ip, protocol string) bool { + if d.isBanned(ip, protocol) { + return true + } + + _, err := dataprovider.IsDefenderHostBanned(ip) + if err != nil { + // not found or another error, we allow this host + return false + } + increment := d.config.BanTime * d.config.BanTimeIncrement / 100 + if increment == 0 { + increment++ + } + dataprovider.UpdateDefenderBanTime(ip, increment) //nolint:errcheck + return true +} + +// DeleteHost removes the specified IP from the defender lists +func (d *dbDefender) DeleteHost(ip string) bool { + if _, err := d.GetHost(ip); err != nil { + return false + } + return dataprovider.DeleteDefenderHost(ip) == nil +} + +// AddEvent adds an event for the given IP. +// This method must be called for clients not yet banned. +// Returns true if the IP is in the defender's safe list. +func (d *dbDefender) AddEvent(ip, protocol string, event HostEvent) bool { + if d.IsSafe(ip, protocol) { + return true + } + + score := d.getScore(event) + + host, err := dataprovider.AddDefenderEvent(ip, score, d.getStartObservationTime()) + if err != nil { + return false + } + d.logEvent(ip, protocol, event, host.Score) + if host.Score > d.config.Threshold { + d.logBan(ip, protocol) + banTime := time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) + err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(banTime)) + if err == nil { + eventManager.handleIPBlockedEvent(EventParams{ + Event: ipBlockedEventName, + IP: ip, + Timestamp: time.Now(), + Status: 1, + }) + } + } + + if err == nil { + d.cleanup() + } + return false +} + +// GetBanTime returns the ban time for the given IP or nil if the IP is not banned +func (d *dbDefender) GetBanTime(ip string) (*time.Time, error) { + host, err := d.GetHost(ip) + if err != nil { + return nil, err + } + if host.BanTime.IsZero() { + return nil, nil + } + return &host.BanTime, nil +} + +// GetScore returns the score for the given IP +func (d *dbDefender) GetScore(ip string) (int, error) { + host, err := d.GetHost(ip) + if err != nil { + return 0, err + } + return host.Score, nil +} + +func (d *dbDefender) cleanup() { + lastCleanup := d.getLastCleanup() + if lastCleanup.IsZero() || lastCleanup.Add(time.Duration(d.config.ObservationTime)*time.Minute*3).Before(time.Now()) { + // FIXME: this could be racy in rare cases but it is better than acquire the lock for the cleanup duration + // or to always acquire a read/write lock. + // Concurrent cleanups could happen anyway from multiple SFTPGo instances and should not cause any issues + d.setLastCleanup(time.Now()) + expireTime := time.Now().Add(-time.Duration(d.config.ObservationTime+1) * time.Minute) + logger.Debug(logSender, "", "cleanup defender hosts before %v, last cleanup %v", expireTime, lastCleanup) + if err := dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(expireTime)); err != nil { + logger.Error(logSender, "", "defender cleanup error, reset last cleanup to %v", lastCleanup) + d.setLastCleanup(lastCleanup) + } + } +} + +func (d *dbDefender) getStartObservationTime() int64 { + t := time.Now().Add(-time.Duration(d.config.ObservationTime) * time.Minute) + return util.GetTimeAsMsSinceEpoch(t) +} + +func (d *dbDefender) getLastCleanup() time.Time { + val := d.lastCleanup.Load() + if val == 0 { + return time.Time{} + } + return util.GetTimeFromMsecSinceEpoch(val) +} + +func (d *dbDefender) setLastCleanup(when time.Time) { + if when.IsZero() { + d.lastCleanup.Store(0) + return + } + d.lastCleanup.Store(util.GetTimeAsMsSinceEpoch(when)) +} diff --git a/internal/common/defenderdb_test.go b/internal/common/defenderdb_test.go new file mode 100644 index 00000000..6cc60725 --- /dev/null +++ b/internal/common/defenderdb_test.go @@ -0,0 +1,320 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "encoding/hex" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func TestBasicDbDefender(t *testing.T) { + if !isDbDefenderSupported() { + t.Skip("this test is not supported with the current database provider") + } + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "172.16.1.1/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "172.16.1.2/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "10.8.0.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + }, + { + IPOrNet: "172.16.1.3/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "172.16.1.4/32", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + { + IPOrNet: "192.168.8.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeAllow, + }, + } + + for idx := range entries { + e := entries[idx] + err := dataprovider.AddIPListEntry(&e, "", "", "") + assert.NoError(t, err) + } + + config := &DefenderConfig{ + Enabled: true, + BanTime: 10, + BanTimeIncrement: 2, + Threshold: 5, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreNoAuth: 2, + ScoreLimitExceeded: 3, + ObservationTime: 15, + EntriesSoftLimit: 1, + EntriesHardLimit: 10, + } + d, err := newDBDefender(config) + assert.NoError(t, err) + defender := d.(*dbDefender) + assert.True(t, defender.IsBanned("172.16.1.1", ProtocolFTP)) + assert.False(t, defender.IsBanned("172.16.1.10", ProtocolSSH)) + assert.False(t, defender.IsBanned("10.8.1.3", ProtocolHTTP)) + assert.True(t, defender.IsBanned("10.8.0.4", ProtocolWebDAV)) + assert.False(t, defender.IsBanned("invalid ip", ProtocolSSH)) + hosts, err := defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 0) + _, err = defender.GetHost("10.8.0.3") + assert.Error(t, err) + + defender.AddEvent("172.16.1.4", ProtocolSSH, HostEventLoginFailed) + defender.AddEvent("192.168.8.4", ProtocolSSH, HostEventUserNotFound) + defender.AddEvent("172.16.1.3", ProtocolSSH, HostEventLimitExceeded) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 0) + assert.True(t, defender.getLastCleanup().IsZero()) + + testIP := "123.45.67.89" + defender.AddEvent(testIP, ProtocolSSH, HostEventLoginFailed) + lastCleanup := defender.getLastCleanup() + assert.False(t, lastCleanup.IsZero()) + score, err := defender.GetScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 1, score) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, 1, hosts[0].Score) + assert.True(t, hosts[0].BanTime.IsZero()) + assert.Empty(t, hosts[0].GetBanTime()) + } + host, err := defender.GetHost(testIP) + assert.NoError(t, err) + assert.Equal(t, 1, host.Score) + assert.Empty(t, host.GetBanTime()) + banTime, err := defender.GetBanTime(testIP) + assert.NoError(t, err) + assert.Nil(t, banTime) + defender.AddEvent(testIP, ProtocolSSH, HostEventLimitExceeded) + score, err = defender.GetScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 4, score) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, 4, hosts[0].Score) + assert.True(t, hosts[0].BanTime.IsZero()) + assert.Empty(t, hosts[0].GetBanTime()) + } + defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) + defender.AddEvent(testIP, ProtocolSSH, HostEventNoLoginTried) + score, err = defender.GetScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 0, score) + banTime, err = defender.GetBanTime(testIP) + assert.NoError(t, err) + assert.NotNil(t, banTime) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, 0, hosts[0].Score) + assert.False(t, hosts[0].BanTime.IsZero()) + assert.NotEmpty(t, hosts[0].GetBanTime()) + assert.Equal(t, hex.EncodeToString([]byte(testIP)), hosts[0].GetID()) + } + host, err = defender.GetHost(testIP) + assert.NoError(t, err) + assert.Equal(t, 0, host.Score) + assert.NotEmpty(t, host.GetBanTime()) + // ban time should increase + assert.True(t, defender.IsBanned(testIP, ProtocolSSH)) + newBanTime, err := defender.GetBanTime(testIP) + assert.NoError(t, err) + assert.True(t, newBanTime.After(*banTime)) + + assert.True(t, defender.DeleteHost(testIP)) + assert.False(t, defender.DeleteHost(testIP)) + // test cleanup + testIP1 := "123.45.67.90" + testIP2 := "123.45.67.91" + testIP3 := "123.45.67.92" + for i := 0; i < 3; i++ { + defender.AddEvent(testIP, ProtocolSSH, HostEventUserNotFound) + defender.AddEvent(testIP1, ProtocolSSH, HostEventNoLoginTried) + defender.AddEvent(testIP2, ProtocolSSH, HostEventUserNotFound) + } + hosts, err = defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 3) + for _, host := range hosts { + assert.Equal(t, 0, host.Score) + assert.False(t, host.BanTime.IsZero()) + assert.NotEmpty(t, host.GetBanTime()) + } + defender.AddEvent(testIP3, ProtocolSSH, HostEventLoginFailed) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 4) + // now set a ban time in the past, so the host will be cleanead up + for _, ip := range []string{testIP1, testIP2} { + err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute))) + assert.NoError(t, err) + } + hosts, err = defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 4) + for _, host := range hosts { + switch host.IP { + case testIP: + assert.Equal(t, 0, host.Score) + assert.False(t, host.BanTime.IsZero()) + assert.NotEmpty(t, host.GetBanTime()) + case testIP3: + assert.Equal(t, 1, host.Score) + assert.True(t, host.BanTime.IsZero()) + assert.Empty(t, host.GetBanTime()) + default: + assert.Equal(t, 6, host.Score) + assert.True(t, host.BanTime.IsZero()) + assert.Empty(t, host.GetBanTime()) + } + } + host, err = defender.GetHost(testIP) + assert.NoError(t, err) + assert.Equal(t, 0, host.Score) + assert.False(t, host.BanTime.IsZero()) + assert.NotEmpty(t, host.GetBanTime()) + host, err = defender.GetHost(testIP3) + assert.NoError(t, err) + assert.Equal(t, 1, host.Score) + assert.True(t, host.BanTime.IsZero()) + assert.Empty(t, host.GetBanTime()) + // set a negative observation time so the from field in the queries will be in the future + // we still should get the banned hosts + defender.config.ObservationTime = -2 + assert.Greater(t, defender.getStartObservationTime(), time.Now().UnixMilli()) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, testIP, hosts[0].IP) + assert.Equal(t, 0, hosts[0].Score) + assert.False(t, hosts[0].BanTime.IsZero()) + assert.NotEmpty(t, hosts[0].GetBanTime()) + } + _, err = defender.GetHost(testIP) + assert.NoError(t, err) + // cleanup db + err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) + assert.NoError(t, err) + // the banned host must still be there + hosts, err = defender.GetHosts() + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + assert.Equal(t, testIP, hosts[0].IP) + assert.Equal(t, 0, hosts[0].Score) + assert.False(t, hosts[0].BanTime.IsZero()) + assert.NotEmpty(t, hosts[0].GetBanTime()) + } + _, err = defender.GetHost(testIP) + assert.NoError(t, err) + err = dataprovider.SetDefenderBanTime(testIP, util.GetTimeAsMsSinceEpoch(time.Now().Add(-1*time.Minute))) + assert.NoError(t, err) + err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(10 * time.Minute))) + assert.NoError(t, err) + hosts, err = defender.GetHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 0) + + for _, e := range entries { + err := dataprovider.DeleteIPListEntry(e.IPOrNet, e.Type, "", "", "") + assert.NoError(t, err) + } +} + +func TestDbDefenderCleanup(t *testing.T) { + if !isDbDefenderSupported() { + t.Skip("this test is not supported with the current database provider") + } + config := &DefenderConfig{ + Enabled: true, + BanTime: 10, + BanTimeIncrement: 2, + Threshold: 5, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreLimitExceeded: 3, + ObservationTime: 15, + EntriesSoftLimit: 1, + EntriesHardLimit: 10, + } + d, err := newDBDefender(config) + assert.NoError(t, err) + defender := d.(*dbDefender) + lastCleanup := defender.getLastCleanup() + assert.True(t, lastCleanup.IsZero()) + defender.cleanup() + lastCleanup = defender.getLastCleanup() + assert.False(t, lastCleanup.IsZero()) + defender.cleanup() + assert.Equal(t, lastCleanup, defender.getLastCleanup()) + defender.setLastCleanup(time.Time{}) + assert.True(t, defender.getLastCleanup().IsZero()) + defender.setLastCleanup(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4)) + time.Sleep(20 * time.Millisecond) + defender.cleanup() + assert.True(t, lastCleanup.Before(defender.getLastCleanup())) + + providerConf := dataprovider.GetProviderConfig() + err = dataprovider.Close() + assert.NoError(t, err) + + lastCleanup = util.GetTimeFromMsecSinceEpoch(time.Now().Add(-time.Duration(config.ObservationTime) * time.Minute * 4).UnixMilli()) + defender.setLastCleanup(lastCleanup) + defender.cleanup() + // cleanup will fail and so last cleanup should be reset to the previous value + assert.Equal(t, lastCleanup, defender.getLastCleanup()) + + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func isDbDefenderSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} diff --git a/internal/common/defendermem.go b/internal/common/defendermem.go new file mode 100644 index 00000000..0f59b37a --- /dev/null +++ b/internal/common/defendermem.go @@ -0,0 +1,354 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "sort" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type memoryDefender struct { + baseDefender + sync.RWMutex + // IP addresses of the clients trying to connected are stored inside hosts, + // they are added to banned once the thresold is reached. + // A violation from a banned host will increase the ban time + // based on the configured BanTimeIncrement + hosts map[string]hostScore // the key is the host IP + banned map[string]time.Time // the key is the host IP +} + +func newInMemoryDefender(config *DefenderConfig) (Defender, error) { + err := config.validate() + if err != nil { + return nil, err + } + ipList, err := dataprovider.NewIPList(dataprovider.IPListTypeDefender) + if err != nil { + return nil, err + } + defender := &memoryDefender{ + baseDefender: baseDefender{ + config: config, + ipList: ipList, + }, + hosts: make(map[string]hostScore), + banned: make(map[string]time.Time), + } + + return defender, nil +} + +// GetHosts returns hosts that are banned or for which some violations have been detected +func (d *memoryDefender) GetHosts() ([]dataprovider.DefenderEntry, error) { + d.RLock() + defer d.RUnlock() + + var result []dataprovider.DefenderEntry + for k, v := range d.banned { + if v.After(time.Now()) { + result = append(result, dataprovider.DefenderEntry{ + IP: k, + BanTime: v, + }) + } + } + for k, v := range d.hosts { + score := 0 + for _, event := range v.Events { + if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { + score += event.score + } + } + if score > 0 { + result = append(result, dataprovider.DefenderEntry{ + IP: k, + Score: score, + }) + } + } + + return result, nil +} + +// GetHost returns a defender host by ip, if any +func (d *memoryDefender) GetHost(ip string) (dataprovider.DefenderEntry, error) { + d.RLock() + defer d.RUnlock() + + if banTime, ok := d.banned[ip]; ok { + if banTime.After(time.Now()) { + return dataprovider.DefenderEntry{ + IP: ip, + BanTime: banTime, + }, nil + } + } + + if hs, ok := d.hosts[ip]; ok { + score := 0 + for _, event := range hs.Events { + if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { + score += event.score + } + } + if score > 0 { + return dataprovider.DefenderEntry{ + IP: ip, + Score: score, + }, nil + } + } + + return dataprovider.DefenderEntry{}, util.NewRecordNotFoundError("host not found") +} + +// IsBanned returns true if the specified IP is banned +// and increase ban time if the IP is found. +// This method must be called as soon as the client connects +func (d *memoryDefender) IsBanned(ip, protocol string) bool { + d.RLock() + + if banTime, ok := d.banned[ip]; ok { + if banTime.After(time.Now()) { + increment := d.config.BanTime * d.config.BanTimeIncrement / 100 + if increment == 0 { + increment++ + } + + d.RUnlock() + + // we can save an earlier ban time if there are contemporary updates + // but this should not make much difference. I prefer to hold a read lock + // until possible for performance reasons, this method is called each + // time a new client connects and it must be as fast as possible + d.Lock() + d.banned[ip] = banTime.Add(time.Duration(increment) * time.Minute) + d.Unlock() + + return true + } + } + + defer d.RUnlock() + + return d.isBanned(ip, protocol) +} + +// DeleteHost removes the specified IP from the defender lists +func (d *memoryDefender) DeleteHost(ip string) bool { + d.Lock() + defer d.Unlock() + + if _, ok := d.banned[ip]; ok { + delete(d.banned, ip) + return true + } + + if _, ok := d.hosts[ip]; ok { + delete(d.hosts, ip) + return true + } + + return false +} + +// AddEvent adds an event for the given IP. +// This method must be called for clients not yet banned. +// Returns true if the IP is in the defender's safe list. +func (d *memoryDefender) AddEvent(ip, protocol string, event HostEvent) bool { + if d.IsSafe(ip, protocol) { + return true + } + + d.Lock() + defer d.Unlock() + + // ignore events for already banned hosts + if v, ok := d.banned[ip]; ok { + if v.After(time.Now()) { + return false + } + delete(d.banned, ip) + } + + score := d.getScore(event) + + ev := hostEvent{ + dateTime: time.Now(), + score: score, + } + + if hs, ok := d.hosts[ip]; ok { + hs.Events = append(hs.Events, ev) + hs.TotalScore = 0 + + idx := 0 + for _, event := range hs.Events { + if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { + hs.Events[idx] = event + hs.TotalScore += event.score + idx++ + } + } + d.logEvent(ip, protocol, event, hs.TotalScore) + + hs.Events = hs.Events[:idx] + if hs.TotalScore >= d.config.Threshold { + d.logBan(ip, protocol) + d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) + delete(d.hosts, ip) + d.cleanupBanned() + eventManager.handleIPBlockedEvent(EventParams{ + Event: ipBlockedEventName, + IP: ip, + Timestamp: time.Now(), + Status: 1, + }) + } else { + d.hosts[ip] = hs + } + } else { + d.logEvent(ip, protocol, event, ev.score) + d.hosts[ip] = hostScore{ + TotalScore: ev.score, + Events: []hostEvent{ev}, + } + d.cleanupHosts() + } + return false +} + +func (d *memoryDefender) countBanned() int { + d.RLock() + defer d.RUnlock() + + return len(d.banned) +} + +func (d *memoryDefender) countHosts() int { + d.RLock() + defer d.RUnlock() + + return len(d.hosts) +} + +// GetBanTime returns the ban time for the given IP or nil if the IP is not banned +func (d *memoryDefender) GetBanTime(ip string) (*time.Time, error) { + d.RLock() + defer d.RUnlock() + + if banTime, ok := d.banned[ip]; ok { + return &banTime, nil + } + + return nil, nil +} + +// GetScore returns the score for the given IP +func (d *memoryDefender) GetScore(ip string) (int, error) { + d.RLock() + defer d.RUnlock() + + score := 0 + + if hs, ok := d.hosts[ip]; ok { + for _, event := range hs.Events { + if event.dateTime.Add(time.Duration(d.config.ObservationTime) * time.Minute).After(time.Now()) { + score += event.score + } + } + } + + return score, nil +} + +func (d *memoryDefender) cleanupBanned() { + if len(d.banned) > d.config.EntriesHardLimit { + kvList := make(kvList, 0, len(d.banned)) + + for k, v := range d.banned { + if v.Before(time.Now()) { + delete(d.banned, k) + } + + kvList = append(kvList, kv{ + Key: k, + Value: v.UnixNano(), + }) + } + + // we removed expired ip addresses, if any, above, this could be enough + numToRemove := len(d.banned) - d.config.EntriesSoftLimit + + if numToRemove <= 0 { + return + } + + sort.Sort(kvList) + + for idx, kv := range kvList { + if idx >= numToRemove { + break + } + + delete(d.banned, kv.Key) + } + } +} + +func (d *memoryDefender) cleanupHosts() { + if len(d.hosts) > d.config.EntriesHardLimit { + kvList := make(kvList, 0, len(d.hosts)) + + for k, v := range d.hosts { + value := int64(0) + if len(v.Events) > 0 { + value = v.Events[len(v.Events)-1].dateTime.UnixNano() + } + kvList = append(kvList, kv{ + Key: k, + Value: value, + }) + } + + sort.Sort(kvList) + + numToRemove := len(d.hosts) - d.config.EntriesSoftLimit + + for idx, kv := range kvList { + if idx >= numToRemove { + break + } + + delete(d.hosts, kv.Key) + } + } +} + +type kv struct { + Key string + Value int64 +} + +type kvList []kv + +func (p kvList) Len() int { return len(p) } +func (p kvList) Less(i, j int) bool { return p[i].Value < p[j].Value } +func (p kvList) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/internal/common/eventmanager.go b/internal/common/eventmanager.go new file mode 100644 index 00000000..ca0f7494 --- /dev/null +++ b/internal/common/eventmanager.go @@ -0,0 +1,2949 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "bytes" + "context" + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "os" + "os/exec" + "path" + "path/filepath" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/bmatcuk/doublestar/v4" + "github.com/klauspost/compress/zip" + "github.com/robfig/cron/v3" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/wneessen/go-mail" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + ipBlockedEventName = "IP Blocked" + maxAttachmentsSize = int64(10 * 1024 * 1024) + objDataPlaceholder = "{{.ObjectData}}" + objDataPlaceholderString = "{{.ObjectDataString}}" + dateTimeMillisFormat = "2006-01-02T15:04:05.000" +) + +// Supported IDP login events +const ( + IDPLoginUser = "IDP login user" + IDPLoginAdmin = "IDP login admin" +) + +var ( + // eventManager handle the supported event rules actions + eventManager eventRulesContainer + multipartQuoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + fsEventsWithSize = []string{operationPreDelete, OperationPreUpload, operationDelete, + operationCopy, operationDownload, operationFirstUpload, operationFirstDownload, + operationUpload} +) + +func init() { + eventManager = eventRulesContainer{ + schedulesMapping: make(map[string][]cron.EntryID), + // arbitrary maximum number of concurrent asynchronous tasks, + // each task could execute multiple actions + concurrencyGuard: make(chan struct{}, 200), + } + dataprovider.SetEventRulesCallbacks(eventManager.loadRules, eventManager.RemoveRule, + func(operation, executor, ip, objectType, objectName, role string, object plugin.Renderer) { + p := EventParams{ + Name: executor, + ObjectName: objectName, + Event: operation, + Status: 1, + ObjectType: objectType, + IP: ip, + Role: role, + Timestamp: time.Now(), + Object: object, + } + if u, ok := object.(*dataprovider.User); ok { + p.Email = u.Email + p.Groups = u.Groups + } else if a, ok := object.(*dataprovider.Admin); ok { + p.Email = a.Email + } + eventManager.handleProviderEvent(p) + }) +} + +// HandleCertificateEvent checks and executes action rules for certificate events +func HandleCertificateEvent(params EventParams) { + eventManager.handleCertificateEvent(params) +} + +// HandleIDPLoginEvent executes actions defined for a successful login from an Identity Provider +func HandleIDPLoginEvent(params EventParams, customFields *map[string]any) (*dataprovider.User, *dataprovider.Admin, error) { + return eventManager.handleIDPLoginEvent(params, customFields) +} + +// eventRulesContainer stores event rules by trigger +type eventRulesContainer struct { + sync.RWMutex + lastLoad atomic.Int64 + FsEvents []dataprovider.EventRule + ProviderEvents []dataprovider.EventRule + Schedules []dataprovider.EventRule + IPBlockedEvents []dataprovider.EventRule + CertificateEvents []dataprovider.EventRule + IPDLoginEvents []dataprovider.EventRule + schedulesMapping map[string][]cron.EntryID + concurrencyGuard chan struct{} +} + +func (r *eventRulesContainer) addAsyncTask() { + activeHooks.Add(1) + r.concurrencyGuard <- struct{}{} +} + +func (r *eventRulesContainer) removeAsyncTask() { + activeHooks.Add(-1) + <-r.concurrencyGuard +} + +func (r *eventRulesContainer) getLastLoadTime() int64 { + return r.lastLoad.Load() +} + +func (r *eventRulesContainer) setLastLoadTime(modTime int64) { + r.lastLoad.Store(modTime) +} + +// RemoveRule deletes the rule with the specified name +func (r *eventRulesContainer) RemoveRule(name string) { + r.Lock() + defer r.Unlock() + + r.removeRuleInternal(name) + eventManagerLog(logger.LevelDebug, "event rules updated after delete, fs events: %d, provider events: %d, schedules: %d", + len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules)) +} + +func (r *eventRulesContainer) removeRuleInternal(name string) { + for idx := range r.FsEvents { + if r.FsEvents[idx].Name == name { + lastIdx := len(r.FsEvents) - 1 + r.FsEvents[idx] = r.FsEvents[lastIdx] + r.FsEvents = r.FsEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from fs events", name) + return + } + } + for idx := range r.ProviderEvents { + if r.ProviderEvents[idx].Name == name { + lastIdx := len(r.ProviderEvents) - 1 + r.ProviderEvents[idx] = r.ProviderEvents[lastIdx] + r.ProviderEvents = r.ProviderEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from provider events", name) + return + } + } + for idx := range r.IPBlockedEvents { + if r.IPBlockedEvents[idx].Name == name { + lastIdx := len(r.IPBlockedEvents) - 1 + r.IPBlockedEvents[idx] = r.IPBlockedEvents[lastIdx] + r.IPBlockedEvents = r.IPBlockedEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from IP blocked events", name) + return + } + } + for idx := range r.CertificateEvents { + if r.CertificateEvents[idx].Name == name { + lastIdx := len(r.CertificateEvents) - 1 + r.CertificateEvents[idx] = r.CertificateEvents[lastIdx] + r.CertificateEvents = r.CertificateEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from certificate events", name) + return + } + } + for idx := range r.IPDLoginEvents { + if r.IPDLoginEvents[idx].Name == name { + lastIdx := len(r.IPDLoginEvents) - 1 + r.IPDLoginEvents[idx] = r.IPDLoginEvents[lastIdx] + r.IPDLoginEvents = r.IPDLoginEvents[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from IDP login events", name) + return + } + } + for idx := range r.Schedules { + if r.Schedules[idx].Name == name { + if schedules, ok := r.schedulesMapping[name]; ok { + for _, entryID := range schedules { + eventManagerLog(logger.LevelDebug, "removing scheduled entry id %d for rule %q", entryID, name) + eventScheduler.Remove(entryID) + } + delete(r.schedulesMapping, name) + } + + lastIdx := len(r.Schedules) - 1 + r.Schedules[idx] = r.Schedules[lastIdx] + r.Schedules = r.Schedules[:lastIdx] + eventManagerLog(logger.LevelDebug, "removed rule %q from scheduled events", name) + return + } + } +} + +func (r *eventRulesContainer) addUpdateRuleInternal(rule dataprovider.EventRule) { + r.removeRuleInternal(rule.Name) + if rule.DeletedAt > 0 { + deletedAt := util.GetTimeFromMsecSinceEpoch(rule.DeletedAt) + if deletedAt.Add(30 * time.Minute).Before(time.Now()) { + eventManagerLog(logger.LevelDebug, "removing rule %q deleted at %s", rule.Name, deletedAt) + go dataprovider.RemoveEventRule(rule) //nolint:errcheck + } + return + } + if rule.Status != 1 || rule.Trigger == dataprovider.EventTriggerOnDemand { + return + } + switch rule.Trigger { + case dataprovider.EventTriggerFsEvent: + r.FsEvents = append(r.FsEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to fs events", rule.Name) + case dataprovider.EventTriggerProviderEvent: + r.ProviderEvents = append(r.ProviderEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to provider events", rule.Name) + case dataprovider.EventTriggerIPBlocked: + r.IPBlockedEvents = append(r.IPBlockedEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to IP blocked events", rule.Name) + case dataprovider.EventTriggerCertificate: + r.CertificateEvents = append(r.CertificateEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to certificate events", rule.Name) + case dataprovider.EventTriggerIDPLogin: + r.IPDLoginEvents = append(r.IPDLoginEvents, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to IDP login events", rule.Name) + case dataprovider.EventTriggerSchedule: + for _, schedule := range rule.Conditions.Schedules { + cronSpec := schedule.GetCronSpec() + job := &eventCronJob{ + ruleName: dataprovider.ConvertName(rule.Name), + } + entryID, err := eventScheduler.AddJob(cronSpec, job) + if err != nil { + eventManagerLog(logger.LevelError, "unable to add scheduled rule %q, cron string %q: %v", rule.Name, cronSpec, err) + return + } + r.schedulesMapping[rule.Name] = append(r.schedulesMapping[rule.Name], entryID) + eventManagerLog(logger.LevelDebug, "schedule for rule %q added, id: %d, cron string %q, active scheduling rules: %d", + rule.Name, entryID, cronSpec, len(r.schedulesMapping)) + } + r.Schedules = append(r.Schedules, rule) + eventManagerLog(logger.LevelDebug, "added rule %q to scheduled events", rule.Name) + default: + eventManagerLog(logger.LevelError, "unsupported trigger: %d", rule.Trigger) + } +} + +func (r *eventRulesContainer) loadRules() { + eventManagerLog(logger.LevelDebug, "loading updated rules") + modTime := util.GetTimeAsMsSinceEpoch(time.Now()) + lastLoadTime := r.getLastLoadTime() + rules, err := dataprovider.GetRecentlyUpdatedRules(lastLoadTime) + if err != nil { + eventManagerLog(logger.LevelError, "unable to load event rules: %v", err) + return + } + eventManagerLog(logger.LevelDebug, "recently updated event rules loaded: %d", len(rules)) + + if len(rules) > 0 { + r.Lock() + defer r.Unlock() + + for _, rule := range rules { + r.addUpdateRuleInternal(rule) + } + } + eventManagerLog(logger.LevelDebug, "event rules updated, fs events: %d, provider events: %d, schedules: %d, ip blocked events: %d, certificate events: %d, IDP login events: %d", + len(r.FsEvents), len(r.ProviderEvents), len(r.Schedules), len(r.IPBlockedEvents), len(r.CertificateEvents), len(r.IPDLoginEvents)) + + r.setLastLoadTime(modTime) +} + +func (*eventRulesContainer) checkIPDLoginEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool { + switch conditions.IDPLoginEvent { + case dataprovider.IDPLoginUser: + if params.Event != IDPLoginUser { + return false + } + case dataprovider.IDPLoginAdmin: + if params.Event != IDPLoginAdmin { + return false + } + } + return checkEventConditionPatterns(params.Name, conditions.Options.Names) +} + +func (*eventRulesContainer) checkProviderEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool { + if !slices.Contains(conditions.ProviderEvents, params.Event) { + return false + } + if !checkEventConditionPatterns(params.Name, conditions.Options.Names) { + return false + } + if !checkEventGroupConditionPatterns(params.Groups, conditions.Options.GroupNames) { + return false + } + if !checkEventConditionPatterns(params.Role, conditions.Options.RoleNames) { + return false + } + if len(conditions.Options.ProviderObjects) > 0 && !slices.Contains(conditions.Options.ProviderObjects, params.ObjectType) { + return false + } + return true +} + +func (*eventRulesContainer) checkFsEventMatch(conditions *dataprovider.EventConditions, params *EventParams) bool { + if !slices.Contains(conditions.FsEvents, params.Event) { + return false + } + if !checkEventConditionPatterns(params.Name, conditions.Options.Names) { + return false + } + if !checkEventConditionPatterns(params.Role, conditions.Options.RoleNames) { + return false + } + if !checkEventGroupConditionPatterns(params.Groups, conditions.Options.GroupNames) { + return false + } + if !checkEventConditionPatterns(params.VirtualPath, conditions.Options.FsPaths) { + return false + } + if len(conditions.Options.Protocols) > 0 && !slices.Contains(conditions.Options.Protocols, params.Protocol) { + return false + } + if slices.Contains(fsEventsWithSize, params.Event) { + if conditions.Options.MinFileSize > 0 { + if params.FileSize < conditions.Options.MinFileSize { + return false + } + } + if conditions.Options.MaxFileSize > 0 { + if params.FileSize > conditions.Options.MaxFileSize { + return false + } + } + } + return true +} + +// hasFsRules returns true if there are any rules for filesystem event triggers +func (r *eventRulesContainer) hasFsRules() bool { + r.RLock() + defer r.RUnlock() + + return len(r.FsEvents) > 0 +} + +// handleFsEvent executes the rules actions defined for the specified event. +// The boolean parameter indicates whether a sync action was executed +func (r *eventRulesContainer) handleFsEvent(params EventParams) (bool, error) { + if params.Protocol == protocolEventAction { + return false, nil + } + r.RLock() + + var rulesWithSyncActions, rulesAsync []dataprovider.EventRule + for _, rule := range r.FsEvents { + if r.checkFsEventMatch(&rule.Conditions, ¶ms) { + if err := rule.CheckActionsConsistency(""); err != nil { + eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", + rule.Name, err, params.Event) + continue + } + hasSyncActions := false + for _, action := range rule.Actions { + if action.Options.ExecuteSync { + hasSyncActions = true + break + } + } + if hasSyncActions { + rulesWithSyncActions = append(rulesWithSyncActions, rule) + } else { + rulesAsync = append(rulesAsync, rule) + } + } + } + + r.RUnlock() + + params.sender = params.Name + params.addUID() + if len(rulesAsync) > 0 { + go executeAsyncRulesActions(rulesAsync, params) + } + + if len(rulesWithSyncActions) > 0 { + return true, executeSyncRulesActions(rulesWithSyncActions, params) + } + return false, nil +} + +func (r *eventRulesContainer) handleIDPLoginEvent(params EventParams, customFields *map[string]any) (*dataprovider.User, + *dataprovider.Admin, error, +) { + r.RLock() + + var rulesWithSyncActions, rulesAsync []dataprovider.EventRule + for _, rule := range r.IPDLoginEvents { + if r.checkIPDLoginEventMatch(&rule.Conditions, ¶ms) { + if err := rule.CheckActionsConsistency(""); err != nil { + eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", + rule.Name, err, params.Event) + continue + } + hasSyncActions := false + for _, action := range rule.Actions { + if action.Options.ExecuteSync { + hasSyncActions = true + break + } + } + if hasSyncActions { + rulesWithSyncActions = append(rulesWithSyncActions, rule) + } else { + rulesAsync = append(rulesAsync, rule) + } + } + } + + r.RUnlock() + + if len(rulesAsync) == 0 && len(rulesWithSyncActions) == 0 { + return nil, nil, nil + } + + params.addIDPCustomFields(customFields) + if len(rulesWithSyncActions) > 1 { + var ruleNames []string + for _, r := range rulesWithSyncActions { + ruleNames = append(ruleNames, r.Name) + } + return nil, nil, fmt.Errorf("more than one account check action rules matches: %q", strings.Join(ruleNames, ",")) + } + + params.addUID() + if len(rulesAsync) > 0 { + go executeAsyncRulesActions(rulesAsync, params) + } + + if len(rulesWithSyncActions) > 0 { + return executeIDPAccountCheckRule(rulesWithSyncActions[0], params) + } + return nil, nil, nil +} + +// username is populated for user objects +func (r *eventRulesContainer) handleProviderEvent(params EventParams) { + r.RLock() + defer r.RUnlock() + + var rules []dataprovider.EventRule + for _, rule := range r.ProviderEvents { + if r.checkProviderEventMatch(&rule.Conditions, ¶ms) { + if err := rule.CheckActionsConsistency(params.ObjectType); err == nil { + rules = append(rules, rule) + } else { + eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q object type %q", + rule.Name, err, params.Event, params.ObjectType) + } + } + } + + if len(rules) > 0 { + params.sender = params.ObjectName + go executeAsyncRulesActions(rules, params) + } +} + +func (r *eventRulesContainer) handleIPBlockedEvent(params EventParams) { + r.RLock() + defer r.RUnlock() + + if len(r.IPBlockedEvents) == 0 { + return + } + var rules []dataprovider.EventRule + for _, rule := range r.IPBlockedEvents { + if err := rule.CheckActionsConsistency(""); err == nil { + rules = append(rules, rule) + } else { + eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", + rule.Name, err, params.Event) + } + } + + if len(rules) > 0 { + go executeAsyncRulesActions(rules, params) + } +} + +func (r *eventRulesContainer) handleCertificateEvent(params EventParams) { + r.RLock() + defer r.RUnlock() + + if len(r.CertificateEvents) == 0 { + return + } + var rules []dataprovider.EventRule + for _, rule := range r.CertificateEvents { + if err := rule.CheckActionsConsistency(""); err == nil { + rules = append(rules, rule) + } else { + eventManagerLog(logger.LevelWarn, "rule %q skipped: %v, event %q", + rule.Name, err, params.Event) + } + } + + if len(rules) > 0 { + go executeAsyncRulesActions(rules, params) + } +} + +type executedRetentionCheck struct { + Username string + ActionName string + Results []folderRetentionCheckResult +} + +// EventParams defines the supported event parameters +type EventParams struct { + Name string + Groups []sdk.GroupMapping + Event string + Status int + VirtualPath string + FsPath string + VirtualTargetPath string + FsTargetPath string + ObjectName string + Extension string + ObjectType string + FileSize int64 + Elapsed int64 + Protocol string + IP string + Role string + Email string + Timestamp time.Time + UID string + IDPCustomFields *map[string]string + Object plugin.Renderer + Metadata map[string]string + sender string + updateStatusFromError bool + errors []string + retentionChecks []executedRetentionCheck +} + +func (p *EventParams) getACopy() *EventParams { + params := *p + params.errors = make([]string, len(p.errors)) + copy(params.errors, p.errors) + retentionChecks := make([]executedRetentionCheck, 0, len(p.retentionChecks)) + for _, c := range p.retentionChecks { + executedCheck := executedRetentionCheck{ + Username: c.Username, + ActionName: c.ActionName, + } + executedCheck.Results = make([]folderRetentionCheckResult, len(c.Results)) + copy(executedCheck.Results, c.Results) + retentionChecks = append(retentionChecks, executedCheck) + } + params.retentionChecks = retentionChecks + if p.IDPCustomFields != nil { + fields := make(map[string]string) + for k, v := range *p.IDPCustomFields { + fields[k] = v + } + params.IDPCustomFields = &fields + } + if len(params.Metadata) > 0 { + metadata := make(map[string]string) + for k, v := range p.Metadata { + metadata[k] = v + } + params.Metadata = metadata + } + + return ¶ms +} + +func (p *EventParams) addIDPCustomFields(customFields *map[string]any) { + if customFields == nil || len(*customFields) == 0 { + return + } + + fields := make(map[string]string) + for k, v := range *customFields { + switch val := v.(type) { + case string: + fields[k] = val + } + } + p.IDPCustomFields = &fields +} + +// AddError adds a new error to the event params and update the status if needed +func (p *EventParams) AddError(err error) { + if err == nil { + return + } + if p.updateStatusFromError && p.Status == 1 { + p.Status = 2 + } + p.errors = append(p.errors, err.Error()) +} + +func (p *EventParams) addUID() { + if p.UID == "" { + p.UID = util.GenerateUniqueID() + } +} + +func (p *EventParams) setBackupParams(backupPath string) { + if p.sender != "" { + return + } + p.sender = dataprovider.ActionExecutorSystem + p.FsPath = backupPath + p.ObjectName = filepath.Base(backupPath) + p.VirtualPath = "/" + p.ObjectName + p.Timestamp = time.Now() + info, err := os.Stat(backupPath) + if err == nil { + p.FileSize = info.Size() + } +} + +func (p *EventParams) getStatusString() string { + switch p.Status { + case 1: + return "OK" + default: + return "KO" + } +} + +// getUsers returns users with group settings not applied +func (p *EventParams) getUsers() ([]dataprovider.User, error) { + if p.sender == "" { + dump, err := dataprovider.DumpData([]string{dataprovider.DumpScopeUsers}) + if err != nil { + eventManagerLog(logger.LevelError, "unable to get users: %+v", err) + return nil, errors.New("unable to get users") + } + return dump.Users, nil + } + user, err := p.getUserFromSender() + if err != nil { + return nil, err + } + return []dataprovider.User{user}, nil +} + +func (p *EventParams) getUserFromSender() (dataprovider.User, error) { + if p.sender == dataprovider.ActionExecutorSystem { + return dataprovider.User{ + BaseUser: sdk.BaseUser{ + Status: 1, + Username: p.sender, + HomeDir: dataprovider.GetBackupsPath(), + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + }, nil + } + user, err := dataprovider.UserExists(p.sender, "") + if err != nil { + eventManagerLog(logger.LevelError, "unable to get user %q: %+v", p.sender, err) + return user, fmt.Errorf("error getting user %q", p.sender) + } + return user, nil +} + +func (p *EventParams) getFolders() ([]vfs.BaseVirtualFolder, error) { + if p.sender == "" { + dump, err := dataprovider.DumpData([]string{dataprovider.DumpScopeFolders}) + return dump.Folders, err + } + folder, err := dataprovider.GetFolderByName(p.sender) + if err != nil { + return nil, fmt.Errorf("error getting folder %q: %w", p.sender, err) + } + return []vfs.BaseVirtualFolder{folder}, nil +} + +func (p *EventParams) getCompressedDataRetentionReport() ([]byte, error) { + if len(p.retentionChecks) == 0 { + return nil, errors.New("no data retention report available") + } + var b bytes.Buffer + if _, err := p.writeCompressedDataRetentionReports(&b); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func (p *EventParams) writeCompressedDataRetentionReports(w io.Writer) (int64, error) { + var n int64 + wr := zip.NewWriter(w) + + for _, check := range p.retentionChecks { + data, err := getCSVRetentionReport(check.Results) + if err != nil { + return n, fmt.Errorf("unable to get CSV report: %w", err) + } + dataSize := int64(len(data)) + n += dataSize + // we suppose a 3:1 compression ratio + if n > (maxAttachmentsSize * 3) { + eventManagerLog(logger.LevelError, "unable to get retention report, size too large: %s", + util.ByteCountIEC(n)) + return n, fmt.Errorf("unable to get retention report, size too large: %s", util.ByteCountIEC(n)) + } + + fh := &zip.FileHeader{ + Name: fmt.Sprintf("%s-%s.csv", check.ActionName, check.Username), + Method: zip.Deflate, + Modified: time.Now().UTC(), + } + f, err := wr.CreateHeader(fh) + if err != nil { + return n, fmt.Errorf("unable to create zip header for file %q: %w", fh.Name, err) + } + _, err = io.CopyN(f, bytes.NewBuffer(data), dataSize) + if err != nil { + return n, fmt.Errorf("unable to write content to zip file %q: %w", fh.Name, err) + } + } + if err := wr.Close(); err != nil { + return n, fmt.Errorf("unable to close zip writer: %w", err) + } + return n, nil +} + +func (p *EventParams) getRetentionReportsAsMailAttachment() (*mail.File, error) { + if len(p.retentionChecks) == 0 { + return nil, errors.New("no data retention report available") + } + return &mail.File{ + Name: "retention-reports.zip", + Header: make(map[string][]string), + Writer: p.writeCompressedDataRetentionReports, + }, nil +} + +func (*EventParams) getStringReplacement(val string, escapeMode int) string { + switch escapeMode { + case 1: + return util.JSONEscape(val) + case 2: + return html.EscapeString(val) + default: + return val + } +} + +func (p *EventParams) getStringReplacements(addObjectData bool, escapeMode int) []string { + var dateTimeString string + if Config.TZ == "local" { + dateTimeString = p.Timestamp.Local().Format(dateTimeMillisFormat) + } else { + dateTimeString = p.Timestamp.UTC().Format(dateTimeMillisFormat) + } + year := dateTimeString[0:4] + month := dateTimeString[5:7] + day := dateTimeString[8:10] + hour := dateTimeString[11:13] + minute := dateTimeString[14:16] + + replacements := []string{ + "{{.Name}}", p.getStringReplacement(p.Name, escapeMode), + "{{.Event}}", p.Event, + "{{.Status}}", fmt.Sprintf("%d", p.Status), + "{{.VirtualPath}}", p.getStringReplacement(p.VirtualPath, escapeMode), + "{{.EscapedVirtualPath}}", p.getStringReplacement(url.QueryEscape(p.VirtualPath), escapeMode), + "{{.FsPath}}", p.getStringReplacement(p.FsPath, escapeMode), + "{{.VirtualTargetPath}}", p.getStringReplacement(p.VirtualTargetPath, escapeMode), + "{{.FsTargetPath}}", p.getStringReplacement(p.FsTargetPath, escapeMode), + "{{.ObjectName}}", p.getStringReplacement(p.ObjectName, escapeMode), + "{{.ObjectBaseName}}", p.getStringReplacement(strings.TrimSuffix(p.ObjectName, p.Extension), escapeMode), + "{{.ObjectType}}", p.ObjectType, + "{{.FileSize}}", strconv.FormatInt(p.FileSize, 10), + "{{.Elapsed}}", strconv.FormatInt(p.Elapsed, 10), + "{{.Protocol}}", p.Protocol, + "{{.IP}}", p.IP, + "{{.Role}}", p.getStringReplacement(p.Role, escapeMode), + "{{.Email}}", p.getStringReplacement(p.Email, escapeMode), + "{{.Timestamp}}", strconv.FormatInt(p.Timestamp.UnixNano(), 10), + "{{.DateTime}}", dateTimeString, + "{{.Year}}", year, + "{{.Month}}", month, + "{{.Day}}", day, + "{{.Hour}}", hour, + "{{.Minute}}", minute, + "{{.StatusString}}", p.getStatusString(), + "{{.UID}}", p.getStringReplacement(p.UID, escapeMode), + "{{.Ext}}", p.getStringReplacement(p.Extension, escapeMode), + } + if p.VirtualPath != "" { + replacements = append(replacements, "{{.VirtualDirPath}}", p.getStringReplacement(path.Dir(p.VirtualPath), escapeMode)) + } + if p.VirtualTargetPath != "" { + replacements = append(replacements, "{{.VirtualTargetDirPath}}", p.getStringReplacement(path.Dir(p.VirtualTargetPath), escapeMode)) + replacements = append(replacements, "{{.TargetName}}", p.getStringReplacement(path.Base(p.VirtualTargetPath), escapeMode)) + } + if len(p.errors) > 0 { + replacements = append(replacements, "{{.ErrorString}}", p.getStringReplacement(strings.Join(p.errors, ", "), escapeMode)) + } else { + replacements = append(replacements, "{{.ErrorString}}", "") + } + replacements = append(replacements, objDataPlaceholder, "{}") + replacements = append(replacements, objDataPlaceholderString, "") + if addObjectData { + data, err := p.Object.RenderAsJSON(p.Event != operationDelete) + if err == nil { + dataString := util.BytesToString(data) + replacements[len(replacements)-3] = p.getStringReplacement(dataString, 0) + replacements[len(replacements)-1] = p.getStringReplacement(dataString, 1) + } + } + if p.IDPCustomFields != nil { + for k, v := range *p.IDPCustomFields { + replacements = append(replacements, fmt.Sprintf("{{.IDPField%s}}", k), p.getStringReplacement(v, escapeMode)) + } + } + replacements = append(replacements, "{{.Metadata}}", "{}") + replacements = append(replacements, "{{.MetadataString}}", "") + if len(p.Metadata) > 0 { + data, err := json.Marshal(p.Metadata) + if err == nil { + dataString := util.BytesToString(data) + replacements[len(replacements)-3] = p.getStringReplacement(dataString, 0) + replacements[len(replacements)-1] = p.getStringReplacement(dataString, 1) + } + } + return replacements +} + +func getCSVRetentionReport(results []folderRetentionCheckResult) ([]byte, error) { + var b bytes.Buffer + csvWriter := csv.NewWriter(&b) + err := csvWriter.Write([]string{"path", "retention (hours)", "deleted files", "deleted size (bytes)", + "elapsed (ms)", "info", "error"}) + if err != nil { + return nil, err + } + + for _, result := range results { + err = csvWriter.Write([]string{result.Path, strconv.Itoa(result.Retention), strconv.Itoa(result.DeletedFiles), + strconv.FormatInt(result.DeletedSize, 10), strconv.FormatInt(result.Elapsed.Milliseconds(), 10), + result.Info, result.Error}) + if err != nil { + return nil, err + } + } + + csvWriter.Flush() + err = csvWriter.Error() + return b.Bytes(), err +} + +func closeWriterAndUpdateQuota(w io.WriteCloser, conn *BaseConnection, virtualSourcePath, virtualTargetPath string, + numFiles int, truncatedSize int64, errTransfer error, operation string, startTime time.Time, +) error { + var fsDstPath string + var errDstFs error + errWrite := w.Close() + targetPath := virtualSourcePath + if virtualTargetPath != "" { + targetPath = virtualTargetPath + var fsDst vfs.Fs + fsDst, fsDstPath, errDstFs = conn.GetFsAndResolvedPath(virtualTargetPath) + if errTransfer != nil && errDstFs == nil { + // try to remove a partial file on error. If this fails, we can't do anything + errRemove := fsDst.Remove(fsDstPath, false) + conn.Log(logger.LevelDebug, "removing partial file %q after write error, result: %v", virtualTargetPath, errRemove) + } + } + info, err := conn.doStatInternal(targetPath, 0, false, false) + if err == nil { + updateUserQuotaAfterFileWrite(conn, targetPath, numFiles, info.Size()-truncatedSize) + var fsSrcPath string + var errSrcFs error + if virtualSourcePath != "" { + _, fsSrcPath, errSrcFs = conn.GetFsAndResolvedPath(virtualSourcePath) + } + if errSrcFs == nil && errDstFs == nil { + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + if errTransfer == nil { + errTransfer = errWrite + } + if operation == operationCopy { + logger.CommandLog(copyLogSender, fsSrcPath, fsDstPath, conn.User.Username, "", conn.ID, conn.protocol, -1, -1, + "", "", "", info.Size(), conn.localAddr, conn.remoteAddr, elapsed) + } + ExecuteActionNotification(conn, operation, fsSrcPath, virtualSourcePath, fsDstPath, virtualTargetPath, "", info.Size(), errTransfer, elapsed, nil) //nolint:errcheck + } + } else { + eventManagerLog(logger.LevelWarn, "unable to update quota after writing %q: %v", targetPath, err) + } + if errTransfer != nil { + return errTransfer + } + return errWrite +} + +func updateUserQuotaAfterFileWrite(conn *BaseConnection, virtualPath string, numFiles int, fileSize int64) { + vfolder, err := conn.User.GetVirtualFolderForPath(path.Dir(virtualPath)) + if err != nil { + dataprovider.UpdateUserQuota(&conn.User, numFiles, fileSize, false) //nolint:errcheck + return + } + dataprovider.UpdateUserFolderQuota(&vfolder, &conn.User, numFiles, fileSize, false) +} + +func checkWriterPermsAndQuota(conn *BaseConnection, virtualPath string, numFiles int, expectedSize, truncatedSize int64) error { + if numFiles == 0 { + if !conn.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualPath)) { + return conn.GetPermissionDeniedError() + } + } else { + if !conn.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualPath)) { + return conn.GetPermissionDeniedError() + } + } + q, _ := conn.HasSpace(numFiles > 0, false, virtualPath) + if !q.HasSpace { + return conn.GetQuotaExceededError() + } + if expectedSize != -1 { + sizeDiff := expectedSize - truncatedSize + if sizeDiff > 0 { + remainingSize := q.GetRemainingSize() + if remainingSize > 0 && remainingSize < sizeDiff { + return conn.GetQuotaExceededError() + } + } + } + return nil +} + +func getFileWriter(conn *BaseConnection, virtualPath string, expectedSize int64) (io.WriteCloser, int, int64, func(), error) { + fs, fsPath, err := conn.GetFsAndResolvedPath(virtualPath) + if err != nil { + return nil, 0, 0, nil, err + } + var truncatedSize, fileSize int64 + numFiles := 1 + isFileOverwrite := false + + info, err := fs.Lstat(fsPath) + if err == nil { + fileSize = info.Size() + if info.IsDir() { + return nil, numFiles, truncatedSize, nil, fmt.Errorf("cannot write to a directory: %q", virtualPath) + } + if info.Mode().IsRegular() { + isFileOverwrite = true + truncatedSize = fileSize + } + numFiles = 0 + } + if err != nil && !fs.IsNotExist(err) { + return nil, numFiles, truncatedSize, nil, conn.GetFsError(fs, err) + } + if err := checkWriterPermsAndQuota(conn, virtualPath, numFiles, expectedSize, truncatedSize); err != nil { + return nil, numFiles, truncatedSize, nil, err + } + f, w, cancelFn, err := fs.Create(fsPath, 0, conn.GetCreateChecks(virtualPath, numFiles == 1, false)) + if err != nil { + return nil, numFiles, truncatedSize, nil, conn.GetFsError(fs, err) + } + vfs.SetPathPermissions(fs, fsPath, conn.User.GetUID(), conn.User.GetGID()) + + if isFileOverwrite { + if vfs.HasTruncateSupport(fs) || vfs.IsCryptOsFs(fs) { + updateUserQuotaAfterFileWrite(conn, virtualPath, numFiles, -fileSize) + truncatedSize = 0 + } + } + if cancelFn == nil { + cancelFn = func() {} + } + if f != nil { + return f, numFiles, truncatedSize, cancelFn, nil + } + return w, numFiles, truncatedSize, cancelFn, nil +} + +func addZipEntry(wr *zipWriterWrapper, conn *BaseConnection, entryPath, baseDir string, info os.FileInfo, recursion int) error { //nolint:gocyclo + if entryPath == wr.Name { + // skip the archive itself + return nil + } + if recursion >= util.MaxRecursion { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, recursion too deep: %v", entryPath, recursion) + return util.ErrRecursionTooDeep + } + recursion++ + var err error + if info == nil { + info, err = conn.DoStat(entryPath, 1, false) + if err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, stat error: %v", entryPath, err) + return err + } + } + entryName, err := getZipEntryName(entryPath, baseDir) + if err != nil { + eventManagerLog(logger.LevelError, "unable to get zip entry name: %v", err) + return err + } + if _, ok := wr.Entries[entryName]; ok { + eventManagerLog(logger.LevelInfo, "skipping duplicate zip entry %q, is dir %t", entryPath, info.IsDir()) + return nil + } + wr.Entries[entryName] = true + if info.IsDir() { + _, err = wr.Writer.CreateHeader(&zip.FileHeader{ + Name: entryName + "/", + Method: zip.Deflate, + Modified: info.ModTime(), + }) + if err != nil { + eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) + return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err) + } + lister, err := conn.ListDir(entryPath) + if err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, get dir lister error: %v", entryPath, err) + return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) + } + defer lister.Close() + + for { + contents, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err := lister.convertError(err); err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, read dir error: %v", entryPath, err) + return fmt.Errorf("unable to add zip entry %q: %w", entryPath, err) + } + for _, info := range contents { + fullPath := util.CleanPath(path.Join(entryPath, info.Name())) + if err := addZipEntry(wr, conn, fullPath, baseDir, info, recursion); err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry: %v", err) + return err + } + } + if finished { + return nil + } + } + } + if !info.Mode().IsRegular() { + // we only allow regular files + eventManagerLog(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) + return nil + } + + return addFileToZip(wr, conn, entryPath, entryName, info.ModTime()) +} + +func addFileToZip(wr *zipWriterWrapper, conn *BaseConnection, entryPath, entryName string, modTime time.Time) error { + reader, cancelFn, err := getFileReader(conn, entryPath) + if err != nil { + eventManagerLog(logger.LevelError, "unable to add zip entry %q, cannot open file: %v", entryPath, err) + return fmt.Errorf("unable to open %q: %w", entryPath, err) + } + defer cancelFn() + defer reader.Close() + + f, err := wr.Writer.CreateHeader(&zip.FileHeader{ + Name: entryName, + Method: zip.Deflate, + Modified: modTime, + }) + if err != nil { + eventManagerLog(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) + return fmt.Errorf("unable to create zip entry %q: %w", entryPath, err) + } + _, err = io.Copy(f, reader) + return err +} + +func getZipEntryName(entryPath, baseDir string) (string, error) { + if !strings.HasPrefix(entryPath, baseDir) { + return "", fmt.Errorf("entry path %q is outside base dir %q", entryPath, baseDir) + } + entryPath = strings.TrimPrefix(entryPath, baseDir) + return strings.TrimPrefix(entryPath, "/"), nil +} + +func getFileReader(conn *BaseConnection, virtualPath string) (io.ReadCloser, func(), error) { + if !conn.User.HasPerm(dataprovider.PermDownload, path.Dir(virtualPath)) { + return nil, nil, conn.GetPermissionDeniedError() + } + fs, fsPath, err := conn.GetFsAndResolvedPath(virtualPath) + if err != nil { + return nil, nil, err + } + f, r, cancelFn, err := fs.Open(fsPath, 0) + if err != nil { + return nil, nil, conn.GetFsError(fs, err) + } + if cancelFn == nil { + cancelFn = func() {} + } + + if f != nil { + return f, cancelFn, nil + } + return r, cancelFn, nil +} + +func writeFileContent(conn *BaseConnection, virtualPath string, w io.Writer) error { + reader, cancelFn, err := getFileReader(conn, virtualPath) + if err != nil { + return err + } + + defer cancelFn() + defer reader.Close() + + _, err = io.Copy(w, reader) + return err +} + +func getFileContentFn(conn *BaseConnection, virtualPath string, size int64) func(w io.Writer) (int64, error) { + return func(w io.Writer) (int64, error) { + reader, cancelFn, err := getFileReader(conn, virtualPath) + if err != nil { + return 0, err + } + + defer cancelFn() + defer reader.Close() + + return io.CopyN(w, reader, size) + } +} + +func getMailAttachments(conn *BaseConnection, attachments []string, replacer *strings.Replacer) ([]*mail.File, error) { + var files []*mail.File + totalSize := int64(0) + + for _, virtualPath := range replacePathsPlaceholders(attachments, replacer) { + info, err := conn.DoStat(virtualPath, 0, false) + if err != nil { + return nil, fmt.Errorf("unable to get info for file %q, user %q: %w", virtualPath, conn.User.Username, err) + } + if !info.Mode().IsRegular() { + return nil, fmt.Errorf("cannot attach non regular file %q", virtualPath) + } + totalSize += info.Size() + if totalSize > maxAttachmentsSize { + return nil, fmt.Errorf("unable to send files as attachment, size too large: %s", util.ByteCountIEC(totalSize)) + } + files = append(files, &mail.File{ + Name: path.Base(virtualPath), + Header: make(map[string][]string), + Writer: getFileContentFn(conn, virtualPath, info.Size()), + }) + } + return files, nil +} + +func replaceWithReplacer(input string, replacer *strings.Replacer) string { + if !strings.Contains(input, "{{.") { + return input + } + return replacer.Replace(input) +} + +func checkEventConditionPattern(p dataprovider.ConditionPattern, name string) bool { + var matched bool + var err error + if strings.Contains(p.Pattern, "**") { + matched, err = doublestar.Match(p.Pattern, name) + } else { + matched, err = path.Match(p.Pattern, name) + } + if err != nil { + eventManagerLog(logger.LevelError, "pattern matching error %q, err: %v", p.Pattern, err) + return false + } + if p.InverseMatch { + return !matched + } + return matched +} + +func checkUserConditionOptions(user *dataprovider.User, conditions *dataprovider.ConditionOptions) bool { + if !checkEventConditionPatterns(user.Username, conditions.Names) { + return false + } + if !checkEventConditionPatterns(user.Role, conditions.RoleNames) { + return false + } + if !checkEventGroupConditionPatterns(user.Groups, conditions.GroupNames) { + return false + } + return true +} + +// checkEventConditionPatterns returns false if patterns are defined and no match is found +func checkEventConditionPatterns(name string, patterns []dataprovider.ConditionPattern) bool { + if len(patterns) == 0 { + return true + } + matches := false + for _, p := range patterns { + // assume, that multiple InverseMatches are set + if p.InverseMatch { + if checkEventConditionPattern(p, name) { + matches = true + } else { + return false + } + } else if checkEventConditionPattern(p, name) { + return true + } + } + return matches +} + +func checkEventGroupConditionPatterns(groups []sdk.GroupMapping, patterns []dataprovider.ConditionPattern) bool { + if len(patterns) == 0 { + return true + } + matches := false + for _, group := range groups { + for _, p := range patterns { + // assume, that multiple InverseMatches are set + if p.InverseMatch { + if checkEventConditionPattern(p, group.Name) { + matches = true + } else { + return false + } + } else { + if checkEventConditionPattern(p, group.Name) { + return true + } + } + } + } + return matches +} + +func getHTTPRuleActionEndpoint(c *dataprovider.EventActionHTTPConfig, replacer *strings.Replacer) (string, error) { + u, err := url.Parse(c.Endpoint) + if err != nil { + return "", fmt.Errorf("invalid endpoint: %w", err) + } + if strings.Contains(u.Path, "{{.") { + pathComponents := strings.Split(u.Path, "/") + for idx := range pathComponents { + part := replaceWithReplacer(pathComponents[idx], replacer) + if part != pathComponents[idx] { + pathComponents[idx] = url.PathEscape(part) + } + } + u.Path = "" + u = u.JoinPath(pathComponents...) + } + if len(c.QueryParameters) > 0 { + q := u.Query() + + for _, keyVal := range c.QueryParameters { + q.Add(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) + } + + u.RawQuery = q.Encode() + } + return u.String(), nil +} + +func writeHTTPPart(m *multipart.Writer, part dataprovider.HTTPPart, h textproto.MIMEHeader, + conn *BaseConnection, replacer *strings.Replacer, params *EventParams, addObjectData bool, +) error { + partWriter, err := m.CreatePart(h) + if err != nil { + eventManagerLog(logger.LevelError, "unable to create part %q, err: %v", part.Name, err) + return err + } + if part.Body != "" { + cType := h.Get("Content-Type") + if strings.Contains(strings.ToLower(cType), "application/json") { + replacements := params.getStringReplacements(addObjectData, 1) + jsonReplacer := strings.NewReplacer(replacements...) + _, err = partWriter.Write(util.StringToBytes(replaceWithReplacer(part.Body, jsonReplacer))) + } else { + _, err = partWriter.Write(util.StringToBytes(replaceWithReplacer(part.Body, replacer))) + } + if err != nil { + eventManagerLog(logger.LevelError, "unable to write part %q, err: %v", part.Name, err) + return err + } + return nil + } + if part.Filepath == dataprovider.RetentionReportPlaceHolder { + data, err := params.getCompressedDataRetentionReport() + if err != nil { + return err + } + _, err = partWriter.Write(data) + if err != nil { + eventManagerLog(logger.LevelError, "unable to write part %q, err: %v", part.Name, err) + return err + } + return nil + } + err = writeFileContent(conn, util.CleanPath(replacer.Replace(part.Filepath)), partWriter) + if err != nil { + eventManagerLog(logger.LevelError, "unable to write file part %q, err: %v", part.Name, err) + return err + } + return nil +} + +func getHTTPRuleActionBody(c *dataprovider.EventActionHTTPConfig, replacer *strings.Replacer, //nolint:gocyclo + cancel context.CancelFunc, user dataprovider.User, params *EventParams, addObjectData bool, +) (io.Reader, string, error) { + var body io.Reader + if c.Method == http.MethodGet { + return body, "", nil + } + if c.Body != "" { + if c.Body == dataprovider.RetentionReportPlaceHolder { + data, err := params.getCompressedDataRetentionReport() + if err != nil { + return body, "", err + } + return bytes.NewBuffer(data), "", nil + } + if c.HasJSONBody() { + replacements := params.getStringReplacements(addObjectData, 1) + jsonReplacer := strings.NewReplacer(replacements...) + return bytes.NewBufferString(replaceWithReplacer(c.Body, jsonReplacer)), "", nil + } + return bytes.NewBufferString(replaceWithReplacer(c.Body, replacer)), "", nil + } + if len(c.Parts) > 0 { + r, w := io.Pipe() + m := multipart.NewWriter(w) + + var conn *BaseConnection + if user.Username != "" { + var err error + if err := getUserForEventAction(&user); err != nil { + return body, "", err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err = user.CheckFsRoot(connectionID) + if err != nil { + user.CloseFs() //nolint:errcheck + return body, "", fmt.Errorf("error getting multipart file/s, unable to check root fs for user %q: %w", + user.Username, err) + } + conn = NewBaseConnection(connectionID, protocolEventAction, "", "", user) + } + + go func() { + defer w.Close() + defer user.CloseFs() //nolint:errcheck + if conn != nil { + defer conn.CloseFS() //nolint:errcheck + } + + for _, part := range c.Parts { + h := make(textproto.MIMEHeader) + if part.Body != "" { + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"`, multipartQuoteEscaper.Replace(part.Name))) + } else { + h.Set("Content-Disposition", + fmt.Sprintf(`form-data; name="%s"; filename="%s"`, + multipartQuoteEscaper.Replace(part.Name), + multipartQuoteEscaper.Replace((path.Base(replaceWithReplacer(part.Filepath, replacer)))))) + contentType := mime.TypeByExtension(path.Ext(part.Filepath)) + if contentType == "" { + contentType = "application/octet-stream" + } + h.Set("Content-Type", contentType) + } + for _, keyVal := range part.Headers { + h.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) + } + if err := writeHTTPPart(m, part, h, conn, replacer, params, addObjectData); err != nil { + cancel() + return + } + } + m.Close() + }() + + return r, m.FormDataContentType(), nil + } + return body, "", nil +} + +func setHTTPReqHeaders(req *http.Request, c *dataprovider.EventActionHTTPConfig, replacer *strings.Replacer, + contentType string, +) { + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + if c.Username != "" || c.Password.GetPayload() != "" { + req.SetBasicAuth(replaceWithReplacer(c.Username, replacer), c.Password.GetPayload()) + } + for _, keyVal := range c.Headers { + req.Header.Set(keyVal.Key, replaceWithReplacer(keyVal.Value, replacer)) + } +} + +func executeHTTPRuleAction(c dataprovider.EventActionHTTPConfig, params *EventParams) error { + if err := c.TryDecryptPassword(); err != nil { + return err + } + addObjectData := false + if params.Object != nil { + addObjectData = c.HasObjectData() + } + + replacements := params.getStringReplacements(addObjectData, 0) + replacer := strings.NewReplacer(replacements...) + endpoint, err := getHTTPRuleActionEndpoint(&c, replacer) + if err != nil { + return err + } + + ctx, cancel := c.GetContext() + defer cancel() + + var user dataprovider.User + if c.HasMultipartFiles() { + user, err = params.getUserFromSender() + if err != nil { + return err + } + } + body, contentType, err := getHTTPRuleActionBody(&c, replacer, cancel, user, params, addObjectData) + if err != nil { + return err + } + if body != nil { + rc, ok := body.(io.ReadCloser) + if ok { + defer rc.Close() + } + } + req, err := http.NewRequestWithContext(ctx, c.Method, endpoint, body) + if err != nil { + return err + } + setHTTPReqHeaders(req, &c, replacer, contentType) + + client := c.GetHTTPClient() + defer client.CloseIdleConnections() + + startTime := time.Now() + resp, err := client.Do(req) + if err != nil { + eventManagerLog(logger.LevelDebug, "unable to send http notification, endpoint: %s, elapsed: %s, err: %v", + endpoint, time.Since(startTime), err) + return fmt.Errorf("error sending HTTP request: %w", err) + } + defer resp.Body.Close() + + eventManagerLog(logger.LevelDebug, "http notification sent, endpoint: %s, elapsed: %s, status code: %d", + endpoint, time.Since(startTime), resp.StatusCode) + if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { + if rb, err := io.ReadAll(io.LimitReader(resp.Body, 2048)); err == nil { + eventManagerLog(logger.LevelDebug, "error notification response from endpoint %q: %s", + endpoint, rb) + } + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return nil +} + +func executeCommandRuleAction(c dataprovider.EventActionCommandConfig, params *EventParams) error { + if !dataprovider.IsActionCommandAllowed(c.Cmd) { + return fmt.Errorf("command %q is not allowed", c.Cmd) + } + addObjectData := false + if params.Object != nil { + for _, k := range c.EnvVars { + if strings.Contains(k.Value, objDataPlaceholder) || strings.Contains(k.Value, objDataPlaceholderString) { + addObjectData = true + break + } + } + } + replacements := params.getStringReplacements(addObjectData, 0) + replacer := strings.NewReplacer(replacements...) + + args := make([]string, 0, len(c.Args)) + for _, arg := range c.Args { + args = append(args, replaceWithReplacer(arg, replacer)) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, c.Cmd, args...) + cmd.Env = []string{} + for _, keyVal := range c.EnvVars { + if keyVal.Value == "$" && !strings.HasPrefix(strings.ToUpper(keyVal.Key), "SFTPGO_") { + val := os.Getenv(keyVal.Key) + if val == "" { + eventManagerLog(logger.LevelDebug, "empty value for environment variable %q", keyVal.Key) + } + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", keyVal.Key, val)) + } else { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", keyVal.Key, replaceWithReplacer(keyVal.Value, replacer))) + } + } + + startTime := time.Now() + err := cmd.Run() + + eventManagerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v", + c.Cmd, time.Since(startTime), err) + + return err +} + +func getEmailAddressesWithReplacer(addrs []string, replacer *strings.Replacer) []string { + if len(addrs) == 0 { + return nil + } + recipients := make([]string, 0, len(addrs)) + for _, recipient := range addrs { + rcpt := replaceWithReplacer(recipient, replacer) + if rcpt != "" { + recipients = append(recipients, rcpt) + } + } + return recipients +} + +func executeEmailRuleAction(c dataprovider.EventActionEmailConfig, params *EventParams) error { + addObjectData := false + if params.Object != nil { + if strings.Contains(c.Body, objDataPlaceholder) || strings.Contains(c.Body, objDataPlaceholderString) { + addObjectData = true + } + } + replacements := params.getStringReplacements(addObjectData, 0) + replacer := strings.NewReplacer(replacements...) + var body string + if c.ContentType == 1 { + replacements := params.getStringReplacements(addObjectData, 2) + bodyReplacer := strings.NewReplacer(replacements...) + body = replaceWithReplacer(c.Body, bodyReplacer) + } else { + body = replaceWithReplacer(c.Body, replacer) + } + subject := replaceWithReplacer(c.Subject, replacer) + recipients := getEmailAddressesWithReplacer(c.Recipients, replacer) + bcc := getEmailAddressesWithReplacer(c.Bcc, replacer) + startTime := time.Now() + var files []*mail.File + fileAttachments := make([]string, 0, len(c.Attachments)) + for _, attachment := range c.Attachments { + if attachment == dataprovider.RetentionReportPlaceHolder { + f, err := params.getRetentionReportsAsMailAttachment() + if err != nil { + return err + } + files = append(files, f) + continue + } + fileAttachments = append(fileAttachments, attachment) + } + if len(fileAttachments) > 0 { + user, err := params.getUserFromSender() + if err != nil { + return err + } + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err = user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("error getting email attachments, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + res, err := getMailAttachments(conn, fileAttachments, replacer) + if err != nil { + return err + } + files = append(files, res...) + } + err := smtp.SendEmail(recipients, bcc, subject, body, smtp.EmailContentType(c.ContentType), files...) + eventManagerLog(logger.LevelDebug, "executed email notification action, elapsed: %s, error: %v", + time.Since(startTime), err) + if err != nil { + return fmt.Errorf("unable to send email: %w", err) + } + return nil +} + +func getUserForEventAction(user *dataprovider.User) error { + err := user.LoadAndApplyGroupSettings() + if err != nil { + eventManagerLog(logger.LevelError, "unable to get group for user %q: %+v", user.Username, err) + return fmt.Errorf("unable to get groups for user %q", user.Username) + } + user.UploadDataTransfer = 0 + user.UploadBandwidth = 0 + user.DownloadBandwidth = 0 + user.Filters.DisableFsChecks = false + user.Filters.FilePatterns = nil + user.Filters.BandwidthLimits = nil + for k := range user.Permissions { + user.Permissions[k] = []string{dataprovider.PermAny} + } + return nil +} + +func replacePathsPlaceholders(paths []string, replacer *strings.Replacer) []string { + results := make([]string, 0, len(paths)) + for _, p := range paths { + results = append(results, util.CleanPath(replaceWithReplacer(p, replacer))) + } + return util.RemoveDuplicates(results, false) +} + +func executeDeleteFileFsAction(conn *BaseConnection, item string, info os.FileInfo) error { + fs, fsPath, err := conn.GetFsAndResolvedPath(item) + if err != nil { + return err + } + return conn.RemoveFile(fs, fsPath, item, info) +} + +func executeDeleteFsActionForUser(deletes []string, replacer *strings.Replacer, user dataprovider.User) error { + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err := user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("delete error, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + for _, item := range replacePathsPlaceholders(deletes, replacer) { + info, err := conn.DoStat(item, 0, false) + if err != nil { + if conn.IsNotExistError(err) { + continue + } + return fmt.Errorf("unable to check item to delete %q, user %q: %w", item, user.Username, err) + } + if info.IsDir() { + if err = conn.RemoveDir(item); err != nil { + return fmt.Errorf("unable to remove dir %q, user %q: %w", item, user.Username, err) + } + } else { + if err = executeDeleteFileFsAction(conn, item, info); err != nil { + return fmt.Errorf("unable to remove file %q, user %q: %w", item, user.Username, err) + } + } + eventManagerLog(logger.LevelDebug, "item %q removed for user %q", item, user.Username) + } + return nil +} + +func executeDeleteFsRuleAction(deletes []string, replacer *strings.Replacer, + conditions dataprovider.ConditionOptions, params *EventParams, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping fs delete for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeDeleteFsActionForUser(deletes, replacer, user); err != nil { + params.AddError(err) + failures = append(failures, user.Username) + } + } + if len(failures) > 0 { + return fmt.Errorf("fs delete failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no delete executed") + return errors.New("no delete executed") + } + return nil +} + +func executeMkDirsFsActionForUser(dirs []string, replacer *strings.Replacer, user dataprovider.User) error { + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err := user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("mkdir error, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + for _, item := range replacePathsPlaceholders(dirs, replacer) { + if err = conn.CheckParentDirs(path.Dir(item)); err != nil { + return fmt.Errorf("unable to check parent dirs for %q, user %q: %w", item, user.Username, err) + } + if err = conn.createDirIfMissing(item); err != nil { + return fmt.Errorf("unable to create dir %q, user %q: %w", item, user.Username, err) + } + eventManagerLog(logger.LevelDebug, "directory %q created for user %q", item, user.Username) + } + return nil +} + +func executeMkdirFsRuleAction(dirs []string, replacer *strings.Replacer, + conditions dataprovider.ConditionOptions, params *EventParams, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping fs mkdir for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeMkDirsFsActionForUser(dirs, replacer, user); err != nil { + failures = append(failures, user.Username) + } + } + if len(failures) > 0 { + return fmt.Errorf("fs mkdir failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no mkdir executed") + return errors.New("no mkdir executed") + } + return nil +} + +func executeRenameFsActionForUser(renames []dataprovider.RenameConfig, replacer *strings.Replacer, + user dataprovider.User, +) error { + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err := user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("rename error, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + for _, item := range renames { + source := util.CleanPath(replaceWithReplacer(item.Key, replacer)) + target := util.CleanPath(replaceWithReplacer(item.Value, replacer)) + checks := 0 + if item.UpdateModTime { + checks += vfs.CheckUpdateModTime + } + if err = conn.renameInternal(source, target, true, checks); err != nil { + return fmt.Errorf("unable to rename %q->%q, user %q: %w", source, target, user.Username, err) + } + eventManagerLog(logger.LevelDebug, "rename %q->%q ok, user %q", source, target, user.Username) + } + return nil +} + +func executeCopyFsActionForUser(keyVals []dataprovider.KeyValue, replacer *strings.Replacer, + user dataprovider.User, +) error { + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err := user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("copy error, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + for _, item := range keyVals { + source := util.CleanPath(replaceWithReplacer(item.Key, replacer)) + target := util.CleanPath(replaceWithReplacer(item.Value, replacer)) + if strings.HasSuffix(item.Key, "/") { + source += "/" + } + if strings.HasSuffix(item.Value, "/") { + target += "/" + } + if err = conn.Copy(source, target); err != nil { + return fmt.Errorf("unable to copy %q->%q, user %q: %w", source, target, user.Username, err) + } + eventManagerLog(logger.LevelDebug, "copy %q->%q ok, user %q", source, target, user.Username) + } + return nil +} + +func executeExistFsActionForUser(exist []string, replacer *strings.Replacer, + user dataprovider.User, +) error { + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err := user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("existence check error, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + for _, item := range replacePathsPlaceholders(exist, replacer) { + if _, err = conn.DoStat(item, 0, false); err != nil { + return fmt.Errorf("error checking existence for path %q, user %q: %w", item, user.Username, err) + } + eventManagerLog(logger.LevelDebug, "path %q exists for user %q", item, user.Username) + } + return nil +} + +func executeRenameFsRuleAction(renames []dataprovider.RenameConfig, replacer *strings.Replacer, + conditions dataprovider.ConditionOptions, params *EventParams, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping fs rename for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeRenameFsActionForUser(renames, replacer, user); err != nil { + failures = append(failures, user.Username) + params.AddError(err) + } + } + if len(failures) > 0 { + return fmt.Errorf("fs rename failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no rename executed") + return errors.New("no rename executed") + } + return nil +} + +func executeCopyFsRuleAction(keyVals []dataprovider.KeyValue, replacer *strings.Replacer, + conditions dataprovider.ConditionOptions, params *EventParams, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + var executed int + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping fs copy for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeCopyFsActionForUser(keyVals, replacer, user); err != nil { + failures = append(failures, user.Username) + params.AddError(err) + } + } + if len(failures) > 0 { + return fmt.Errorf("fs copy failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no copy executed") + return errors.New("no copy executed") + } + return nil +} + +func getArchiveBaseDir(paths []string) string { + var parentDirs []string + for _, p := range paths { + parentDirs = append(parentDirs, path.Dir(p)) + } + parentDirs = util.RemoveDuplicates(parentDirs, false) + baseDir := "/" + if len(parentDirs) == 1 { + baseDir = parentDirs[0] + } + return baseDir +} + +func getSizeForPath(conn *BaseConnection, p string, info os.FileInfo) (int64, error) { + if info.IsDir() { + var dirSize int64 + lister, err := conn.ListDir(p) + if err != nil { + return 0, err + } + defer lister.Close() + for { + entries, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return 0, err + } + for _, entry := range entries { + size, err := getSizeForPath(conn, path.Join(p, entry.Name()), entry) + if err != nil { + return 0, err + } + dirSize += size + } + if finished { + return dirSize, nil + } + } + } + if info.Mode().IsRegular() { + return info.Size(), nil + } + return 0, nil +} + +func estimateZipSize(conn *BaseConnection, zipPath string, paths []string) (int64, error) { + q, _ := conn.HasSpace(false, false, zipPath) + if q.HasSpace && q.GetRemainingSize() > 0 { + var size int64 + for _, item := range paths { + info, err := conn.DoStat(item, 1, false) + if err != nil { + return size, err + } + itemSize, err := getSizeForPath(conn, item, info) + if err != nil { + return size, err + } + size += itemSize + } + eventManagerLog(logger.LevelDebug, "archive paths %v, archive name %q, size: %d", paths, zipPath, size) + // we assume the zip size will be half of the real size + return size / 2, nil + } + return -1, nil +} + +func executeCompressFsActionForUser(c dataprovider.EventActionFsCompress, replacer *strings.Replacer, + user dataprovider.User, +) error { + if err := getUserForEventAction(&user); err != nil { + return err + } + connectionID := fmt.Sprintf("%s_%s", protocolEventAction, xid.New().String()) + err := user.CheckFsRoot(connectionID) + defer user.CloseFs() //nolint:errcheck + if err != nil { + return fmt.Errorf("compress error, unable to check root fs for user %q: %w", user.Username, err) + } + conn := NewBaseConnection(connectionID, protocolEventAction, "", "", user) + defer conn.CloseFS() //nolint:errcheck + + name := util.CleanPath(replaceWithReplacer(c.Name, replacer)) + conn.CheckParentDirs(path.Dir(name)) //nolint:errcheck + paths := make([]string, 0, len(c.Paths)) + for idx := range c.Paths { + p := util.CleanPath(replaceWithReplacer(c.Paths[idx], replacer)) + if p == name { + return fmt.Errorf("cannot compress the archive to create: %q", name) + } + paths = append(paths, p) + } + paths = util.RemoveDuplicates(paths, false) + estimatedSize, err := estimateZipSize(conn, name, paths) + if err != nil { + eventManagerLog(logger.LevelError, "unable to estimate size for archive %q: %v", name, err) + return fmt.Errorf("unable to estimate archive size: %w", err) + } + writer, numFiles, truncatedSize, cancelFn, err := getFileWriter(conn, name, estimatedSize) + if err != nil { + eventManagerLog(logger.LevelError, "unable to create archive %q: %v", name, err) + return fmt.Errorf("unable to create archive: %w", err) + } + defer cancelFn() + + baseDir := getArchiveBaseDir(paths) + eventManagerLog(logger.LevelDebug, "creating archive %q for paths %+v", name, paths) + + zipWriter := &zipWriterWrapper{ + Name: name, + Writer: zip.NewWriter(writer), + Entries: make(map[string]bool), + } + startTime := time.Now() + for _, item := range paths { + if err := addZipEntry(zipWriter, conn, item, baseDir, nil, 0); err != nil { + closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck + return err + } + } + if err := zipWriter.Writer.Close(); err != nil { + eventManagerLog(logger.LevelError, "unable to close zip file %q: %v", name, err) + closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) //nolint:errcheck + return fmt.Errorf("unable to close zip file %q: %w", name, err) + } + return closeWriterAndUpdateQuota(writer, conn, name, "", numFiles, truncatedSize, err, operationUpload, startTime) +} + +func executeExistFsRuleAction(exist []string, replacer *strings.Replacer, conditions dataprovider.ConditionOptions, + params *EventParams, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping fs exist for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeExistFsActionForUser(exist, replacer, user); err != nil { + failures = append(failures, user.Username) + params.AddError(err) + } + } + if len(failures) > 0 { + return fmt.Errorf("fs existence check failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no existence check executed") + return errors.New("no existence check executed") + } + return nil +} + +func executeCompressFsRuleAction(c dataprovider.EventActionFsCompress, replacer *strings.Replacer, + conditions dataprovider.ConditionOptions, params *EventParams, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping fs compress for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeCompressFsActionForUser(c, replacer, user); err != nil { + failures = append(failures, user.Username) + params.AddError(err) + } + } + if len(failures) > 0 { + return fmt.Errorf("fs compress failed for users: %s", strings.Join(failures, ",")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no file/folder compressed") + return errors.New("no file/folder compressed") + } + return nil +} + +func executeFsRuleAction(c dataprovider.EventActionFilesystemConfig, conditions dataprovider.ConditionOptions, + params *EventParams, +) error { + addObjectData := false + replacements := params.getStringReplacements(addObjectData, 0) + replacer := strings.NewReplacer(replacements...) + switch c.Type { + case dataprovider.FilesystemActionRename: + return executeRenameFsRuleAction(c.Renames, replacer, conditions, params) + case dataprovider.FilesystemActionDelete: + return executeDeleteFsRuleAction(c.Deletes, replacer, conditions, params) + case dataprovider.FilesystemActionMkdirs: + return executeMkdirFsRuleAction(c.MkDirs, replacer, conditions, params) + case dataprovider.FilesystemActionExist: + return executeExistFsRuleAction(c.Exist, replacer, conditions, params) + case dataprovider.FilesystemActionCompress: + return executeCompressFsRuleAction(c.Compress, replacer, conditions, params) + case dataprovider.FilesystemActionCopy: + return executeCopyFsRuleAction(c.Copy, replacer, conditions, params) + default: + return fmt.Errorf("unsupported filesystem action %d", c.Type) + } +} + +func executeQuotaResetForUser(user *dataprovider.User) error { + if err := user.LoadAndApplyGroupSettings(); err != nil { + eventManagerLog(logger.LevelError, "skipping scheduled quota reset for user %s, cannot apply group settings: %v", + user.Username, err) + return err + } + if !QuotaScans.AddUserQuotaScan(user.Username, user.Role) { + eventManagerLog(logger.LevelError, "another quota scan is already in progress for user %q", user.Username) + return fmt.Errorf("another quota scan is in progress for user %q", user.Username) + } + defer QuotaScans.RemoveUserQuotaScan(user.Username) + + numFiles, size, err := user.ScanQuota() + if err != nil { + eventManagerLog(logger.LevelError, "error scanning quota for user %q: %v", user.Username, err) + return fmt.Errorf("error scanning quota for user %q: %w", user.Username, err) + } + err = dataprovider.UpdateUserQuota(user, numFiles, size, true) + if err != nil { + eventManagerLog(logger.LevelError, "error updating quota for user %q: %v", user.Username, err) + return fmt.Errorf("error updating quota for user %q: %w", user.Username, err) + } + return nil +} + +func executeUsersQuotaResetRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping quota reset for user %q, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeQuotaResetForUser(&user); err != nil { + params.AddError(err) + failures = append(failures, user.Username) + } + } + if len(failures) > 0 { + return fmt.Errorf("quota reset failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no user quota reset executed") + return errors.New("no user quota reset executed") + } + return nil +} + +func executeFoldersQuotaResetRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { + folders, err := params.getFolders() + if err != nil { + return fmt.Errorf("unable to get folders: %w", err) + } + var failures []string + executed := 0 + for _, folder := range folders { + // if sender is set, the conditions have already been evaluated + if params.sender == "" && !checkEventConditionPatterns(folder.Name, conditions.Names) { + eventManagerLog(logger.LevelDebug, "skipping scheduled quota reset for folder %s, name conditions don't match", + folder.Name) + continue + } + if !QuotaScans.AddVFolderQuotaScan(folder.Name) { + eventManagerLog(logger.LevelError, "another quota scan is already in progress for folder %q", folder.Name) + params.AddError(fmt.Errorf("another quota scan is already in progress for folder %q", folder.Name)) + failures = append(failures, folder.Name) + continue + } + executed++ + f := vfs.VirtualFolder{ + BaseVirtualFolder: folder, + VirtualPath: "/", + } + numFiles, size, err := f.ScanQuota() + QuotaScans.RemoveVFolderQuotaScan(folder.Name) + if err != nil { + eventManagerLog(logger.LevelError, "error scanning quota for folder %q: %v", folder.Name, err) + params.AddError(fmt.Errorf("error scanning quota for folder %q: %w", folder.Name, err)) + failures = append(failures, folder.Name) + continue + } + err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) + if err != nil { + eventManagerLog(logger.LevelError, "error updating quota for folder %q: %v", folder.Name, err) + params.AddError(fmt.Errorf("error updating quota for folder %q: %w", folder.Name, err)) + failures = append(failures, folder.Name) + } + } + if len(failures) > 0 { + return fmt.Errorf("quota reset failed for folders: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no folder quota reset executed") + return errors.New("no folder quota reset executed") + } + return nil +} + +func executeTransferQuotaResetRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping scheduled transfer quota reset for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + err = dataprovider.UpdateUserTransferQuota(&user, 0, 0, true) + if err != nil { + eventManagerLog(logger.LevelError, "error updating transfer quota for user %q: %v", user.Username, err) + params.AddError(fmt.Errorf("error updating transfer quota for user %q: %w", user.Username, err)) + failures = append(failures, user.Username) + } + } + if len(failures) > 0 { + return fmt.Errorf("transfer quota reset failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no transfer quota reset executed") + return errors.New("no transfer quota reset executed") + } + return nil +} + +func executeDataRetentionCheckForUser(user dataprovider.User, folders []dataprovider.FolderRetention, + params *EventParams, actionName string, +) error { + if err := user.LoadAndApplyGroupSettings(); err != nil { + eventManagerLog(logger.LevelError, "skipping scheduled retention check for user %s, cannot apply group settings: %v", + user.Username, err) + return err + } + check := RetentionCheck{ + Folders: folders, + } + c := RetentionChecks.Add(check, &user) + if c == nil { + eventManagerLog(logger.LevelError, "another retention check is already in progress for user %q", user.Username) + return fmt.Errorf("another retention check is in progress for user %q", user.Username) + } + defer func() { + params.retentionChecks = append(params.retentionChecks, executedRetentionCheck{ + Username: user.Username, + ActionName: actionName, + Results: c.results, + }) + }() + if err := c.Start(); err != nil { + eventManagerLog(logger.LevelError, "error checking retention for user %q: %v", user.Username, err) + return fmt.Errorf("error checking retention for user %q: %w", user.Username, err) + } + return nil +} + +func executeDataRetentionCheckRuleAction(config dataprovider.EventActionDataRetentionConfig, + conditions dataprovider.ConditionOptions, params *EventParams, actionName string, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + executed := 0 + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping scheduled retention check for user %s, condition options don't match", + user.Username) + continue + } + } + executed++ + if err = executeDataRetentionCheckForUser(user, config.Folders, params, actionName); err != nil { + failures = append(failures, user.Username) + params.AddError(err) + } + } + if len(failures) > 0 { + return fmt.Errorf("retention check failed for users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no retention check executed") + return errors.New("no retention check executed") + } + return nil +} + +func executeUserExpirationCheckRuleAction(conditions dataprovider.ConditionOptions, params *EventParams) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + var executed int + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping expiration check for user %q, condition options don't match", + user.Username) + continue + } + } + executed++ + if user.ExpirationDate > 0 { + expDate := util.GetTimeFromMsecSinceEpoch(user.ExpirationDate) + if expDate.Before(time.Now()) { + failures = append(failures, user.Username) + } + } + } + if len(failures) > 0 { + return fmt.Errorf("expired users: %s", strings.Join(failures, ", ")) + } + if executed == 0 { + eventManagerLog(logger.LevelError, "no user expiration check executed") + return errors.New("no user expiration check executed") + } + return nil +} + +func executeInactivityCheckForUser(user *dataprovider.User, config dataprovider.EventActionUserInactivity, when time.Time) error { + if config.DeleteThreshold > 0 && (user.Status == 0 || config.DisableThreshold == 0) { + if inactivityDays := user.InactivityDays(when); inactivityDays > config.DeleteThreshold { + err := dataprovider.DeleteUser(user.Username, dataprovider.ActionExecutorSystem, "", "") + eventManagerLog(logger.LevelInfo, "deleting inactive user %q, days of inactivity: %d/%d, err: %v", + user.Username, inactivityDays, config.DeleteThreshold, err) + if err != nil { + return fmt.Errorf("unable to delete inactive user %q", user.Username) + } + return fmt.Errorf("inactive user %q deleted. Number of days of inactivity: %d", user.Username, inactivityDays) + } + } + if config.DisableThreshold > 0 && user.Status > 0 { + if inactivityDays := user.InactivityDays(when); inactivityDays > config.DisableThreshold { + user.Status = 0 + err := dataprovider.UpdateUser(user, dataprovider.ActionExecutorSystem, "", "") + eventManagerLog(logger.LevelInfo, "disabling inactive user %q, days of inactivity: %d/%d, err: %v", + user.Username, inactivityDays, config.DisableThreshold, err) + if err != nil { + return fmt.Errorf("unable to disable inactive user %q", user.Username) + } + return fmt.Errorf("inactive user %q disabled. Number of days of inactivity: %d", user.Username, inactivityDays) + } + } + + return nil +} + +func executeUserInactivityCheckRuleAction(config dataprovider.EventActionUserInactivity, + conditions dataprovider.ConditionOptions, + params *EventParams, + when time.Time, +) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping inactivity check for user %q, condition options don't match", + user.Username) + continue + } + } + if err = executeInactivityCheckForUser(&user, config, when); err != nil { + params.AddError(err) + failures = append(failures, user.Username) + } + } + if len(failures) > 0 { + return fmt.Errorf("executed inactivity check actions for users: %s", strings.Join(failures, ", ")) + } + + return nil +} + +func executePwdExpirationCheckForUser(user *dataprovider.User, config dataprovider.EventActionPasswordExpiration) error { + if err := user.LoadAndApplyGroupSettings(); err != nil { + eventManagerLog(logger.LevelError, "skipping password expiration check for user %q, cannot apply group settings: %v", + user.Username, err) + return err + } + if user.ExpirationDate > 0 { + if expDate := util.GetTimeFromMsecSinceEpoch(user.ExpirationDate); expDate.Before(time.Now()) { + eventManagerLog(logger.LevelDebug, "skipping password expiration check for expired user %q, expiration date: %s", + user.Username, expDate) + return nil + } + } + if user.Filters.PasswordExpiration == 0 { + eventManagerLog(logger.LevelDebug, "password expiration not set for user %q skipping check", user.Username) + return nil + } + days := user.PasswordExpiresIn() + if days > config.Threshold { + eventManagerLog(logger.LevelDebug, "password for user %q expires in %d days, threshold %d, no need to notify", + user.Username, days, config.Threshold) + return nil + } + body := new(bytes.Buffer) + data := make(map[string]any) + data["Username"] = user.Username + data["Days"] = days + if err := smtp.RenderPasswordExpirationTemplate(body, data); err != nil { + eventManagerLog(logger.LevelError, "unable to notify password expiration for user %s: %v", + user.Username, err) + return err + } + subject := "SFTPGo password expiration notification" + startTime := time.Now() + if err := smtp.SendEmail(user.GetEmailAddresses(), nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil { + eventManagerLog(logger.LevelError, "unable to notify password expiration for user %s: %v, elapsed: %s", + user.Username, err, time.Since(startTime)) + return err + } + eventManagerLog(logger.LevelDebug, "password expiration email sent to user %s, days: %d, elapsed: %s", + user.Username, days, time.Since(startTime)) + return nil +} + +func executePwdExpirationCheckRuleAction(config dataprovider.EventActionPasswordExpiration, conditions dataprovider.ConditionOptions, + params *EventParams) error { + users, err := params.getUsers() + if err != nil { + return fmt.Errorf("unable to get users: %w", err) + } + var failures []string + for _, user := range users { + // if sender is set, the conditions have already been evaluated + if params.sender == "" { + if !checkUserConditionOptions(&user, &conditions) { + eventManagerLog(logger.LevelDebug, "skipping password check for user %q, condition options don't match", + user.Username) + continue + } + } + if err = executePwdExpirationCheckForUser(&user, config); err != nil { + params.AddError(err) + failures = append(failures, user.Username) + } + } + if len(failures) > 0 { + return fmt.Errorf("password expiration check failed for users: %s", strings.Join(failures, ", ")) + } + + return nil +} + +func executeAdminCheckAction(c *dataprovider.EventActionIDPAccountCheck, params *EventParams) (*dataprovider.Admin, error) { + admin, err := dataprovider.AdminExists(params.Name) + exists := err == nil + if exists && c.Mode == 1 { + return &admin, nil + } + if err != nil && !errors.Is(err, util.ErrNotFound) { + return nil, err + } + + replacements := params.getStringReplacements(false, 1) + replacer := strings.NewReplacer(replacements...) + data := replaceWithReplacer(c.TemplateAdmin, replacer) + + var newAdmin dataprovider.Admin + err = json.Unmarshal(util.StringToBytes(data), &newAdmin) + if err != nil { + return nil, err + } + if exists { + eventManagerLog(logger.LevelDebug, "updating admin %q after IDP login", params.Name) + // Not sure if this makes sense, but it shouldn't hurt. + if newAdmin.Password == "" { + newAdmin.Password = admin.Password + } + newAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig + newAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes + err = dataprovider.UpdateAdmin(&newAdmin, dataprovider.ActionExecutorSystem, "", "") + } else { + eventManagerLog(logger.LevelDebug, "creating admin %q after IDP login", params.Name) + if newAdmin.Password == "" { + newAdmin.Password = util.GenerateUniqueID() + } + err = dataprovider.AddAdmin(&newAdmin, dataprovider.ActionExecutorSystem, "", "") + } + return &newAdmin, err +} + +func preserveUserProfile(user, newUser *dataprovider.User) { + if newUser.CanChangePassword() && user.Password != "" { + newUser.Password = user.Password + } + if newUser.CanManagePublicKeys() && len(user.PublicKeys) > 0 { + newUser.PublicKeys = user.PublicKeys + } + if newUser.CanManageTLSCerts() { + if len(user.Filters.TLSCerts) > 0 { + newUser.Filters.TLSCerts = user.Filters.TLSCerts + } + } + if newUser.CanChangeInfo() { + if user.Description != "" { + newUser.Description = user.Description + } + if user.Email != "" { + newUser.Email = user.Email + } + if len(user.Filters.AdditionalEmails) > 0 { + newUser.Filters.AdditionalEmails = user.Filters.AdditionalEmails + } + } + if newUser.CanChangeAPIKeyAuth() { + newUser.Filters.AllowAPIKeyAuth = user.Filters.AllowAPIKeyAuth + } + newUser.Filters.RecoveryCodes = user.Filters.RecoveryCodes + newUser.Filters.TOTPConfig = user.Filters.TOTPConfig + newUser.LastPasswordChange = user.LastPasswordChange + newUser.SetEmptySecretsIfNil() +} + +func executeUserCheckAction(c *dataprovider.EventActionIDPAccountCheck, params *EventParams) (*dataprovider.User, error) { + user, err := dataprovider.UserExists(params.Name, "") + exists := err == nil + if exists && c.Mode == 1 { + err = user.LoadAndApplyGroupSettings() + return &user, err + } + if err != nil && !errors.Is(err, util.ErrNotFound) { + return nil, err + } + replacements := params.getStringReplacements(false, 1) + replacer := strings.NewReplacer(replacements...) + data := replaceWithReplacer(c.TemplateUser, replacer) + + var newUser dataprovider.User + err = json.Unmarshal(util.StringToBytes(data), &newUser) + if err != nil { + return nil, err + } + if exists { + eventManagerLog(logger.LevelDebug, "updating user %q after IDP login", params.Name) + preserveUserProfile(&user, &newUser) + err = dataprovider.UpdateUser(&newUser, dataprovider.ActionExecutorSystem, "", "") + } else { + eventManagerLog(logger.LevelDebug, "creating user %q after IDP login", params.Name) + err = dataprovider.AddUser(&newUser, dataprovider.ActionExecutorSystem, "", "") + } + if err != nil { + return nil, err + } + u, err := dataprovider.GetUserWithGroupSettings(params.Name, "") + return &u, err +} + +func executeRuleAction(action dataprovider.BaseEventAction, params *EventParams, //nolint:gocyclo + conditions dataprovider.ConditionOptions, +) error { + if len(conditions.EventStatuses) > 0 && !slices.Contains(conditions.EventStatuses, params.Status) { + eventManagerLog(logger.LevelDebug, "skipping action %s, event status %d does not match: %v", + action.Name, params.Status, conditions.EventStatuses) + return nil + } + var err error + + switch action.Type { + case dataprovider.ActionTypeHTTP: + err = executeHTTPRuleAction(action.Options.HTTPConfig, params) + case dataprovider.ActionTypeCommand: + err = executeCommandRuleAction(action.Options.CmdConfig, params) + case dataprovider.ActionTypeEmail: + err = executeEmailRuleAction(action.Options.EmailConfig, params) + case dataprovider.ActionTypeBackup: + var backupPath string + backupPath, err = dataprovider.ExecuteBackup() + if err == nil { + params.setBackupParams(backupPath) + } + case dataprovider.ActionTypeUserQuotaReset: + err = executeUsersQuotaResetRuleAction(conditions, params) + case dataprovider.ActionTypeFolderQuotaReset: + err = executeFoldersQuotaResetRuleAction(conditions, params) + case dataprovider.ActionTypeTransferQuotaReset: + err = executeTransferQuotaResetRuleAction(conditions, params) + case dataprovider.ActionTypeDataRetentionCheck: + err = executeDataRetentionCheckRuleAction(action.Options.RetentionConfig, conditions, params, action.Name) + case dataprovider.ActionTypeFilesystem: + err = executeFsRuleAction(action.Options.FsConfig, conditions, params) + case dataprovider.ActionTypePasswordExpirationCheck: + err = executePwdExpirationCheckRuleAction(action.Options.PwdExpirationConfig, conditions, params) + case dataprovider.ActionTypeUserExpirationCheck: + err = executeUserExpirationCheckRuleAction(conditions, params) + case dataprovider.ActionTypeUserInactivityCheck: + err = executeUserInactivityCheckRuleAction(action.Options.UserInactivityConfig, conditions, params, time.Now()) + case dataprovider.ActionTypeRotateLogs: + err = logger.RotateLogFile() + default: + err = fmt.Errorf("unsupported action type: %d", action.Type) + } + + if err != nil { + err = fmt.Errorf("action %q failed: %w", action.Name, err) + } + params.AddError(err) + return err +} + +func executeIDPAccountCheckRule(rule dataprovider.EventRule, params EventParams) (*dataprovider.User, + *dataprovider.Admin, error, +) { + for _, action := range rule.Actions { + if action.Type == dataprovider.ActionTypeIDPAccountCheck { + startTime := time.Now() + var user *dataprovider.User + var admin *dataprovider.Admin + var err error + var failedActions []string + paramsCopy := params.getACopy() + + switch params.Event { + case IDPLoginAdmin: + admin, err = executeAdminCheckAction(&action.BaseEventAction.Options.IDPConfig, paramsCopy) + case IDPLoginUser: + user, err = executeUserCheckAction(&action.BaseEventAction.Options.IDPConfig, paramsCopy) + default: + err = fmt.Errorf("unsupported IDP login event: %q", params.Event) + } + if err != nil { + paramsCopy.AddError(fmt.Errorf("unable to handle %q: %w", params.Event, err)) + eventManagerLog(logger.LevelError, "unable to handle IDP login event %q, err: %v", params.Event, err) + failedActions = append(failedActions, action.Name) + } else { + eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s", + action.Name, rule.Name, time.Since(startTime)) + } + // execute async actions if any, including failure actions + go executeRuleAsyncActions(rule, paramsCopy, failedActions) + return user, admin, err + } + } + eventManagerLog(logger.LevelError, "no action executed for IDP login event %q, event rule: %q", params.Event, rule.Name) + return nil, nil, errors.New("no action executed") +} + +func executeSyncRulesActions(rules []dataprovider.EventRule, params EventParams) error { + var errRes error + + for _, rule := range rules { + var failedActions []string + paramsCopy := params.getACopy() + for _, action := range rule.Actions { + if !action.Options.IsFailureAction && action.Options.ExecuteSync { + startTime := time.Now() + if err := executeRuleAction(action.BaseEventAction, paramsCopy, rule.Conditions.Options); err != nil { + eventManagerLog(logger.LevelError, "unable to execute sync action %q for rule %q, elapsed %s, err: %v", + action.Name, rule.Name, time.Since(startTime), err) + failedActions = append(failedActions, action.Name) + // we return the last error, it is ok for now + errRes = err + if action.Options.StopOnFailure { + break + } + } else { + eventManagerLog(logger.LevelDebug, "executed sync action %q for rule %q, elapsed: %s", + action.Name, rule.Name, time.Since(startTime)) + } + } + } + // execute async actions if any, including failure actions + go executeRuleAsyncActions(rule, paramsCopy, failedActions) + } + + return errRes +} + +func executeAsyncRulesActions(rules []dataprovider.EventRule, params EventParams) { + eventManager.addAsyncTask() + defer eventManager.removeAsyncTask() + + params.addUID() + for _, rule := range rules { + executeRuleAsyncActions(rule, params.getACopy(), nil) + } +} + +func executeRuleAsyncActions(rule dataprovider.EventRule, params *EventParams, failedActions []string) { + for _, action := range rule.Actions { + if !action.Options.IsFailureAction && !action.Options.ExecuteSync { + startTime := time.Now() + if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { + eventManagerLog(logger.LevelError, "unable to execute action %q for rule %q, elapsed %s, err: %v", + action.Name, rule.Name, time.Since(startTime), err) + failedActions = append(failedActions, action.Name) + if action.Options.StopOnFailure { + break + } + } else { + eventManagerLog(logger.LevelDebug, "executed action %q for rule %q, elapsed %s", + action.Name, rule.Name, time.Since(startTime)) + } + } + } + if len(failedActions) > 0 { + params.updateStatusFromError = false + // execute failure actions + for _, action := range rule.Actions { + if action.Options.IsFailureAction { + startTime := time.Now() + if err := executeRuleAction(action.BaseEventAction, params, rule.Conditions.Options); err != nil { + eventManagerLog(logger.LevelError, "unable to execute failure action %q for rule %q, elapsed %s, err: %v", + action.Name, rule.Name, time.Since(startTime), err) + if action.Options.StopOnFailure { + break + } + } else { + eventManagerLog(logger.LevelDebug, "executed failure action %q for rule %q, elapsed: %s", + action.Name, rule.Name, time.Since(startTime)) + } + } + } + } +} + +type eventCronJob struct { + ruleName string +} + +func (j *eventCronJob) getTask(rule *dataprovider.EventRule) (dataprovider.Task, error) { + if rule.GuardFromConcurrentExecution() { + task, err := dataprovider.GetTaskByName(rule.Name) + if err != nil { + if errors.Is(err, util.ErrNotFound) { + eventManagerLog(logger.LevelDebug, "adding task for rule %q", rule.Name) + task = dataprovider.Task{ + Name: rule.Name, + UpdateAt: 0, + Version: 0, + } + err = dataprovider.AddTask(rule.Name) + if err != nil { + eventManagerLog(logger.LevelWarn, "unable to add task for rule %q: %v", rule.Name, err) + return task, err + } + } else { + eventManagerLog(logger.LevelWarn, "unable to get task for rule %q: %v", rule.Name, err) + } + } + return task, err + } + + return dataprovider.Task{}, nil +} + +func (j *eventCronJob) getEventParams() EventParams { + return EventParams{ + Event: "Schedule", + Name: j.ruleName, + Status: 1, + Timestamp: time.Now(), + updateStatusFromError: true, + } +} + +func (j *eventCronJob) Run() { + eventManagerLog(logger.LevelDebug, "executing scheduled rule %q", j.ruleName) + rule, err := dataprovider.EventRuleExists(j.ruleName) + if err != nil { + eventManagerLog(logger.LevelError, "unable to load rule with name %q", j.ruleName) + return + } + if err := rule.CheckActionsConsistency(""); err != nil { + eventManagerLog(logger.LevelWarn, "scheduled rule %q skipped: %v", rule.Name, err) + return + } + task, err := j.getTask(&rule) + if err != nil { + return + } + if task.Name != "" { + updateInterval := 5 * time.Minute + updatedAt := util.GetTimeFromMsecSinceEpoch(task.UpdateAt) + if updatedAt.Add(updateInterval*2 + 1).After(time.Now()) { + eventManagerLog(logger.LevelDebug, "task for rule %q too recent: %s, skip execution", rule.Name, updatedAt) + return + } + err = dataprovider.UpdateTask(rule.Name, task.Version) + if err != nil { + eventManagerLog(logger.LevelInfo, "unable to update task timestamp for rule %q, skip execution, err: %v", + rule.Name, err) + return + } + ticker := time.NewTicker(updateInterval) + done := make(chan bool) + + defer func() { + done <- true + ticker.Stop() + }() + + go func(taskName string) { + eventManagerLog(logger.LevelDebug, "update task %q timestamp worker started", taskName) + for { + select { + case <-done: + eventManagerLog(logger.LevelDebug, "update task %q timestamp worker finished", taskName) + return + case <-ticker.C: + err := dataprovider.UpdateTaskTimestamp(taskName) + eventManagerLog(logger.LevelInfo, "updated timestamp for task %q, err: %v", taskName, err) + } + } + }(task.Name) + + executeAsyncRulesActions([]dataprovider.EventRule{rule}, j.getEventParams()) + } else { + executeAsyncRulesActions([]dataprovider.EventRule{rule}, j.getEventParams()) + } + eventManagerLog(logger.LevelDebug, "execution for scheduled rule %q finished", j.ruleName) +} + +// RunOnDemandRule executes actions for a rule with on-demand trigger +func RunOnDemandRule(name string) error { + eventManagerLog(logger.LevelDebug, "executing on demand rule %q", name) + rule, err := dataprovider.EventRuleExists(name) + if err != nil { + eventManagerLog(logger.LevelDebug, "unable to load rule with name %q", name) + return util.NewRecordNotFoundError(fmt.Sprintf("rule %q does not exist", name)) + } + if rule.Trigger != dataprovider.EventTriggerOnDemand { + eventManagerLog(logger.LevelDebug, "cannot run rule %q as on demand, trigger: %d", name, rule.Trigger) + return util.NewValidationError(fmt.Sprintf("rule %q is not defined as on-demand", name)) + } + if rule.Status != 1 { + eventManagerLog(logger.LevelDebug, "on-demand rule %q is inactive", name) + return util.NewValidationError(fmt.Sprintf("rule %q is inactive", name)) + } + if err := rule.CheckActionsConsistency(""); err != nil { + eventManagerLog(logger.LevelError, "on-demand rule %q has incompatible actions: %v", name, err) + return util.NewValidationError(fmt.Sprintf("rule %q has incosistent actions", name)) + } + eventManagerLog(logger.LevelDebug, "on-demand rule %q started", name) + go executeAsyncRulesActions([]dataprovider.EventRule{rule}, EventParams{Status: 1, updateStatusFromError: true}) + return nil +} + +type zipWriterWrapper struct { + Name string + Entries map[string]bool + Writer *zip.Writer +} + +func eventManagerLog(level logger.LogLevel, format string, v ...any) { + logger.Log(level, "eventmanager", "", format, v...) +} diff --git a/internal/common/eventmanager_test.go b/internal/common/eventmanager_test.go new file mode 100644 index 00000000..d350421e --- /dev/null +++ b/internal/common/eventmanager_test.go @@ -0,0 +1,2511 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/klauspost/compress/zip" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func TestEventRuleMatch(t *testing.T) { + role := "role1" + conditions := &dataprovider.EventConditions{ + ProviderEvents: []string{"add", "update"}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "user1", + InverseMatch: true, + }, + }, + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: role, + }, + }, + }, + } + res := eventManager.checkProviderEventMatch(conditions, &EventParams{ + Name: "user1", + Role: role, + Event: "add", + }) + assert.False(t, res) + res = eventManager.checkProviderEventMatch(conditions, &EventParams{ + Name: "user2", + Role: role, + Event: "update", + }) + assert.True(t, res) + res = eventManager.checkProviderEventMatch(conditions, &EventParams{ + Name: "user2", + Role: role, + Event: "delete", + }) + assert.False(t, res) + conditions.Options.ProviderObjects = []string{"api_key"} + res = eventManager.checkProviderEventMatch(conditions, &EventParams{ + Name: "user2", + Event: "update", + Role: role, + ObjectType: "share", + }) + assert.False(t, res) + res = eventManager.checkProviderEventMatch(conditions, &EventParams{ + Name: "user2", + Event: "update", + Role: role, + ObjectType: "api_key", + }) + assert.True(t, res) + res = eventManager.checkProviderEventMatch(conditions, &EventParams{ + Name: "user2", + Event: "update", + Role: role + "1", + ObjectType: "api_key", + }) + assert.False(t, res) + // now test fs events + conditions = &dataprovider.EventConditions{ + FsEvents: []string{operationUpload, operationDownload}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "user*", + }, + { + Pattern: "tester*", + }, + }, + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: role, + InverseMatch: true, + }, + }, + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "/**/*.txt", + }, + }, + Protocols: []string{ProtocolSFTP}, + MinFileSize: 10, + MaxFileSize: 30, + }, + } + params := EventParams{ + Name: "tester4", + Event: operationDelete, + VirtualPath: "/path.txt", + Protocol: ProtocolSFTP, + ObjectName: "path.txt", + FileSize: 20, + } + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.Event = operationDownload + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.True(t, res) + params.Role = role + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.Role = "" + params.Name = "name" + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.Name = "user5" + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.True(t, res) + params.VirtualPath = "/sub/f.jpg" + params.ObjectName = path.Base(params.VirtualPath) + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.VirtualPath = "/sub/f.txt" + params.ObjectName = path.Base(params.VirtualPath) + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.True(t, res) + params.Protocol = ProtocolHTTP + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.Protocol = ProtocolSFTP + params.FileSize = 5 + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.FileSize = 50 + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.FileSize = 25 + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.True(t, res) + // bad pattern + conditions.Options.Names = []dataprovider.ConditionPattern{ + { + Pattern: "[-]", + }, + } + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + // check fs events with group name filters + conditions = &dataprovider.EventConditions{ + FsEvents: []string{operationUpload, operationDownload}, + Options: dataprovider.ConditionOptions{ + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "group*", + }, + { + Pattern: "testgroup*", + }, + }, + }, + } + params = EventParams{ + Name: "user1", + Event: operationUpload, + } + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.Groups = []sdk.GroupMapping{ + { + Name: "g1", + Type: sdk.GroupTypePrimary, + }, + { + Name: "g2", + Type: sdk.GroupTypeSecondary, + }, + } + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.False(t, res) + params.Groups = []sdk.GroupMapping{ + { + Name: "testgroup2", + Type: sdk.GroupTypePrimary, + }, + { + Name: "g2", + Type: sdk.GroupTypeSecondary, + }, + } + res = eventManager.checkFsEventMatch(conditions, ¶ms) + assert.True(t, res) + // check user conditions + user := dataprovider.User{} + user.Username = "u1" + res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{}) + assert.True(t, res) + res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "user", + }, + }, + }) + assert.False(t, res) + res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: role, + }, + }, + }) + assert.False(t, res) + user.Role = role + res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: role, + }, + }, + }) + assert.True(t, res) + res = checkUserConditionOptions(&user, &dataprovider.ConditionOptions{ + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "group", + }, + }, + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: role, + }, + }, + }) + assert.False(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 0, + }, &EventParams{ + Event: IDPLoginAdmin, + }) + assert.True(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 2, + }, &EventParams{ + Event: IDPLoginAdmin, + }) + assert.True(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 1, + }, &EventParams{ + Event: IDPLoginAdmin, + }) + assert.False(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 1, + }, &EventParams{ + Event: IDPLoginUser, + }) + assert.True(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 1, + }, &EventParams{ + Name: "user", + Event: IDPLoginUser, + }) + assert.True(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 1, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "abc", + }, + }, + }, + }, &EventParams{ + Name: "user", + Event: IDPLoginUser, + }) + assert.False(t, res) + res = eventManager.checkIPDLoginEventMatch(&dataprovider.EventConditions{ + IDPLoginEvent: 2, + }, &EventParams{ + Name: "user", + Event: IDPLoginUser, + }) + assert.False(t, res) +} + +func TestDoubleStarMatching(t *testing.T) { + c := dataprovider.ConditionPattern{ + Pattern: "/mydir/**", + } + res := checkEventConditionPattern(c, "/mydir") + assert.True(t, res) + res = checkEventConditionPattern(c, "/mydirname") + assert.False(t, res) + res = checkEventConditionPattern(c, "/mydir/sub") + assert.True(t, res) + res = checkEventConditionPattern(c, "/mydir/sub/dir") + assert.True(t, res) + + c.Pattern = "/**/*" + res = checkEventConditionPattern(c, "/mydir") + assert.True(t, res) + res = checkEventConditionPattern(c, "/mydirname") + assert.True(t, res) + res = checkEventConditionPattern(c, "/mydir/sub/dir/file.txt") + assert.True(t, res) + + c.Pattern = "/**/*.filepart" + res = checkEventConditionPattern(c, "/file.filepart") + assert.True(t, res) + res = checkEventConditionPattern(c, "/mydir/sub/file.filepart") + assert.True(t, res) + res = checkEventConditionPattern(c, "/file.txt") + assert.False(t, res) + res = checkEventConditionPattern(c, "/mydir/file.txt") + assert.False(t, res) + + c.Pattern = "/mydir/**/*.txt" + res = checkEventConditionPattern(c, "/mydir") + assert.False(t, res) + res = checkEventConditionPattern(c, "/mydirname/f.txt") + assert.False(t, res) + res = checkEventConditionPattern(c, "/mydir/sub") + assert.False(t, res) + res = checkEventConditionPattern(c, "/mydir/sub/dir") + assert.False(t, res) + res = checkEventConditionPattern(c, "/mydir/sub/dir/a.txt") + assert.True(t, res) + + c.InverseMatch = true + assert.True(t, checkEventConditionPattern(c, "/mydir")) + assert.True(t, checkEventConditionPattern(c, "/mydirname/f.txt")) + assert.True(t, checkEventConditionPattern(c, "/mydir/sub")) + assert.True(t, checkEventConditionPattern(c, "/mydir/sub/dir")) + assert.False(t, checkEventConditionPattern(c, "/mydir/sub/dir/a.txt")) +} + +func TestMutlipleDoubleStarMatching(t *testing.T) { + patterns := []dataprovider.ConditionPattern{ + { + Pattern: "/**/*.txt", + InverseMatch: false, + }, + { + Pattern: "/**/*.tmp", + InverseMatch: false, + }, + } + assert.False(t, checkEventConditionPatterns("/mydir", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/test.tmp", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/test.txt", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/test.csv", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/sub", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/sub/test.tmp", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/sub/test.txt", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/sub/test.csv", patterns)) +} + +func TestMultipleDoubleStarMatchingInverse(t *testing.T) { + patterns := []dataprovider.ConditionPattern{ + { + Pattern: "/**/*.txt", + InverseMatch: true, + }, + { + Pattern: "/**/*.tmp", + InverseMatch: true, + }, + } + assert.True(t, checkEventConditionPatterns("/mydir", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/test.tmp", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/test.txt", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/test.csv", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/sub", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/sub/test.tmp", patterns)) + assert.False(t, checkEventConditionPatterns("/mydir/sub/test.txt", patterns)) + assert.True(t, checkEventConditionPatterns("/mydir/sub/test.csv", patterns)) +} + +func TestGroupConditionPatterns(t *testing.T) { + group1 := "group1" + group2 := "group2" + patterns := []dataprovider.ConditionPattern{ + { + Pattern: group1, + }, + { + Pattern: group2, + }, + } + inversePatterns := []dataprovider.ConditionPattern{ + { + Pattern: group1, + InverseMatch: true, + }, + { + Pattern: group2, + InverseMatch: true, + }, + } + groups := []sdk.GroupMapping{ + { + Name: "group3", + Type: sdk.GroupTypePrimary, + }, + } + assert.False(t, checkEventGroupConditionPatterns(groups, patterns)) + assert.True(t, checkEventGroupConditionPatterns(groups, inversePatterns)) + + groups = []sdk.GroupMapping{ + { + Name: group1, + Type: sdk.GroupTypePrimary, + }, + { + Name: "group4", + Type: sdk.GroupTypePrimary, + }, + } + assert.True(t, checkEventGroupConditionPatterns(groups, patterns)) + assert.False(t, checkEventGroupConditionPatterns(groups, inversePatterns)) + groups = []sdk.GroupMapping{ + { + Name: group1, + Type: sdk.GroupTypePrimary, + }, + } + assert.True(t, checkEventGroupConditionPatterns(groups, patterns)) + assert.False(t, checkEventGroupConditionPatterns(groups, inversePatterns)) + groups = []sdk.GroupMapping{ + { + Name: "group11", + Type: sdk.GroupTypePrimary, + }, + } + assert.False(t, checkEventGroupConditionPatterns(groups, patterns)) + assert.True(t, checkEventGroupConditionPatterns(groups, inversePatterns)) +} + +func TestEventManager(t *testing.T) { + startEventScheduler() + action := &dataprovider.BaseEventAction{ + Name: "test_action", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "http://localhost", + Timeout: 20, + Method: http.MethodGet, + }, + }, + } + err := dataprovider.AddEventAction(action, "", "", "") + assert.NoError(t, err) + rule := &dataprovider.EventRule{ + Name: "rule", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{operationUpload}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + + err = dataprovider.AddEventRule(rule, "", "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 1) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + rule.Trigger = dataprovider.EventTriggerProviderEvent + rule.Conditions = dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + } + err = dataprovider.UpdateEventRule(rule, "", "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 1) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + rule.Trigger = dataprovider.EventTriggerSchedule + rule.Conditions = dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "0", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + } + rule.DeletedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour)) + eventManager.addUpdateRuleInternal(*rule) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + assert.Eventually(t, func() bool { + _, err = dataprovider.EventRuleExists(rule.Name) + ok := errors.Is(err, util.ErrNotFound) + return ok + }, 2*time.Second, 100*time.Millisecond) + + rule.DeletedAt = 0 + err = dataprovider.AddEventRule(rule, "", "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 1) + assert.Len(t, eventManager.schedulesMapping, 1) + eventManager.RUnlock() + + err = dataprovider.DeleteEventRule(rule.Name, "", "", "") + assert.NoError(t, err) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + assert.Len(t, eventManager.schedulesMapping, 0) + eventManager.RUnlock() + + err = dataprovider.DeleteEventAction(action.Name, "", "", "") + assert.NoError(t, err) + stopEventScheduler() +} + +func TestEventManagerErrors(t *testing.T) { + startEventScheduler() + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + + params := EventParams{ + sender: "sender", + } + _, err = params.getUsers() + assert.Error(t, err) + _, err = params.getFolders() + assert.Error(t, err) + + err = executeUsersQuotaResetRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeFoldersQuotaResetRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeTransferQuotaResetRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeUserExpirationCheckRuleAction(dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{}, + dataprovider.ConditionOptions{}, &EventParams{}, time.Time{}) + assert.Error(t, err) + err = executeDeleteFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeMkdirFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeRenameFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeExistFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeCopyFsRuleAction(nil, nil, dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executeCompressFsRuleAction(dataprovider.EventActionFsCompress{}, nil, dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + err = executePwdExpirationCheckRuleAction(dataprovider.EventActionPasswordExpiration{}, + dataprovider.ConditionOptions{}, &EventParams{}) + assert.Error(t, err) + _, err = executeAdminCheckAction(&dataprovider.EventActionIDPAccountCheck{}, &EventParams{}) + assert.Error(t, err) + _, err = executeUserCheckAction(&dataprovider.EventActionIDPAccountCheck{}, &EventParams{}) + assert.Error(t, err) + + groupName := "agroup" + err = executeQuotaResetForUser(&dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executeDataRetentionCheckForUser(dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }, nil, &EventParams{}, "") + assert.Error(t, err) + err = executeDeleteFsActionForUser(nil, nil, dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executeMkDirsFsActionForUser(nil, nil, dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executeRenameFsActionForUser(nil, nil, dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executeExistFsActionForUser(nil, nil, dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executeCopyFsActionForUser(nil, nil, dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executeCompressFsActionForUser(dataprovider.EventActionFsCompress{}, nil, dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }) + assert.Error(t, err) + err = executePwdExpirationCheckForUser(&dataprovider.User{ + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }}, dataprovider.EventActionPasswordExpiration{}) + assert.Error(t, err) + + _, _, err = getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{ + Method: http.MethodPost, + Parts: []dataprovider.HTTPPart{ + { + Name: "p1", + }, + }, + }, nil, nil, dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "u", + }, + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + }, &EventParams{}, false) + assert.Error(t, err) + + dataRetentionAction := dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeDataRetentionCheck, + Options: dataprovider.BaseEventActionOptions{ + RetentionConfig: dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 24, + }, + }, + }, + }, + } + err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "username1", + }, + }, + }) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get users") + } + + eventManager.loadRules() + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + eventManager.RUnlock() + + // rule with invalid trigger + eventManager.addUpdateRuleInternal(dataprovider.EventRule{ + Name: "test rule", + Status: 1, + Trigger: -1, + }) + + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + eventManager.RUnlock() + // rule with invalid cronspec + eventManager.addUpdateRuleInternal(dataprovider.EventRule{ + Name: "test rule", + Status: 1, + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "1000", + }, + }, + }, + }) + eventManager.RLock() + assert.Len(t, eventManager.FsEvents, 0) + assert.Len(t, eventManager.ProviderEvents, 0) + assert.Len(t, eventManager.Schedules, 0) + eventManager.RUnlock() + + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + stopEventScheduler() +} + +func TestDateTimePlaceholder(t *testing.T) { + oldTZ := Config.TZ + + Config.TZ = "" + dateTime := time.Now() + params := EventParams{ + Timestamp: dateTime, + } + replacements := params.getStringReplacements(false, 0) + r := strings.NewReplacer(replacements...) + res := r.Replace("{{.DateTime}}") + assert.Equal(t, dateTime.UTC().Format(dateTimeMillisFormat), res) + res = r.Replace("{{.Year}}-{{.Month}}-{{.Day}}T{{.Hour}}:{{.Minute}}") + assert.Equal(t, dateTime.UTC().Format(dateTimeMillisFormat)[:16], res) + + Config.TZ = "local" + replacements = params.getStringReplacements(false, 0) + r = strings.NewReplacer(replacements...) + res = r.Replace("{{.DateTime}}") + assert.Equal(t, dateTime.Local().Format(dateTimeMillisFormat), res) + res = r.Replace("{{.Year}}-{{.Month}}-{{.Day}}T{{.Hour}}:{{.Minute}}") + assert.Equal(t, dateTime.Local().Format(dateTimeMillisFormat)[:16], res) + + Config.TZ = oldTZ +} + +func TestEventRuleActions(t *testing.T) { + actionName := "test rule action" + action := dataprovider.BaseEventAction{ + Name: actionName, + Type: dataprovider.ActionTypeBackup, + } + err := executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{}) + assert.NoError(t, err) + action.Type = -1 + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{}) + assert.Error(t, err) + + action = dataprovider.BaseEventAction{ + Name: actionName, + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "http://foo\x7f.com/", // invalid URL + SkipTLSVerify: true, + Body: `"data": "{{.ObjectDataString}}"`, + Method: http.MethodPost, + QueryParameters: []dataprovider.KeyValue{ + { + Key: "param", + Value: "value", + }, + }, + Timeout: 5, + Headers: []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + }, + Username: "httpuser", + }, + }, + } + action.Options.SetEmptySecretsIfNil() + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid endpoint") + } + action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr) + params := &EventParams{ + Name: "a", + Object: &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test user", + }, + }, + } + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.NoError(t, err) + action.Options.HTTPConfig.Method = http.MethodGet + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.NoError(t, err) + action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v/404", httpAddr) + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unexpected status code: 404") + } + action.Options.HTTPConfig.Endpoint = "http://invalid:1234" + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.Error(t, err) + action.Options.HTTPConfig.QueryParameters = nil + action.Options.HTTPConfig.Endpoint = "http://bar\x7f.com/" + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.Error(t, err) + action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", "data") + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to decrypt HTTP password") + } + action.Options.HTTPConfig.Endpoint = fmt.Sprintf("http://%v", httpAddr) + action.Options.HTTPConfig.Password = kms.NewEmptySecret() + action.Options.HTTPConfig.Body = "" + action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ + { + Name: "p1", + Filepath: "path", + }, + } + err = executeRuleAction(action, params, dataprovider.ConditionOptions{}) + assert.Contains(t, getErrorString(err), "error getting user") + + action.Options.HTTPConfig.Parts = nil + action.Options.HTTPConfig.Body = "{{.ObjectData}}" + // test disk and transfer quota reset + username1 := "user1" + username2 := "user2" + user1 := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username1, + HomeDir: filepath.Join(os.TempDir(), username1), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + user2 := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username2, + HomeDir: filepath.Join(os.TempDir(), username2), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + user2.Filters.PasswordExpiration = 10 + err = dataprovider.AddUser(&user1, "", "", "") + assert.NoError(t, err) + err = dataprovider.AddUser(&user2, "", "", "") + assert.NoError(t, err) + + err = executePwdExpirationCheckRuleAction(dataprovider.EventActionPasswordExpiration{ + Threshold: 20, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user2.Username, + }, + }, + }, &EventParams{}) + // smtp not configured + assert.Error(t, err) + + action = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeUserQuotaReset, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) // no home dir + // create the home dir + err = os.MkdirAll(user1.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user1.GetHomeDir(), "file.txt"), []byte("user"), 0666) + assert.NoError(t, err) + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + userGet, err := dataprovider.UserExists(username1, "") + assert.NoError(t, err) + assert.Equal(t, 1, userGet.UsedQuotaFiles) + assert.Equal(t, int64(4), userGet.UsedQuotaSize) + // simulate another quota scan in progress + assert.True(t, QuotaScans.AddUserQuotaScan(username1, "")) + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) + assert.True(t, QuotaScans.RemoveUserQuotaScan(username1)) + // non matching pattern + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "don't match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no user quota reset executed") + + action = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeUserExpirationCheck, + } + + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "don't match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no user expiration check executed") + + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + + dataRetentionAction := dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeDataRetentionCheck, + Options: dataprovider.BaseEventActionOptions{ + RetentionConfig: dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: "", + Retention: 24, + }, + }, + }, + }, + } + err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) // invalid config, no folder path specified + retentionDir := "testretention" + dataRetentionAction = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeDataRetentionCheck, + Options: dataprovider.BaseEventActionOptions{ + RetentionConfig: dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: path.Join("/", retentionDir), + Retention: 24, + DeleteEmptyDirs: true, + }, + }, + }, + }, + } + // create some test files + file1 := filepath.Join(user1.GetHomeDir(), "file1.txt") + file2 := filepath.Join(user1.GetHomeDir(), retentionDir, "file2.txt") + file3 := filepath.Join(user1.GetHomeDir(), retentionDir, "file3.txt") + file4 := filepath.Join(user1.GetHomeDir(), retentionDir, "sub", "file4.txt") + + err = os.MkdirAll(filepath.Dir(file4), os.ModePerm) + assert.NoError(t, err) + + for _, f := range []string{file1, file2, file3, file4} { + err = os.WriteFile(f, []byte(""), 0666) + assert.NoError(t, err) + } + timeBeforeRetention := time.Now().Add(-48 * time.Hour) + err = os.Chtimes(file1, timeBeforeRetention, timeBeforeRetention) + assert.NoError(t, err) + err = os.Chtimes(file2, timeBeforeRetention, timeBeforeRetention) + assert.NoError(t, err) + err = os.Chtimes(file4, timeBeforeRetention, timeBeforeRetention) + assert.NoError(t, err) + + err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + assert.FileExists(t, file1) + assert.NoFileExists(t, file2) + assert.FileExists(t, file3) + assert.NoDirExists(t, filepath.Dir(file4)) + // simulate another check in progress + c := RetentionChecks.Add(RetentionCheck{}, &user1) + assert.NotNil(t, c) + err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) + RetentionChecks.remove(user1.Username) + + err = executeRuleAction(dataRetentionAction, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no retention check executed") + + // test file exists action + action = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionExist, + Exist: []string{"/file1.txt", path.Join("/", retentionDir, "file3.txt")}, + }, + }, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no existence check executed") + + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + action.Options.FsConfig.Exist = []string{"/file1.txt", path.Join("/", retentionDir, "file2.txt")} + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.Error(t, err) + + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.UpdateUserTransferQuota(&user1, 100, 100, true) + assert.NoError(t, err) + + action.Type = dataprovider.ActionTypeTransferQuotaReset + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username1, + }, + }, + }) + assert.NoError(t, err) + userGet, err = dataprovider.UserExists(username1, "") + assert.NoError(t, err) + assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer) + assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer) + + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no transfer quota reset executed") + + action.Type = dataprovider.ActionTypeFilesystem + action.Options = dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/source", + Value: "/target", + }, + }, + }, + }, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no rename executed") + + action.Options = dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionDelete, + Deletes: []string{"/dir1"}, + }, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no delete executed") + + action.Options = dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionMkdirs, + Deletes: []string{"/dir1"}, + }, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no mkdir executed") + + action.Options = dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCompress, + Compress: dataprovider.EventActionFsCompress{ + Name: "test.zip", + Paths: []string{"/{{.VirtualPath}}"}, + }, + }, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no file/folder compressed") + + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "no match", + }, + }, + }) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "no file/folder compressed") + + err = dataprovider.DeleteUser(username1, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteUser(username2, "", "", "") + assert.NoError(t, err) + // test folder quota reset + foldername1 := "f1" + foldername2 := "f2" + folder1 := vfs.BaseVirtualFolder{ + Name: foldername1, + MappedPath: filepath.Join(os.TempDir(), foldername1), + } + folder2 := vfs.BaseVirtualFolder{ + Name: foldername2, + MappedPath: filepath.Join(os.TempDir(), foldername2), + } + err = dataprovider.AddFolder(&folder1, "", "", "") + assert.NoError(t, err) + err = dataprovider.AddFolder(&folder2, "", "", "") + assert.NoError(t, err) + action = dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeFolderQuotaReset, + } + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername1, + }, + }, + }) + assert.Error(t, err) // no home dir + err = os.MkdirAll(folder1.MappedPath, os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(folder1.MappedPath, "file.txt"), []byte("folder"), 0666) + assert.NoError(t, err) + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername1, + }, + }, + }) + assert.NoError(t, err) + folderGet, err := dataprovider.GetFolderByName(foldername1) + assert.NoError(t, err) + assert.Equal(t, 1, folderGet.UsedQuotaFiles) + assert.Equal(t, int64(6), folderGet.UsedQuotaSize) + // simulate another quota scan in progress + assert.True(t, QuotaScans.AddVFolderQuotaScan(foldername1)) + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername1, + }, + }, + }) + assert.Error(t, err) + assert.True(t, QuotaScans.RemoveVFolderQuotaScan(foldername1)) + + err = executeRuleAction(action, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "no folder match", + }, + }, + }) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no folder quota reset executed") + } + + body, _, err := getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{ + Method: http.MethodPost, + }, nil, nil, dataprovider.User{}, &EventParams{}, true) + assert.NoError(t, err) + assert.Nil(t, body) + body, _, err = getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{ + Method: http.MethodPost, + Body: "test body", + }, nil, nil, dataprovider.User{}, &EventParams{}, false) + assert.NoError(t, err) + assert.NotNil(t, body) + + err = os.RemoveAll(folder1.MappedPath) + assert.NoError(t, err) + err = dataprovider.DeleteFolder(foldername1, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteFolder(foldername2, "", "", "") + assert.NoError(t, err) +} + +func TestIDPAccountCheckRule(t *testing.T) { + _, _, err := executeIDPAccountCheckRule(dataprovider.EventRule{}, EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no action executed") + } + _, _, err = executeIDPAccountCheckRule(dataprovider.EventRule{ + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "n", + Type: dataprovider.ActionTypeIDPAccountCheck, + }, + }, + }, + }, EventParams{Event: "invalid"}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported IDP login event") + } + // invalid json + _, err = executeAdminCheckAction(&dataprovider.EventActionIDPAccountCheck{TemplateAdmin: "{"}, &EventParams{Name: "missing admin"}) + assert.Error(t, err) + _, err = executeUserCheckAction(&dataprovider.EventActionIDPAccountCheck{TemplateUser: "["}, &EventParams{Name: "missing user"}) + assert.Error(t, err) + _, err = executeUserCheckAction(&dataprovider.EventActionIDPAccountCheck{TemplateUser: "{}"}, &EventParams{Name: "invalid user template"}) + assert.ErrorIs(t, err, util.ErrValidation) + username := "u" + c := &dataprovider.EventActionIDPAccountCheck{ + Mode: 1, + TemplateUser: `{"username":"` + username + `","status":1,"home_dir":"` + util.JSONEscape(filepath.Join(os.TempDir())) + `","permissions":{"/":["*"]}}`, + } + params := &EventParams{ + Name: username, + Event: IDPLoginUser, + } + user, err := executeUserCheckAction(c, params) + assert.NoError(t, err) + assert.Equal(t, username, user.Username) + assert.Equal(t, 1, user.Status) + user.Status = 0 + err = dataprovider.UpdateUser(user, "", "", "") + assert.NoError(t, err) + // the user is not changed + user, err = executeUserCheckAction(c, params) + assert.NoError(t, err) + assert.Equal(t, username, user.Username) + assert.Equal(t, 0, user.Status) + // change the mode, the user is now updated + c.Mode = 0 + user, err = executeUserCheckAction(c, params) + assert.NoError(t, err) + assert.Equal(t, username, user.Username) + assert.Equal(t, 1, user.Status) + assert.Empty(t, user.Password) + assert.Len(t, user.PublicKeys, 0) + assert.Len(t, user.Filters.TLSCerts, 0) + assert.Empty(t, user.Email) + assert.Empty(t, user.Description) + // Update the profile attribute and make sure they are preserved + user.Password = "secret" + user.Email = "example@example.com" + user.Filters.AdditionalEmails = []string{"alias@example.com"} + user.Description = "some desc" + user.Filters.TLSCerts = []string{serverCert} + user.PublicKeys = []string{"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1"} + err = dataprovider.UpdateUser(user, "", "", "") + assert.NoError(t, err) + + user, err = executeUserCheckAction(c, params) + assert.NoError(t, err) + assert.Equal(t, username, user.Username) + assert.Equal(t, 1, user.Status) + assert.NotEmpty(t, user.Password) + assert.Len(t, user.PublicKeys, 1) + assert.Len(t, user.Filters.TLSCerts, 1) + assert.NotEmpty(t, user.Email) + assert.Len(t, user.Filters.AdditionalEmails, 1) + assert.NotEmpty(t, user.Description) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + // check rule consistency + r := dataprovider.EventRule{ + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeIDPAccountCheck, + }, + Order: 1, + }, + }, + } + err = r.CheckActionsConsistency("") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "IDP account check action is only supported for IDP login trigger") + } + r.Trigger = dataprovider.EventTriggerIDPLogin + err = r.CheckActionsConsistency("") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "IDP account check must be a sync action") + } + r.Actions[0].Options.ExecuteSync = true + err = r.CheckActionsConsistency("") + assert.NoError(t, err) + r.Actions = append(r.Actions, dataprovider.EventAction{ + BaseEventAction: dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeCommand, + }, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + Order: 2, + }) + err = r.CheckActionsConsistency("") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "IDP account check must be the only sync action") + } +} + +func TestUserExpirationCheck(t *testing.T) { + username := "test_user_expiration_check" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + HomeDir: filepath.Join(os.TempDir(), username), + ExpirationDate: util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)), + }, + } + user.Filters.PasswordExpiration = 5 + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + conditions := dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + } + err = executeUserExpirationCheckRuleAction(conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "expired users") + } + // the check will be skipped, the user is expired + err = executePwdExpirationCheckRuleAction(dataprovider.EventActionPasswordExpiration{Threshold: 10}, conditions, &EventParams{}) + assert.NoError(t, err) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestEventRuleActionsNoGroupMatching(t *testing.T) { + username := "test_user_action_group_matching" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + HomeDir: filepath.Join(os.TempDir(), username), + }, + } + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + conditions := dataprovider.ConditionOptions{ + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "agroup", + }, + }, + } + err = executeDeleteFsRuleAction(nil, nil, conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no delete executed") + } + err = executeMkdirFsRuleAction(nil, nil, conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no mkdir executed") + } + err = executeRenameFsRuleAction(nil, nil, conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no rename executed") + } + err = executeExistFsRuleAction(nil, nil, conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no existence check executed") + } + err = executeCopyFsRuleAction(nil, nil, conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no copy executed") + } + err = executeUsersQuotaResetRuleAction(conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no user quota reset executed") + } + err = executeTransferQuotaResetRuleAction(conditions, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no transfer quota reset executed") + } + err = executeDataRetentionCheckRuleAction(dataprovider.EventActionDataRetentionConfig{}, conditions, &EventParams{}, "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no retention check executed") + } + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestGetFileContent(t *testing.T) { + username := "test_user_get_file_content" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + HomeDir: filepath.Join(os.TempDir(), username), + }, + } + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + fileContent := []byte("test file content") + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file.txt"), fileContent, 0666) + assert.NoError(t, err) + conn := NewBaseConnection(xid.New().String(), protocolEventAction, "", "", user) + replacer := strings.NewReplacer("old", "new") + files, err := getMailAttachments(conn, []string{"/file.txt"}, replacer) + assert.NoError(t, err) + if assert.Len(t, files, 1) { + var b bytes.Buffer + _, err = files[0].Writer(&b) + assert.NoError(t, err) + assert.Equal(t, fileContent, b.Bytes()) + } + // missing file + _, err = getMailAttachments(conn, []string{"/file1.txt"}, replacer) + assert.Error(t, err) + // directory + _, err = getMailAttachments(conn, []string{"/"}, replacer) + assert.Error(t, err) + // files too large + content := make([]byte, maxAttachmentsSize/2+1) + _, err = rand.Read(content) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file1.txt"), content, 0666) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "file2.txt"), content, 0666) + assert.NoError(t, err) + files, err = getMailAttachments(conn, []string{"/file1.txt"}, replacer) + assert.NoError(t, err) + if assert.Len(t, files, 1) { + var b bytes.Buffer + _, err = files[0].Writer(&b) + assert.NoError(t, err) + assert.Equal(t, content, b.Bytes()) + } + _, err = getMailAttachments(conn, []string{"/file1.txt", "/file2.txt"}, replacer) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "size too large") + } + // change the filesystem provider + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("pwd") + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + conn = NewBaseConnection(xid.New().String(), protocolEventAction, "", "", user) + // the file is not encrypted so reading the encryption header will fail + files, err = getMailAttachments(conn, []string{"/file.txt"}, replacer) + assert.NoError(t, err) + if assert.Len(t, files, 1) { + var b bytes.Buffer + _, err = files[0].Writer(&b) + assert.Error(t, err) + } + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestFilesystemActionErrors(t *testing.T) { + err := executeFsRuleAction(dataprovider.EventActionFilesystemConfig{}, dataprovider.ConditionOptions{}, &EventParams{}) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported filesystem action") + } + username := "test_user_for_actions" + testReplacer := strings.NewReplacer("old", "new") + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + HomeDir: filepath.Join(os.TempDir(), username), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: "127.0.0.1:4022", + Username: username, + }, + Password: kms.NewPlainSecret("pwd"), + }, + }, + } + err = executeEmailRuleAction(dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.net"}, + Subject: "subject", + Body: "body", + Attachments: []string{"/file.txt"}, + }, &EventParams{ + sender: username, + }) + assert.Error(t, err) + conn := NewBaseConnection("", protocolEventAction, "", "", user) + err = executeDeleteFileFsAction(conn, "", nil) + assert.Error(t, err) + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + // check root fs fails + err = executeDeleteFsActionForUser(nil, testReplacer, user) + assert.Error(t, err) + err = executeMkDirsFsActionForUser(nil, testReplacer, user) + assert.Error(t, err) + err = executeRenameFsActionForUser(nil, testReplacer, user) + assert.Error(t, err) + err = executeExistFsActionForUser(nil, testReplacer, user) + assert.Error(t, err) + err = executeCopyFsActionForUser(nil, testReplacer, user) + assert.Error(t, err) + err = executeCompressFsActionForUser(dataprovider.EventActionFsCompress{}, testReplacer, user) + assert.Error(t, err) + _, _, _, _, err = getFileWriter(conn, "/path.txt", -1) //nolint:dogsled + assert.Error(t, err) + err = executeEmailRuleAction(dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.net"}, + Subject: "subject", + Body: "body", + Attachments: []string{"/file1.txt"}, + }, &EventParams{ + sender: username, + }) + assert.Error(t, err) + fn := getFileContentFn(NewBaseConnection("", protocolEventAction, "", "", user), "/f.txt", 1234) + var b bytes.Buffer + _, err = fn(&b) + assert.Error(t, err) + err = executeHTTPRuleAction(dataprovider.EventActionHTTPConfig{ + Endpoint: "http://127.0.0.1:9999/", + Method: http.MethodPost, + Parts: []dataprovider.HTTPPart{ + { + Name: "p1", + Filepath: "/filepath", + }, + }, + }, &EventParams{ + sender: username, + }) + assert.Error(t, err) + user.FsConfig.Provider = sdk.LocalFilesystemProvider + user.Permissions["/"] = []string{dataprovider.PermUpload} + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + err = executeRenameFsActionForUser([]dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/p1", + Value: "/p1", + }, + }, + }, testReplacer, user) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "the rename source and target cannot be the same") + } + err = executeRuleAction(dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/p2", + Value: "/p2", + }, + }, + }, + }, + }, + }, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + if runtime.GOOS != osWindows { + dirPath := filepath.Join(user.HomeDir, "adir", "sub") + err := os.MkdirAll(dirPath, os.ModePerm) + assert.NoError(t, err) + filePath := filepath.Join(dirPath, "f.dat") + err = os.WriteFile(filePath, []byte("test file content"), 0666) + assert.NoError(t, err) + err = os.Chmod(dirPath, 0001) + assert.NoError(t, err) + + err = executeDeleteFsActionForUser([]string{"/adir/sub"}, testReplacer, user) + assert.Error(t, err) + err = executeDeleteFsActionForUser([]string{"/adir/sub/f.dat"}, testReplacer, user) + assert.Error(t, err) + err = os.Chmod(dirPath, 0555) + assert.NoError(t, err) + err = executeDeleteFsActionForUser([]string{"/adir/sub/f.dat"}, testReplacer, user) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to remove file") + } + err = executeRuleAction(dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionDelete, + Deletes: []string{"/adir/sub/f.dat"}, + }, + }, + }, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + err = executeMkDirsFsActionForUser([]string{"/adir/sub/sub"}, testReplacer, user) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to create dir") + } + err = executeMkDirsFsActionForUser([]string{"/adir/sub/sub/sub"}, testReplacer, user) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to check parent dirs") + } + + err = executeRuleAction(dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionMkdirs, + MkDirs: []string{"/adir/sub/sub1"}, + }, + }, + }, &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + err = os.Chmod(dirPath, os.ModePerm) + assert.NoError(t, err) + + conn = NewBaseConnection("", protocolEventAction, "", "", user) + wr := &zipWriterWrapper{ + Name: "test.zip", + Writer: zip.NewWriter(bytes.NewBuffer(nil)), + Entries: map[string]bool{}, + } + err = addZipEntry(wr, conn, "/adir/sub/f.dat", "/adir/sub/sub", nil, 0) + assert.Error(t, err) + assert.Contains(t, getErrorString(err), "is outside base dir") + } + + wr := &zipWriterWrapper{ + Name: xid.New().String() + ".zip", + Writer: zip.NewWriter(bytes.NewBuffer(nil)), + Entries: map[string]bool{}, + } + err = addZipEntry(wr, conn, "/p1", "/", nil, 2000) + assert.ErrorIs(t, err, util.ErrRecursionTooDeep) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaActionsWithQuotaTrackDisabled(t *testing.T) { + oldProviderConf := dataprovider.GetProviderConfig() + providerConf := dataprovider.GetProviderConfig() + providerConf.TrackQuota = 0 + err := dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + username := "u1" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.LocalFilesystemProvider, + }, + } + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeUserQuotaReset}, + &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeTransferQuotaReset}, + &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: username, + }, + }, + }) + assert.Error(t, err) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + + foldername := "f1" + folder := vfs.BaseVirtualFolder{ + Name: foldername, + MappedPath: filepath.Join(os.TempDir(), foldername), + } + err = dataprovider.AddFolder(&folder, "", "", "") + assert.NoError(t, err) + err = os.MkdirAll(folder.MappedPath, os.ModePerm) + assert.NoError(t, err) + + err = executeRuleAction(dataprovider.BaseEventAction{Type: dataprovider.ActionTypeFolderQuotaReset}, + &EventParams{}, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: foldername, + }, + }, + }) + assert.Error(t, err) + + err = os.RemoveAll(folder.MappedPath) + assert.NoError(t, err) + err = dataprovider.DeleteFolder(foldername, "", "", "") + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(oldProviderConf, configDir, true) + assert.NoError(t, err) +} + +func TestScheduledActions(t *testing.T) { + startEventScheduler() + backupsPath := filepath.Join(os.TempDir(), "backups") + err := os.RemoveAll(backupsPath) + assert.NoError(t, err) + now := time.Now().UTC().Format(dateTimeMillisFormat) + // The backup action sets the home directory to the backup path. + expectedDirPath := filepath.Join(backupsPath, fmt.Sprintf("%s_%s_%s", now[0:4], now[5:7], now[8:10])) + + action1 := &dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeBackup, + } + err = dataprovider.AddEventAction(action1, "", "", "") + assert.NoError(t, err) + action2 := &dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionMkdirs, + MkDirs: []string{"{{.Year}}_{{.Month}}_{{.Day}}"}, + }, + }, + } + err = dataprovider.AddEventAction(action2, "", "", "") + assert.NoError(t, err) + rule := &dataprovider.EventRule{ + Name: "rule", + Status: 1, + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "11", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + + job := eventCronJob{ + ruleName: rule.Name, + } + job.Run() // rule not found + assert.NoDirExists(t, backupsPath) + + err = dataprovider.AddEventRule(rule, "", "", "") + assert.NoError(t, err) + + job.Run() + assert.DirExists(t, backupsPath) + assert.DirExists(t, expectedDirPath) + + action1.Type = dataprovider.ActionTypeEmail + action1.Options = dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"example@example.com"}, + Subject: "test with attachments", + Body: "body", + Attachments: []string{"/file1.txt"}, + }, + } + err = dataprovider.UpdateEventAction(action1, "", "", "") + assert.NoError(t, err) + job.Run() // action is not compatible with a scheduled rule + + err = dataprovider.DeleteEventRule(rule.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(action1.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(action2.Name, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(backupsPath) + assert.NoError(t, err) + + stopEventScheduler() +} + +func TestEventParamsCopy(t *testing.T) { + params := EventParams{ + Name: "name", + Event: "event", + Extension: "ext", + Status: 1, + errors: []string{"error1"}, + retentionChecks: []executedRetentionCheck{}, + } + paramsCopy := params.getACopy() + assert.Equal(t, params, *paramsCopy) + params.Name = "name mod" + paramsCopy.Event = "event mod" + paramsCopy.Status = 2 + params.errors = append(params.errors, "error2") + paramsCopy.errors = append(paramsCopy.errors, "error3") + assert.Equal(t, []string{"error1", "error3"}, paramsCopy.errors) + assert.Equal(t, []string{"error1", "error2"}, params.errors) + assert.Equal(t, "name mod", params.Name) + assert.Equal(t, "name", paramsCopy.Name) + assert.Equal(t, "event", params.Event) + assert.Equal(t, "event mod", paramsCopy.Event) + assert.Equal(t, 1, params.Status) + assert.Equal(t, 2, paramsCopy.Status) + params = EventParams{ + retentionChecks: []executedRetentionCheck{ + { + Username: "u", + ActionName: "a", + Results: []folderRetentionCheckResult{ + { + Path: "p", + Retention: 1, + }, + }, + }, + }, + } + paramsCopy = params.getACopy() + require.Len(t, paramsCopy.retentionChecks, 1) + paramsCopy.retentionChecks[0].Username = "u_copy" + paramsCopy.retentionChecks[0].ActionName = "a_copy" + require.Len(t, paramsCopy.retentionChecks[0].Results, 1) + paramsCopy.retentionChecks[0].Results[0].Path = "p_copy" + paramsCopy.retentionChecks[0].Results[0].Retention = 2 + assert.Equal(t, "u", params.retentionChecks[0].Username) + assert.Equal(t, "a", params.retentionChecks[0].ActionName) + assert.Equal(t, "p", params.retentionChecks[0].Results[0].Path) + assert.Equal(t, 1, params.retentionChecks[0].Results[0].Retention) + assert.Equal(t, "u_copy", paramsCopy.retentionChecks[0].Username) + assert.Equal(t, "a_copy", paramsCopy.retentionChecks[0].ActionName) + assert.Equal(t, "p_copy", paramsCopy.retentionChecks[0].Results[0].Path) + assert.Equal(t, 2, paramsCopy.retentionChecks[0].Results[0].Retention) + assert.Nil(t, params.IDPCustomFields) + params.addIDPCustomFields(nil) + assert.Nil(t, params.IDPCustomFields) + params.IDPCustomFields = &map[string]string{ + "field1": "val1", + } + paramsCopy = params.getACopy() + for k, v := range *paramsCopy.IDPCustomFields { + assert.Equal(t, "field1", k) + assert.Equal(t, "val1", v) + } + assert.Equal(t, params.IDPCustomFields, paramsCopy.IDPCustomFields) + (*paramsCopy.IDPCustomFields)["field1"] = "val2" + assert.NotEqual(t, params.IDPCustomFields, paramsCopy.IDPCustomFields) + params.Metadata = map[string]string{"key": "value"} + paramsCopy = params.getACopy() + params.Metadata["key1"] = "value1" + require.Equal(t, map[string]string{"key": "value"}, paramsCopy.Metadata) +} + +func TestEventParamsStatusFromError(t *testing.T) { + params := EventParams{Status: 1} + params.AddError(os.ErrNotExist) + assert.Equal(t, 1, params.Status) + + params = EventParams{Status: 1, updateStatusFromError: true} + params.AddError(os.ErrNotExist) + assert.Equal(t, 2, params.Status) +} + +type testWriter struct { + errTest error + sentinel string +} + +func (w *testWriter) Write(p []byte) (int, error) { + if w.errTest != nil { + return 0, w.errTest + } + if w.sentinel == string(p) { + return 0, io.ErrUnexpectedEOF + } + return len(p), nil +} + +func TestWriteHTTPPartsError(t *testing.T) { + m := multipart.NewWriter(&testWriter{ + errTest: io.ErrShortWrite, + }) + + err := writeHTTPPart(m, dataprovider.HTTPPart{}, nil, nil, nil, &EventParams{}, false) + assert.ErrorIs(t, err, io.ErrShortWrite) + + body := "test body" + m = multipart.NewWriter(&testWriter{sentinel: body}) + err = writeHTTPPart(m, dataprovider.HTTPPart{ + Body: body, + }, nil, nil, nil, &EventParams{}, false) + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) +} + +func TestReplacePathsPlaceholders(t *testing.T) { + replacer := strings.NewReplacer("{{.VirtualPath}}", "/path1") + paths := []string{"{{.VirtualPath}}", "/path1"} + paths = replacePathsPlaceholders(paths, replacer) + assert.Equal(t, []string{"/path1"}, paths) + paths = []string{"{{.VirtualPath}}", "/path2"} + paths = replacePathsPlaceholders(paths, replacer) + assert.Equal(t, []string{"/path1", "/path2"}, paths) +} + +func TestEstimateZipSizeErrors(t *testing.T) { + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "u", + HomeDir: filepath.Join(os.TempDir(), "u"), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + QuotaSize: 1000, + }, + } + err := dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + err = os.MkdirAll(u.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + conn := NewBaseConnection("", ProtocolFTP, "", "", u) + _, _, _, _, err = getFileWriter(conn, "/missing/path/file.txt", -1) //nolint:dogsled + assert.Error(t, err) + _, err = getSizeForPath(conn, "/missing", vfs.NewFileInfo("missing", true, 0, time.Now(), false)) + assert.True(t, conn.IsNotExistError(err)) + if runtime.GOOS != osWindows { + err = os.MkdirAll(filepath.Join(u.HomeDir, "d1", "d2", "sub"), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(u.HomeDir, "d1", "d2", "sub", "file.txt"), []byte("data"), 0666) + assert.NoError(t, err) + err = os.Chmod(filepath.Join(u.HomeDir, "d1", "d2"), 0001) + assert.NoError(t, err) + size, err := estimateZipSize(conn, "/archive.zip", []string{"/d1"}) + assert.Error(t, err, "size %d", size) + err = os.Chmod(filepath.Join(u.HomeDir, "d1", "d2"), os.ModePerm) + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(u.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) +} + +func TestOnDemandRule(t *testing.T) { + a := &dataprovider.BaseEventAction{ + Name: "a", + Type: dataprovider.ActionTypeBackup, + Options: dataprovider.BaseEventActionOptions{}, + } + err := dataprovider.AddEventAction(a, "", "", "") + assert.NoError(t, err) + r := &dataprovider.EventRule{ + Name: "test on demand rule", + Status: 1, + Trigger: dataprovider.EventTriggerOnDemand, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: a.Name, + }, + }, + }, + } + err = dataprovider.AddEventRule(r, "", "", "") + assert.NoError(t, err) + + err = RunOnDemandRule(r.Name) + assert.NoError(t, err) + + r.Status = 0 + err = dataprovider.UpdateEventRule(r, "", "", "") + assert.NoError(t, err) + err = RunOnDemandRule(r.Name) + assert.ErrorIs(t, err, util.ErrValidation) + assert.Contains(t, err.Error(), "is inactive") + + r.Status = 1 + r.Trigger = dataprovider.EventTriggerCertificate + err = dataprovider.UpdateEventRule(r, "", "", "") + assert.NoError(t, err) + err = RunOnDemandRule(r.Name) + assert.ErrorIs(t, err, util.ErrValidation) + assert.Contains(t, err.Error(), "is not defined as on-demand") + + a1 := &dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"example@example.org"}, + Subject: "subject", + Body: "body", + Attachments: []string{"/{{.VirtualPath}}"}, + }, + }, + } + err = dataprovider.AddEventAction(a1, "", "", "") + assert.NoError(t, err) + + r.Trigger = dataprovider.EventTriggerOnDemand + r.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: a1.Name, + }, + }, + } + err = dataprovider.UpdateEventRule(r, "", "", "") + assert.NoError(t, err) + err = RunOnDemandRule(r.Name) + assert.ErrorIs(t, err, util.ErrValidation) + assert.Contains(t, err.Error(), "incosistent actions") + + err = dataprovider.DeleteEventRule(r.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(a.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(a1.Name, "", "", "") + assert.NoError(t, err) + + err = RunOnDemandRule(r.Name) + assert.ErrorIs(t, err, util.ErrNotFound) +} + +func getErrorString(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func TestHTTPEndpointWithPlaceholders(t *testing.T) { + c := dataprovider.EventActionHTTPConfig{ + Endpoint: "http://127.0.0.1:8080/base/url/{{.Name}}/{{.VirtualPath}}/upload", + QueryParameters: []dataprovider.KeyValue{ + { + Key: "u", + Value: "{{.Name}}", + }, + { + Key: "p", + Value: "{{.VirtualPath}}", + }, + }, + } + name := "uname" + vPath := "/a dir/@ file.txt" + replacer := strings.NewReplacer("{{.Name}}", name, "{{.VirtualPath}}", vPath) + u, err := getHTTPRuleActionEndpoint(&c, replacer) + assert.NoError(t, err) + expected := "http://127.0.0.1:8080/base/url/" + url.PathEscape(name) + "/" + url.PathEscape(vPath) + + "/upload?" + "p=" + url.QueryEscape(vPath) + "&u=" + url.QueryEscape(name) + assert.Equal(t, expected, u) + + c.Endpoint = "http://127.0.0.1/upload" + u, err = getHTTPRuleActionEndpoint(&c, replacer) + assert.NoError(t, err) + expected = c.Endpoint + "?p=" + url.QueryEscape(vPath) + "&u=" + url.QueryEscape(name) + assert.Equal(t, expected, u) +} + +func TestMetadataReplacement(t *testing.T) { + params := &EventParams{ + Metadata: map[string]string{ + "key": "value", + }, + } + replacements := params.getStringReplacements(false, 0) + replacer := strings.NewReplacer(replacements...) + reader, _, err := getHTTPRuleActionBody(&dataprovider.EventActionHTTPConfig{Body: "{{.Metadata}} {{.MetadataString}}"}, replacer, nil, dataprovider.User{}, params, false) + require.NoError(t, err) + data, err := io.ReadAll(reader) + require.NoError(t, err) + assert.Equal(t, `{"key":"value"} {\"key\":\"value\"}`, string(data)) +} + +func TestUserInactivityCheck(t *testing.T) { + username1 := "user1" + username2 := "user2" + user1 := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username1, + HomeDir: filepath.Join(os.TempDir(), username1), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + user2 := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username2, + HomeDir: filepath.Join(os.TempDir(), username2), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + days := user1.InactivityDays(time.Now().Add(10*24*time.Hour + 5*time.Second)) + assert.Equal(t, 0, days) + + user2.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + err := executeInactivityCheckForUser(&user2, dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + }, time.Now().Add(12*24*time.Hour)) + assert.Error(t, err) + user2.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + err = executeInactivityCheckForUser(&user2, dataprovider.EventActionUserInactivity{ + DeleteThreshold: 10, + }, time.Now().Add(12*24*time.Hour)) + assert.Error(t, err) + + err = dataprovider.AddUser(&user1, "", "", "") + assert.NoError(t, err) + err = dataprovider.AddUser(&user2, "", "", "") + assert.NoError(t, err) + user1, err = dataprovider.UserExists(username1, "") + assert.NoError(t, err) + assert.Equal(t, 1, user1.Status) + days = user1.InactivityDays(time.Now().Add(10*24*time.Hour + 5*time.Second)) + assert.Equal(t, 10, days) + days = user1.InactivityDays(time.Now().Add(-10*24*time.Hour + 5*time.Second)) + assert.Equal(t, -9, days) + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "not matching", + }, + }, + }, &EventParams{}, time.Now().Add(12*24*time.Hour)) + assert.NoError(t, err) + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now()) + assert.NoError(t, err) // no action + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now().Add(-12*24*time.Hour)) + assert.NoError(t, err) // no action + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 20, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now().Add(30*24*time.Hour)) + // both thresholds exceeded, the user will be disabled + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "executed inactivity check actions for users") + } + user1, err = dataprovider.UserExists(username1, "") + assert.NoError(t, err) + assert.Equal(t, 0, user1.Status) + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now().Add(30*24*time.Hour)) + assert.NoError(t, err) // already disabled, no action + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 20, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now().Add(-30*24*time.Hour)) + assert.NoError(t, err) + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 20, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now()) + assert.NoError(t, err) + user1, err = dataprovider.UserExists(username1, "") + assert.NoError(t, err) + assert.Equal(t, 0, user1.Status) + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 20, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user1.Username, + }, + }, + }, &EventParams{}, time.Now().Add(30*24*time.Hour)) // the user is disabled, will be now deleted + assert.Error(t, err) + _, err = dataprovider.UserExists(username1, "") + assert.ErrorIs(t, err, util.ErrNotFound) + + err = executeUserInactivityCheckRuleAction(dataprovider.EventActionUserInactivity{ + DeleteThreshold: 20, + }, dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user2.Username, + }, + }, + }, &EventParams{}, time.Now().Add(30*24*time.Hour)) // no disable threshold, user deleted + assert.Error(t, err) + _, err = dataprovider.UserExists(username2, "") + assert.ErrorIs(t, err, util.ErrNotFound) + + err = dataprovider.DeleteUser(username1, "", "", "") + assert.Error(t, err) + err = dataprovider.DeleteUser(username2, "", "", "") + assert.Error(t, err) +} diff --git a/internal/common/eventscheduler.go b/internal/common/eventscheduler.go new file mode 100644 index 00000000..762880f7 --- /dev/null +++ b/internal/common/eventscheduler.go @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "time" + + "github.com/robfig/cron/v3" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + eventScheduler *cron.Cron +) + +func stopEventScheduler() { + if eventScheduler != nil { + eventScheduler.Stop() + eventScheduler = nil + } +} + +func startEventScheduler() { + stopEventScheduler() + + options := []cron.Option{ + cron.WithLogger(cron.DiscardLogger), + } + if !dataprovider.UseLocalTime() { + eventManagerLog(logger.LevelDebug, "use UTC time for the scheduler") + options = append(options, cron.WithLocation(time.UTC)) + } + + eventScheduler = cron.New(options...) + eventManager.loadRules() + _, err := eventScheduler.AddFunc("@every 10m", eventManager.loadRules) + util.PanicOnError(err) + eventScheduler.Start() +} diff --git a/internal/common/httpauth.go b/internal/common/httpauth.go new file mode 100644 index 00000000..10aa1782 --- /dev/null +++ b/internal/common/httpauth.go @@ -0,0 +1,148 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "encoding/csv" + "os" + "strings" + "sync" + + "github.com/GehirnInc/crypt/apr1_crypt" + "github.com/GehirnInc/crypt/md5_crypt" + "golang.org/x/crypto/bcrypt" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + // HTTPAuthenticationHeader defines the HTTP authentication + HTTPAuthenticationHeader = "WWW-Authenticate" + md5CryptPwdPrefix = "$1$" + apr1CryptPwdPrefix = "$apr1$" +) + +var ( + bcryptPwdPrefixes = []string{"$2a$", "$2$", "$2x$", "$2y$", "$2b$"} +) + +// HTTPAuthProvider defines the interface for HTTP auth providers +type HTTPAuthProvider interface { + ValidateCredentials(username, password string) bool + IsEnabled() bool +} + +type basicAuthProvider struct { + Path string + sync.RWMutex + Info os.FileInfo + Users map[string]string +} + +// NewBasicAuthProvider returns an HTTPAuthProvider implementing Basic Auth +func NewBasicAuthProvider(authUserFile string) (HTTPAuthProvider, error) { + basicAuthProvider := basicAuthProvider{ + Path: authUserFile, + Info: nil, + Users: make(map[string]string), + } + return &basicAuthProvider, basicAuthProvider.loadUsers() +} + +func (p *basicAuthProvider) IsEnabled() bool { + return p.Path != "" +} + +func (p *basicAuthProvider) isReloadNeeded(info os.FileInfo) bool { + p.RLock() + defer p.RUnlock() + + return p.Info == nil || p.Info.ModTime() != info.ModTime() || p.Info.Size() != info.Size() +} + +func (p *basicAuthProvider) loadUsers() error { + if !p.IsEnabled() { + return nil + } + info, err := os.Stat(p.Path) + if err != nil { + logger.Debug(logSender, "", "unable to stat basic auth users file: %v", err) + return err + } + if p.isReloadNeeded(info) { + r, err := os.Open(p.Path) + if err != nil { + logger.Debug(logSender, "", "unable to open basic auth users file: %v", err) + return err + } + defer r.Close() + reader := csv.NewReader(r) + reader.Comma = ':' + reader.Comment = '#' + reader.TrimLeadingSpace = true + records, err := reader.ReadAll() + if err != nil { + logger.Debug(logSender, "", "unable to parse basic auth users file: %v", err) + return err + } + p.Lock() + defer p.Unlock() + + p.Users = make(map[string]string) + for _, record := range records { + if len(record) == 2 { + p.Users[record[0]] = record[1] + } + } + logger.Debug(logSender, "", "number of users loaded for httpd basic auth: %v", len(p.Users)) + p.Info = info + } + return nil +} + +func (p *basicAuthProvider) getHashedPassword(username string) (string, bool) { + err := p.loadUsers() + if err != nil { + return "", false + } + p.RLock() + defer p.RUnlock() + + pwd, ok := p.Users[username] + return pwd, ok +} + +// ValidateCredentials returns true if the credentials are valid +func (p *basicAuthProvider) ValidateCredentials(username, password string) bool { + if hashedPwd, ok := p.getHashedPassword(username); ok { + if util.IsStringPrefixInSlice(hashedPwd, bcryptPwdPrefixes) { + err := bcrypt.CompareHashAndPassword([]byte(hashedPwd), []byte(password)) + return err == nil + } + if strings.HasPrefix(hashedPwd, md5CryptPwdPrefix) { + crypter := md5_crypt.New() + err := crypter.Verify(hashedPwd, []byte(password)) + return err == nil + } + if strings.HasPrefix(hashedPwd, apr1CryptPwdPrefix) { + crypter := apr1_crypt.New() + err := crypter.Verify(hashedPwd, []byte(password)) + return err == nil + } + } + + return false +} diff --git a/internal/common/httpauth_test.go b/internal/common/httpauth_test.go new file mode 100644 index 00000000..5f6f4716 --- /dev/null +++ b/internal/common/httpauth_test.go @@ -0,0 +1,85 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBasicAuth(t *testing.T) { + httpAuth, err := NewBasicAuthProvider("") + require.NoError(t, err) + require.False(t, httpAuth.IsEnabled()) + + _, err = NewBasicAuthProvider("missing path") + require.Error(t, err) + + authUserFile := filepath.Join(os.TempDir(), "http_users.txt") + authUserData := []byte("test1:$2y$05$bcHSED7aO1cfLto6ZdDBOOKzlwftslVhtpIkRhAtSa4GuLmk5mola\n") + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + + httpAuth, err = NewBasicAuthProvider(authUserFile) + require.NoError(t, err) + require.True(t, httpAuth.IsEnabled()) + require.False(t, httpAuth.ValidateCredentials("test1", "wrong1")) + require.False(t, httpAuth.ValidateCredentials("test2", "password2")) + require.True(t, httpAuth.ValidateCredentials("test1", "password1")) + + authUserData = append(authUserData, []byte("test2:$1$OtSSTL8b$bmaCqEksI1e7rnZSjsIDR1\n")...) + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + require.False(t, httpAuth.ValidateCredentials("test2", "wrong2")) + require.True(t, httpAuth.ValidateCredentials("test2", "password2")) + + authUserData = append(authUserData, []byte("test2:$apr1$gLnIkRIf$Xr/6aJfmIrihP4b2N2tcs/\n")...) + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + require.False(t, httpAuth.ValidateCredentials("test2", "wrong2")) + require.True(t, httpAuth.ValidateCredentials("test2", "password2")) + + authUserData = append(authUserData, []byte("test3:$apr1$gLnIkRIf$Xr/6aJfmIrihP4b2N2tcs/\n")...) + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + require.False(t, httpAuth.ValidateCredentials("test3", "password3")) + + authUserData = append(authUserData, []byte("test4:$invalid$gLnIkRIf$Xr/6$aJfmIr$ihP4b2N2tcs/\n")...) + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + require.False(t, httpAuth.ValidateCredentials("test4", "password3")) + + if runtime.GOOS != "windows" { + authUserData = append(authUserData, []byte("test5:$apr1$gLnIkRIf$Xr/6aJfmIrihP4b2N2tcs/\n")...) + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + err = os.Chmod(authUserFile, 0001) + require.NoError(t, err) + require.False(t, httpAuth.ValidateCredentials("test5", "password2")) + err = os.Chmod(authUserFile, os.ModePerm) + require.NoError(t, err) + } + authUserData = append(authUserData, []byte("\"foo\"bar\"\r\n")...) + err = os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + require.False(t, httpAuth.ValidateCredentials("test2", "password2")) + + err = os.Remove(authUserFile) + require.NoError(t, err) +} diff --git a/internal/common/protocol_test.go b/internal/common/protocol_test.go new file mode 100644 index 00000000..8147e246 --- /dev/null +++ b/internal/common/protocol_test.go @@ -0,0 +1,10027 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common_test + +import ( + "bufio" + "bytes" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "math" + "net" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "slices" + "strings" + "sync" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/jackc/pgx/v5/stdlib" + _ "github.com/mattn/go-sqlite3" + "github.com/mhale/smtpd" + "github.com/minio/sio" + "github.com/pkg/sftp" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "github.com/rs/xid" + "github.com/rs/zerolog" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/studio-b12/gowebdav" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + httpAddr = "127.0.0.1:9999" + httpProxyAddr = "127.0.0.1:7777" + sftpServerAddr = "127.0.0.1:4022" + smtpServerAddr = "127.0.0.1:2525" + webDavServerPort = 9191 + httpFsPort = 34567 + defaultUsername = "test_common_sftp" + defaultPassword = "test_password" + defaultSFTPUsername = "test_common_sftpfs_user" + defaultHTTPFsUsername = "httpfs_ftp_user" + httpFsWellKnowDir = "/wellknow" + osWindows = "windows" + testFileName = "test_file_common_sftp.dat" + testDir = "test_dir_common" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn +NhAAAAAwEAAQAAAYEAtN449A/nY5O6cSH/9Doa8a3ISU0WZJaHydTaCLuO+dkqtNpnV5mq +zFbKidXAI1eSwVctw9ReVOl1uK6aZF3lbXdOD8W9PXobR9KUUT2qBx5QC4ibfAqDKWymDA +PG9ylzz64hsYBqJr7VNk9kTFEUsDmWzLabLoH42Elnp8mF/lTkWIcpVp0ly/etS08gttXo +XenekJ1vRuxOYWDCEzGPU7kGc920TmM14k7IDdPoOh5+3sRUKedKeOUrVDH1f0n7QjHQsZ +cbshp8tgqzf734zu8cTqNrr+6taptdEOOij1iUL/qYGfzny/hA48tO5+UFUih5W8ftp0+E +NBIDkkGgk2MJ92I7QAXyMVsIABXco+mJT7pQi9tqlODGIQ3AOj0gcA3X/Ib8QX77Ih3TPi +XEh77/P1XiYZOgpp2cRmNH8QbqaL9u898hDvJwIPJPuj2lIltTElH7hjBf5LQfCzrLV7BD +10rM7sl4jr+A2q8jl1Ikp+25kainBBZSbrDummT9AAAFgDU/VLk1P1S5AAAAB3NzaC1yc2 +EAAAGBALTeOPQP52OTunEh//Q6GvGtyElNFmSWh8nU2gi7jvnZKrTaZ1eZqsxWyonVwCNX +ksFXLcPUXlTpdbiummRd5W13Tg/FvT16G0fSlFE9qgceUAuIm3wKgylspgwDxvcpc8+uIb +GAaia+1TZPZExRFLA5lsy2my6B+NhJZ6fJhf5U5FiHKVadJcv3rUtPILbV6F3p3pCdb0bs +TmFgwhMxj1O5BnPdtE5jNeJOyA3T6Doeft7EVCnnSnjlK1Qx9X9J+0Ix0LGXG7IafLYKs3 ++9+M7vHE6ja6/urWqbXRDjoo9YlC/6mBn858v4QOPLTuflBVIoeVvH7adPhDQSA5JBoJNj +CfdiO0AF8jFbCAAV3KPpiU+6UIvbapTgxiENwDo9IHAN1/yG/EF++yId0z4lxIe+/z9V4m +GToKadnEZjR/EG6mi/bvPfIQ7ycCDyT7o9pSJbUxJR+4YwX+S0Hws6y1ewQ9dKzO7JeI6/ +gNqvI5dSJKftuZGopwQWUm6w7ppk/QAAAAMBAAEAAAGAHKnC+Nq0XtGAkIFE4N18e6SAwy +0WSWaZqmCzFQM0S2AhJnweOIG/0ZZHjsRzKKauOTmppQk40dgVsejpytIek9R+aH172gxJ +2n4Cx0UwduRU5x8FFQlNc/kl722B0JWfJuB/snOZXv6LJ4o5aObIkozt2w9tVFeAqjYn2S +1UsNOfRHBXGsTYwpRDwFWP56nKo2d2wBBTHDhCy6fb2dLW1fvSi/YspueOGIlHpvlYKi2/ +CWqvs9xVrwcScMtiDoQYq0khhO0efLCxvg/o+W9CLMVM2ms4G1zoSUQKN0oYWWQJyW4+VI +YneWO8UpN0J3ElXKi7bhgAat7dBaM1g9IrAzk153DiEFZNsPxGOgL/+YdQN7zUBx/z7EkI +jyv80RV7fpUXvcq2p+qNl6UVig3VSzRrnsaJkUWu/A0u59ha7ocv6NxDIXjxpIDJme16GF +quiGVBQNnYJymS/vFEbGf6bgf7iRmMCRUMG4nqLA6fPYP9uAtch+CmDfVLZC/fIdC5AAAA +wQCDissV4zH6bfqgxJSuYNk8Vbb+19cF3b7gH1rVlB3zxpCAgcRgMHC+dP1z2NRx7UW9MR +nye6kjpkzZZ0OigLqo7TtEq8uTglD9o6W7mRXqhy5A/ySOmqPL3ernHHQhGuoNODYAHkOU +u2Rh8HXi+VLwKZcLInPOYJvcuLG4DxN8WfeVvlMHwhAOaTNNOtL4XZDHQeIPc4qHmJymmv +sV7GuyQ6yW5C10uoGdxRPd90Bh4z4h2bKfZFjvEBbSBVkqrlAAAADBAN/zNtNayd/dX7Cr +Nb4sZuzCh+CW4BH8GOePZWNCATwBbNXBVb5cR+dmuTqYm+Ekz0VxVQRA1TvKncluJOQpoa +Xj8r0xdIgqkehnfDPMKtYVor06B9Fl1jrXtXU0Vrr6QcBWruSVyK1ZxqcmcNK/+KolVepe +A6vcl/iKaG4U7su166nxLST06M2EgcSVsFJHpKn5+WAXC+X0Gx8kNjWIIb3GpiChdc0xZD +mq02xZthVJrTCVw/e7gfDoB2QRsNV8HwAAAMEAzsCghZVp+0YsYg9oOrw4tEqcbEXEMhwY +0jW8JNL8Spr1Ibp5Dw6bRSk5azARjmJtnMJhJ3oeHfF0eoISqcNuQXGndGQbVM9YzzAzc1 +NbbCNsVroqKlChT5wyPNGS+phi2bPARBno7WSDvshTZ7dAVEP2c9MJW0XwoSevwKlhgSdt +RLFFQ/5nclJSdzPBOmQouC0OBcMFSrYtMeknJ4VvueVvve5HcHFaEsaMc7ABAGaLYaBQOm +iixITGvaNZh/tjAAAACW5pY29sYUBwMQE= +-----END OPENSSH PRIVATE KEY-----` +) + +var ( + configDir = filepath.Join(".", "..", "..") + allPerms = []string{dataprovider.PermAny} + homeBasePath string + logFilePath string + backupsPath string + testFileContent = []byte("test data") + lastReceivedEmail receivedEmail +) + +func TestMain(m *testing.M) { + homeBasePath = os.TempDir() + logFilePath = filepath.Join(configDir, "common_test.log") + backupsPath = filepath.Join(os.TempDir(), "backups") + logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) + + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") + os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") + os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") + os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") + err := config.LoadConfig(configDir, "") + if err != nil { + logger.ErrorToConsole("error loading configuration: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + logger.InfoToConsole("Starting COMMON tests, provider: %v", providerConf.Driver) + + err = dataprovider.Initialize(providerConf, configDir, true) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + + err = common.Initialize(config.GetCommonConfig(), 0) + if err != nil { + logger.WarnToConsole("error initializing common: %v", err) + os.Exit(1) + } + + httpConfig := config.GetHTTPConfig() + httpConfig.Timeout = 5 + httpConfig.RetryMax = 0 + httpConfig.Initialize(configDir) //nolint:errcheck + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing kms: %v", err) + os.Exit(1) + } + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing MFA: %v", err) + os.Exit(1) + } + + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings[0].Port = 4022 + sftpdConf.EnabledSSHCommands = []string{"*"} + sftpdConf.Bindings = append(sftpdConf.Bindings, sftpd.Binding{ + Port: 4024, + }) + sftpdConf.KeyboardInteractiveAuthentication = true + + httpdConf := config.GetHTTPDConfig() + httpdConf.Bindings[0].Port = 4080 + httpdtest.SetBaseURL("http://127.0.0.1:4080") + + webDavConf := config.GetWebDAVDConfig() + webDavConf.Bindings = []webdavd.Binding{ + { + Port: webDavServerPort, + }, + } + + go func() { + if err := sftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := httpdConf.Initialize(configDir, 0); err != nil { + logger.ErrorToConsole("could not start HTTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := webDavConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start WebDAV server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + waitTCPListening(webDavConf.Bindings[0].GetAddress()) + startHTTPFs() + + go func() { + // start a test HTTP server to receive action notifications + http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprintf(w, "OK\n") + }) + http.HandleFunc("/404", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "Not found\n") + }) + http.HandleFunc("/multipart", func(w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(1048576) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "KO\n") + return + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + fmt.Fprintf(w, "OK\n") + }) + if err := http.ListenAndServe(httpAddr, nil); err != nil { + logger.ErrorToConsole("could not start HTTP notification server: %v", err) + os.Exit(1) + } + }() + + go func() { + common.Config.ProxyProtocol = 2 + listener, err := net.Listen("tcp", httpProxyAddr) + if err != nil { + logger.ErrorToConsole("error creating listener for proxy protocol server: %v", err) + os.Exit(1) + } + proxyListener, err := common.Config.GetProxyListener(listener) + if err != nil { + logger.ErrorToConsole("error creating proxy protocol listener: %v", err) + os.Exit(1) + } + common.Config.ProxyProtocol = 0 + + s := &http.Server{} + if err := s.Serve(proxyListener); err != nil { + logger.ErrorToConsole("could not start HTTP proxy protocol server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := smtpd.ListenAndServe(smtpServerAddr, func(_ net.Addr, from string, to []string, data []byte) error { + lastReceivedEmail.set(from, to, data) + return nil + }, "SFTPGo test", "localhost"); err != nil { + logger.ErrorToConsole("could not start SMTP server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(httpAddr) + waitTCPListening(httpProxyAddr) + waitTCPListening(smtpServerAddr) + + exitCode := m.Run() + os.Remove(logFilePath) + os.RemoveAll(backupsPath) + os.Exit(exitCode) +} + +func TestBaseConnection(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + _, err = client.ReadDir(testDir) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.RemoveDirectory(testDir) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.Error(t, err) + info, err := client.Stat(testDir) + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + err = client.Rename(testDir, testDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "the rename source and target cannot be the same") + } + err = client.Rename(testDir, path.Join(testDir, "sub")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.RemoveDirectory(testDir) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.ErrorIs(t, err, os.ErrNotExist) + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + linkName := testFileName + ".link" //nolint:goconst + err = client.Rename(testFileName, testFileName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "the rename source and target cannot be the same") + } + err = client.Symlink(testFileName, linkName) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName) + assert.Error(t, err) + info, err = client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(len(testFileContent)), info.Size()) + assert.False(t, info.IsDir()) + } + info, err = client.Lstat(linkName) + if assert.NoError(t, err) { + assert.NotEqual(t, int64(7), info.Size()) + assert.True(t, info.Mode()&os.ModeSymlink != 0) + assert.False(t, info.IsDir()) + } + err = client.RemoveDirectory(linkName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + } + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Remove(linkName) + assert.NoError(t, err) + err = client.Rename(testFileName, "test") + assert.ErrorIs(t, err, os.ErrNotExist) + f, err = client.Create(testFileName) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"1") + assert.NoError(t, err) + err = client.Remove(testFileName + "1") + assert.NoError(t, err) + err = client.RemoveDirectory("missing") + assert.Error(t, err) + } else { + printLatestLogs(10) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRemoveAll(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webDavClient := getWebDavClient(user) + err = webDavClient.RemoveAll("/") + if assert.Error(t, err) { + assert.True(t, gowebdav.IsErrCode(err, http.StatusForbidden)) + } + + testDir := "baseDir" + err = webDavClient.RemoveAll(testDir) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), 1234, client) + assert.NoError(t, err) + + err = webDavClient.RemoveAll(path.Join(testDir, testFileName)) + assert.NoError(t, err) + _, err = client.Stat(path.Join(testDir, testFileName)) + assert.Error(t, err) + + err = writeSFTPFile(path.Join(testDir, testFileName), 1234, client) + assert.NoError(t, err) + err = webDavClient.RemoveAll(testDir) + assert.NoError(t, err) + _, err = client.Stat(testDir) + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRelativeSymlinks(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + linkName := testFileName + "_link" //nolint:goconst + err = client.Symlink("non-existent-file", linkName) + assert.NoError(t, err) + err = client.Remove(linkName) + assert.NoError(t, err) + testDir := "sub" + err = client.Mkdir(testDir) + assert.NoError(t, err) + f, err := client.Create(path.Join(testDir, testFileName)) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + err = client.Symlink(path.Join(testDir, testFileName), linkName) + assert.NoError(t, err) + _, err = client.Stat(linkName) + assert.NoError(t, err) + p, err := client.ReadLink(linkName) + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testDir, testFileName), p) + err = client.Remove(linkName) + assert.NoError(t, err) + + err = client.Symlink(testFileName, path.Join(testDir, linkName)) + assert.NoError(t, err) + _, err = client.Stat(path.Join(testDir, linkName)) + assert.NoError(t, err) + p, err = client.ReadLink(path.Join(testDir, linkName)) + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testDir, testFileName), p) + + f, err = client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + + err = client.Symlink(testFileName, linkName) + assert.NoError(t, err) + _, err = client.Stat(linkName) + assert.NoError(t, err) + p, err = client.ReadLink(linkName) + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testFileName), p) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCheckFsAfterUpdate(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + // remove the home dir, it will not be re-created + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.Error(t, err) + } else { + printLatestLogs(10) + } + // update the user and login again, this time the home dir will be created + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginAccessTime(t *testing.T) { + u := getTestUser() + u.Filters.AccessTime = []sdk.TimePeriod{ + { + DayOfWeek: int(time.Now().Add(-25 * time.Hour).UTC().Weekday()), + From: "00:00", + To: "23:59", + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + _, _, err = getSftpClient(user) + assert.Error(t, err) + + user.Filters.AccessTime = []sdk.TimePeriod{ + { + DayOfWeek: int(time.Now().UTC().Weekday()), + From: "00:00", + To: "23:59", + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err := checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSetStat(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + acmodTime := time.Now().Add(36 * time.Hour) + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.NoError(t, err) + newFi, err := client.Lstat(testFileName) + assert.NoError(t, err) + diff := math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) + assert.LessOrEqual(t, diff, float64(1)) + if runtime.GOOS != osWindows { + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.NoError(t, err) + } + newPerm := os.FileMode(0666) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + newFi, err = client.Lstat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, newPerm, newFi.Mode().Perm()) + } + err = client.Truncate(testFileName, 2) + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(2), info.Size()) + } + err = client.Remove(testFileName) + assert.NoError(t, err) + + err = client.Truncate(testFileName, 0) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.ErrorIs(t, err, os.ErrNotExist) + if runtime.GOOS != osWindows { + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.ErrorIs(t, err, os.ErrNotExist) + } + err = client.Chmod(testFileName, newPerm) + assert.ErrorIs(t, err, os.ErrNotExist) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCryptFsUserUploadErrorOverwrite(t *testing.T) { + u := getCryptFsUser() + u.QuotaSize = 6000 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + var buf []byte + for i := 0; i < 4000; i++ { + buf = append(buf, []byte("a")...) + } + bufSize := int64(len(buf)) + reader := bytes.NewReader(buf) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + f, err := client.Create(testFileName + "_big") + assert.NoError(t, err) + n, err := io.Copy(f, reader) + assert.NoError(t, err) + assert.Equal(t, bufSize, n) + err = f.Close() + assert.NoError(t, err) + encryptedSize, err := getEncryptedFileSize(bufSize) + assert.NoError(t, err) + expectedSize := encryptedSize + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, expectedSize, user.UsedQuotaSize) + // now write a small file + f, err = client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + encryptedSize, err = getEncryptedFileSize(int64(len(testFileContent))) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, expectedSize+encryptedSize, user.UsedQuotaSize) + // try to overwrite this file with a big one, this cause an overquota error + // the partial file is deleted and the quota updated + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + f, err = client.Create(testFileName) + assert.NoError(t, err) + _, err = io.Copy(f, reader) + assert.Error(t, err) + err = f.Close() + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, expectedSize, user.UsedQuotaSize) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestChtimesOpenHandle(t *testing.T) { + localUser, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + u := getCryptFsUser() + cryptFsUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + for _, user := range []dataprovider.User{localUser, sftpUser, cryptFsUser} { + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + f, err := client.Create(testFileName) + assert.NoError(t, err, "user %v", user.Username) + f1, err := client.Create(testFileName + "1") + assert.NoError(t, err, "user %v", user.Username) + acmodTime := time.Now().Add(36 * time.Hour) + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.NoError(t, err, "user %v", user.Username) + _, err = f.Write(testFileContent) + assert.NoError(t, err, "user %v", user.Username) + err = f.Close() + assert.NoError(t, err, "user %v", user.Username) + err = f1.Close() + assert.NoError(t, err, "user %v", user.Username) + info, err := client.Lstat(testFileName) + assert.NoError(t, err, "user %v", user.Username) + diff := math.Abs(info.ModTime().Sub(acmodTime).Seconds()) + assert.LessOrEqual(t, diff, float64(1), "user %v", user.Username) + info1, err := client.Lstat(testFileName + "1") + assert.NoError(t, err, "user %v", user.Username) + diff = math.Abs(info1.ModTime().Sub(acmodTime).Seconds()) + assert.Greater(t, diff, float64(86400), "user %v", user.Username) + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(cryptFsUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(cryptFsUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWaitForConnections(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 128 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + testFileSize := int64(524288) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = common.CheckClosing() + assert.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(1 * time.Second) + common.WaitForTransfers(10) + common.WaitForTransfers(0) + common.WaitForTransfers(10) + }() + + err = writeSFTPFileNoCheck(testFileName, testFileSize, client) + assert.NoError(t, err) + wg.Wait() + + err = common.CheckClosing() + assert.EqualError(t, err, common.ErrShuttingDown.Error()) + + _, err = client.Stat(testFileName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrShuttingDown.Error()) + } + } + + _, _, err = getSftpClient(user) + assert.Error(t, err) + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + err = client.Remove(testFileName) + assert.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(1 * time.Second) + common.WaitForTransfers(1) + }() + + err = writeSFTPFileNoCheck(testFileName, testFileSize, client) + // we don't have an error here because the service won't really stop + assert.NoError(t, err) + wg.Wait() + } + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + common.WaitForTransfers(1) + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCheckParentDirs(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + testDir := "/path/to/sub/dir" + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + _, err = client.Stat(testDir) + assert.ErrorIs(t, err, os.ErrNotExist) + c := common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", user) + err = c.CheckParentDirs(testDir) + assert.NoError(t, err) + _, err = client.Stat(testDir) + assert.NoError(t, err) + err = c.CheckParentDirs(testDir) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + u := getTestUser() + u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermDownload} + user, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + c := common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", user) + err = c.CheckParentDirs(testDir) + assert.ErrorIs(t, err, sftp.ErrSSHFxPermissionDenied) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermissionErrors(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser() + subDir := "/sub" + u.Permissions[subDir] = nil + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.MkdirAll(path.Join(subDir, subDir)) + assert.NoError(t, err) + f, err := client.Create(path.Join(subDir, subDir, testFileName)) + if assert.NoError(t, err) { + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + } + conn, client, err = getSftpClient(sftpUser) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + _, err = client.ReadDir(subDir) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir(path.Join(subDir, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.RemoveDirectory(path.Join(subDir, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink("test", path.Join(subDir, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Chmod(path.Join(subDir, subDir), os.ModePerm) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Chown(path.Join(subDir, subDir), os.Getuid(), os.Getgid()) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Chtimes(path.Join(subDir, subDir), time.Now(), time.Now()) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Truncate(path.Join(subDir, subDir), 0) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Remove(path.Join(subDir, subDir, testFileName)) + assert.ErrorIs(t, err, os.ErrPermission) + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHiddenPatternFilter(t *testing.T) { + deniedDir := "/denied_hidden" + u := getTestUser() + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: deniedDir, + DeniedPatterns: []string{"*.txt", "beta*"}, + DenyPolicy: sdk.DenyPolicyHide, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + dirName := "beta" + subDirName := "testDir" + testFile := filepath.Join(u.GetHomeDir(), deniedDir, "file.txt") + testFile1 := filepath.Join(u.GetHomeDir(), deniedDir, "beta.txt") + testHiddenFile := filepath.Join(u.GetHomeDir(), deniedDir, dirName, subDirName, "hidden.jpg") + err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFile, testFileContent, os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFile1, testFileContent, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir, dirName, subDirName), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testHiddenFile, testFileContent, os.ModePerm) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + files, err := client.ReadDir(deniedDir) + assert.NoError(t, err) + assert.Len(t, files, 0) + err = client.Remove(path.Join(deniedDir, filepath.Base(testFile))) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Chtimes(path.Join(deniedDir, filepath.Base(testFile)), time.Now(), time.Now()) + assert.ErrorIs(t, err, os.ErrNotExist) + _, err = client.Stat(path.Join(deniedDir, filepath.Base(testFile1))) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.RemoveDirectory(path.Join(deniedDir, dirName)) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Rename(path.Join(deniedDir, dirName), path.Join(deniedDir, "newname")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir(path.Join(deniedDir, "beta1")) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join(deniedDir, "afile.txt"), 1024, client) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join(deniedDir, dirName, subDirName, "afile.jpg"), 1024, client) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.Open(path.Join(deniedDir, dirName, subDirName, filepath.Base(testHiddenFile))) + assert.ErrorIs(t, err, os.ErrNotExist) + err = client.Symlink(path.Join(deniedDir, dirName), dirName) + assert.ErrorIs(t, err, os.ErrNotExist) + err = writeSFTPFile(path.Join(deniedDir, testFileName), 1024, client) + assert.NoError(t, err) + err = client.Symlink(path.Join(deniedDir, testFileName), path.Join(deniedDir, "symlink.txt")) + assert.ErrorIs(t, err, os.ErrPermission) + files, err = client.ReadDir(deniedDir) + assert.NoError(t, err) + assert.Len(t, files, 1) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: deniedDir, + DeniedPatterns: []string{"*.txt", "beta*"}, + DenyPolicy: sdk.DenyPolicyDefault, + }, + } + user, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + files, err := client.ReadDir(deniedDir) + assert.NoError(t, err) + assert.Len(t, files, 4) + _, err = client.Stat(path.Join(deniedDir, filepath.Base(testFile))) + assert.NoError(t, err) + err = client.Chtimes(path.Join(deniedDir, filepath.Base(testFile)), time.Now(), time.Now()) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir(path.Join(deniedDir, "beta2")) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join(deniedDir, "afile2.txt"), 1024, client) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink(path.Join(deniedDir, testFileName), path.Join(deniedDir, "link.txt")) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join(deniedDir, dirName, subDirName, "afile.jpg"), 1024, client) + assert.NoError(t, err) + f, err := client.Open(path.Join(deniedDir, dirName, subDirName, filepath.Base(testHiddenFile))) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHiddenRoot(t *testing.T) { + // only the "/ftp" directory is allowed and visibile in the "/" path + // within /ftp any file/directory is allowed and visibile + u := getTestUser() + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + AllowedPatterns: []string{"ftp"}, + DenyPolicy: sdk.DenyPolicyHide, + }, + { + Path: "/ftp", + AllowedPatterns: []string{"*"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + for i := 0; i < 10; i++ { + err = os.MkdirAll(filepath.Join(user.HomeDir, fmt.Sprintf("ftp%d", i)), os.ModePerm) + assert.NoError(t, err) + } + err = os.WriteFile(filepath.Join(user.HomeDir, testFileName), []byte(""), 0666) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.HomeDir, "ftp.txt"), []byte(""), 0666) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Mkdir("ftp") + assert.NoError(t, err) + entries, err := client.ReadDir("/") + assert.NoError(t, err) + if assert.Len(t, entries, 1) { + assert.Equal(t, "ftp", entries[0].Name()) + } + _, err = client.Stat(".") + assert.NoError(t, err) + for _, name := range []string{testFileName, "ftp.txt"} { + _, err = client.Stat(name) + assert.ErrorIs(t, err, os.ErrNotExist) + } + for i := 0; i < 10; i++ { + _, err = client.Stat(fmt.Sprintf("ftp%d", i)) + assert.ErrorIs(t, err, os.ErrNotExist) + } + err = writeSFTPFile(testFileName, 4096, client) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile("ftp123", 4096, client) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, testFileName+"_rename") //nolint:goconst + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join("/ftp", testFileName), 4096, client) + assert.NoError(t, err) + err = client.Mkdir("/ftp/dir") + assert.NoError(t, err) + err = client.Rename(path.Join("/ftp", testFileName), path.Join("/ftp/dir", testFileName)) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestFileNotAllowedErrors(t *testing.T) { + deniedDir := "/denied" + u := getTestUser() + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: deniedDir, + DeniedPatterns: []string{"*.txt"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFile := filepath.Join(u.GetHomeDir(), deniedDir, "file.txt") + err = os.MkdirAll(filepath.Join(u.GetHomeDir(), deniedDir), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFile, testFileContent, os.ModePerm) + assert.NoError(t, err) + err = client.Remove(path.Join(deniedDir, "file.txt")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(deniedDir, "file.txt"), path.Join(deniedDir, "file1.txt")) + assert.ErrorIs(t, err, os.ErrPermission) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRootDirVirtualFolder(t *testing.T) { + mappedPath1 := filepath.Join(os.TempDir(), "mapped1") + f1 := vfs.BaseVirtualFolder{ + Name: filepath.Base(mappedPath1), + MappedPath: mappedPath1, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("cryptsecret"), + }, + }, + } + mappedPath2 := filepath.Join(os.TempDir(), "mapped2") + f2 := vfs.BaseVirtualFolder{ + Name: filepath.Base(mappedPath2), + MappedPath: mappedPath2, + } + folder1, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + folder2, _, err := httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.QuotaFiles = 1000 + u.UploadDataTransfer = 1000 + u.DownloadDataTransfer = 5000 + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder1.Name, + }, + VirtualPath: "/", + QuotaFiles: 1000, + }) + vdirPath2 := "/vmapped" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder2.Name, + }, + VirtualPath: vdirPath2, + QuotaFiles: -1, + QuotaSize: -1, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + f, err := user.GetVirtualFolderForPath("/") + assert.NoError(t, err) + assert.Equal(t, "/", f.VirtualPath) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = checkBasicSFTP(client) + assert.NoError(t, err) + f, err := client.Create(testFileName) + if assert.NoError(t, err) { + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + assert.NoFileExists(t, filepath.Join(user.HomeDir, testFileName)) + assert.FileExists(t, filepath.Join(mappedPath1, testFileName)) + entries, err := client.ReadDir(".") + if assert.NoError(t, err) { + assert.Len(t, entries, 2) + } + + user, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + folder, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, folder.UsedQuotaFiles) + + f, err = client.Create(path.Join(vdirPath2, testFileName)) + if assert.NoError(t, err) { + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + folder, _, err = httpdtest.GetFolderByName(folder1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, folder.UsedQuotaFiles) + + err = client.Rename(testFileName, path.Join(vdirPath2, testFileName+"_rename")) + assert.Error(t, err) + err = client.Rename(path.Join(vdirPath2, testFileName), testFileName+"_rename") + assert.Error(t, err) + err = client.Rename(testFileName, testFileName+"_rename") + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+"_rename")) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder1.Name}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder2.Name}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestTruncateQuotaLimits(t *testing.T) { + mappedPath1 := filepath.Join(os.TempDir(), "mapped1") + f1 := vfs.BaseVirtualFolder{ + Name: filepath.Base(mappedPath1), + MappedPath: mappedPath1, + } + folder1, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + mappedPath2 := filepath.Join(os.TempDir(), "mapped2") + f2 := vfs.BaseVirtualFolder{ + Name: filepath.Base(mappedPath2), + MappedPath: mappedPath2, + } + folder2, _, err := httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.QuotaSize = 20 + u.UploadDataTransfer = 1000 + u.DownloadDataTransfer = 5000 + vdirPath1 := "/vmapped1" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder1.Name, + }, + VirtualPath: vdirPath1, + QuotaFiles: 10, + }) + vdirPath2 := "/vmapped2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder2.Name, + }, + VirtualPath: vdirPath2, + QuotaFiles: -1, + QuotaSize: -1, + }) + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaSize = 20 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(5) + assert.NoError(t, err) + expectedQuotaSize = int64(5) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + expectedQuotaSize = int64(5) + int64(len(testFileContent)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now truncate by path + err = client.Truncate(testFileName, 5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + // now open an existing file without truncate it, quota should not change + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + } + // open the file truncating it + f, err = client.OpenFile(testFileName, os.O_WRONLY|os.O_TRUNC) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + // now test max write size + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(11) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(11), user.UsedQuotaSize) + _, err = f.Seek(int64(11), io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + _, err = f.Seek(int64(5), io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(12) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(12), user.UsedQuotaSize) + _, err = f.Seek(int64(12), io.SeekStart) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + // the file is deleted + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + + if user.Username == defaultUsername { + // basic test inside a virtual folder + vfileName1 := path.Join(vdirPath1, testFileName) + f, err = client.OpenFile(vfileName1, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + fold, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + fold, _, err = httpdtest.GetFolderByName(folder1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + } + err = client.Truncate(vfileName1, 1) + assert.NoError(t, err) + fold, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(1), fold.UsedQuotaSize) + assert.Equal(t, 1, fold.UsedQuotaFiles) + // now test on vdirPath2, the folder quota is included in the user's quota + vfileName2 := path.Join(vdirPath2, testFileName) + f, err = client.OpenFile(vfileName2, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + n, err := f.Write(testFileContent) + assert.NoError(t, err) + assert.Equal(t, len(testFileContent), n) + err = f.Truncate(3) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(3) + fold, _, err := httpdtest.GetFolderByName(folder2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), fold.UsedQuotaSize) + assert.Equal(t, 0, fold.UsedQuotaFiles) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + fold, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), fold.UsedQuotaSize) + assert.Equal(t, 0, fold.UsedQuotaFiles) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + + // cleanup + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + if user.Username == defaultUsername { + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.QuotaSize = 0 + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { + testFileSize := int64(131072) + testFileSize1 := int64(65537) + testFileName1 := "test_file1.dat" //nolint:goconst + u := getTestUser() + u.QuotaFiles = 0 + u.QuotaSize = 0 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" //nolint:goconst + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" //nolint:goconst + mappedPath3 := filepath.Join(os.TempDir(), "vdir3") + folderName3 := filepath.Base(mappedPath3) + vdirPath3 := "/vdir3" + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderName3, + MappedPath: mappedPath3, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: 2, + QuotaSize: 0, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 0, + QuotaSize: testFileSize + testFileSize1 + 1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName3, + }, + VirtualPath: vdirPath3, + QuotaFiles: 2, + QuotaSize: testFileSize * 2, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + contents, err := io.ReadAll(f) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + assert.Len(t, contents, int(testFileSize)) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath3, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath3, testFileName+"1"), testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName+".rename")) //nolint:goconst + assert.Error(t, err) + // we overwrite an existing file and we have unlimited size + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + // we have no space and we try to overwrite a bigger file with a smaller one, this should succeed + err = client.Rename(testFileName1, path.Join(vdirPath2, testFileName)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + // we have no space and we try to overwrite a smaller file with a bigger one, this should fail + err = client.Rename(testFileName, path.Join(vdirPath2, testFileName1)) + assert.Error(t, err) + fi, err := client.Stat(path.Join(vdirPath1, testFileName1)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize1, fi.Size()) + } + // we are overquota inside vdir3 size 2/2 and size 262144/262144 + err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName1+".rename")) + assert.Error(t, err) + // we overwrite an existing file and we have enough size + err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName)) + assert.NoError(t, err) + testFileName2 := "test_file2.dat" + err = writeSFTPFile(testFileName2, testFileSize+testFileSize1, client) + assert.NoError(t, err) + // we overwrite an existing file and we haven't enough size + err = client.Rename(testFileName2, path.Join(vdirPath3, testFileName)) + assert.Error(t, err) + // now remove a file from vdir3, create a dir with 2 files and try to rename it in vdir3 + // this will fail since the rename will result in 3 files inside vdir3 and quota limits only + // allow 2 total files there + err = client.Remove(path.Join(vdirPath3, testFileName+"1")) + assert.NoError(t, err) + aDir := "a dir" + err = client.Mkdir(aDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(aDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(aDir, testFileName1+"1"), testFileSize1, client) + assert.NoError(t, err) + err = client.Rename(aDir, path.Join(vdirPath3, aDir)) + assert.Error(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath3) + assert.NoError(t, err) +} + +func TestQuotaRenameOverwrite(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + testFileSize1 := int64(65537) + testFileName1 := "test_file1.dat" + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, testFileSize) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + err = f.Close() + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) + assert.Equal(t, int64(0), user.UsedUploadDataTransfer) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + err = client.Rename(testFileName, testFileName1) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) + assert.Equal(t, int64(0), user.UsedUploadDataTransfer) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + err = client.Remove(testFileName1) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = client.Rename(testFileName1, testFileName) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaValues(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" + folderName2 := filepath.Base(mappedPath2) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + // we copy the same file two times to test quota update on file overwrite + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + expectedQuotaFiles := 2 + expectedQuotaSize := testFileSize * 2 + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = client.Remove(path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + err = client.Remove(path.Join(vdirPath2, testFileName)) + assert.NoError(t, err) + + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" + folderName2 := filepath.Base(mappedPath2) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" //nolint:goconst + dir2 := "dir2" //nolint:goconst + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file inside vdir1 it is included inside user quota, so we have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath1, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file inside vdir2, it isn't included inside user quota, so we have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName.rename + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath2, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file inside vdir2 overwriting an existing, we now have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName.rename (initial testFileName1) + err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file inside vdir1 overwriting an existing, we now have: + // - vdir1/dir1/testFileName.rename (initial testFileName1) + // - vdir2/dir1/testFileName.rename (initial testFileName1) + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath1, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a directory inside the same virtual folder, quota should not change + err = client.RemoveDirectory(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath1, dir1), path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath2, dir1), path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameBetweenVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" + dir2 := "dir2" + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file from vdir1 to vdir2, vdir1 is included inside user quota, so we have: + // - vdir1/dir1/testFileName + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName1+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + // rename a file from vdir2 to vdir1, vdir2 is not included inside user quota, so we have: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName.rename + // - vdir2/dir2/testFileName1 + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath1, dir2, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir1 to vdir2 overwriting an existing file, vdir1 is included inside user quota, so we have: + // - vdir1/dir2/testFileName.rename + // - vdir2/dir2/testFileName1 (is the initial testFileName) + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath2, dir2, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir2 to vdir1 overwriting an existing file, vdir2 is not included inside user quota, so we have: + // - vdir1/dir2/testFileName.rename (is the initial testFileName1) + // - vdir2/dir2/testFileName1 (is the initial testFileName) + err = client.Rename(path.Join(vdirPath2, dir1, testFileName1+".rename"), path.Join(vdirPath1, dir2, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName+"1.dupl"), testFileSize1, client) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + // - vdir1/dir2/testFileName.rename (initial testFileName1) + // - vdir1/dir2/testFileName + // - vdir2/dir2/testFileName1 (initial testFileName) + // - vdir2/dir2/testFileName (initial testFileName1) + // - vdir2/dir2/testFileName1.dupl + // rename directories between the two virtual folders + err = client.Rename(path.Join(vdirPath2, dir2), path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 5, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // now move on vpath2 + err = client.Rename(path.Join(vdirPath1, dir2), path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameFromVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" + dir2 := "dir2" + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file from vdir1 to the user home dir, vdir1 is included in user quota so we have: + // - testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir2 to the user home dir, vdir2 is not included in user quota so we have: + // - testFileName + // - testFileName1 + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from vdir1 to the user home dir overwriting an existing file, vdir1 is included in user quota so we have: + // - testFileName (initial testFileName1) + // - testFileName1 + // - vdir2/dir1/testFileName + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from vdir2 to the user home dir overwriting an existing file, vdir2 is not included in user quota so we have: + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // dir rename + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + // - vdir1/dir1/testFileName + // - vdir1/dir1/testFileName1 + // - dir1/testFileName + // - dir1/testFileName1 + err = client.Rename(path.Join(vdirPath2, dir1), dir1) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + // - dir2/testFileName + // - dir2/testFileName1 + // - dir1/testFileName + // - dir1/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1), dir2) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameToVirtualFolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + u.Permissions[vdirPath1] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, + dataprovider.PermOverwrite, dataprovider.PermDelete, dataprovider.PermCreateSymlinks, dataprovider.PermCreateDirs, + dataprovider.PermRename} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + dir1 := "dir1" + dir2 := "dir2" + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - testFileName + // - testFileName1 + // + // rename a file from user home dir to vdir1, vdir1 is included in user quota so we have: + // - testFileName + // - /vdir1/dir1/testFileName1 + err = client.Rename(testFileName1, path.Join(vdirPath1, dir1, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have: + // - /vdir2/dir1/testFileName + // - /vdir1/dir1/testFileName1 + err = client.Rename(testFileName, path.Join(vdirPath2, dir1, testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // upload two new files to the user home dir so we have: + // - testFileName + // - testFileName1 + // - /vdir1/dir1/testFileName1 + // - /vdir2/dir1/testFileName + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) + // rename a file from user home dir to vdir1 overwriting an existing file, vdir1 is included in user quota so we have: + // - testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName + err = client.Rename(testFileName, path.Join(vdirPath1, dir1, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from user home dir to vdir2 overwriting an existing file, vdir2 is not included in user quota so we have: + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + err = client.Rename(testFileName1, path.Join(vdirPath2, dir1, testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = client.Mkdir(dir1) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - /dir1/testFileName + // - /dir1/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // - /vdir1/adir/testFileName + // - /vdir1/adir/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + err = client.Rename(dir1, path.Join(vdirPath1, "adir")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + err = client.Mkdir(dir1) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - /vdir1/adir/testFileName + // - /vdir1/adir/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + // - /vdir2/adir/testFileName + // - /vdir2/adir/testFileName1 + err = client.Rename(dir1, path.Join(vdirPath2, "adir")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.TotalDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(524288) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, testFileSize) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + assert.Len(t, contents, int(testFileSize)) + err = f.Close() + assert.NoError(t, err) + _, err = client.Open(testFileName) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = writeSFTPFile(testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + // test the limit while uploading/downloading + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 1 + user.DownloadDataTransfer = 1 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(450000) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(testFileName) + if assert.NoError(t, err) { + _, err = io.Copy(io.Discard, f) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + f, err = client.Open(testFileName) + if assert.NoError(t, err) { + _, err = io.Copy(io.Discard, f) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + } + + err = writeSFTPFile(testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestVirtualFoldersLink(t *testing.T) { + u := getTestUser() + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + testDir := "adir" + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testDir, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testDir, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testFileName+".link1")) //nolint:goconst + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testDir, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join(vdirPath1, testFileName), testFileName+".link1") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join(vdirPath2, testFileName), testFileName+".link1") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath1, testFileName+".link1")) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink("/", "/roolink") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink(testFileName, "/") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Symlink(testFileName, vdirPath1) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Symlink(vdirPath1, testFileName+".link2") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestCrossFolderRename(t *testing.T) { + folder1 := "folder1" + folder2 := "folder2" + folder3 := "folder3" + folder4 := "folder4" + folder5 := "folder5" + folder6 := "folder6" + folder7 := "folder7" + + baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + + f1 := vfs.BaseVirtualFolder{ + Name: folder1, + MappedPath: filepath.Join(os.TempDir(), folder1), + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folder2, + MappedPath: filepath.Join(os.TempDir(), folder2), + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folder3, + MappedPath: filepath.Join(os.TempDir(), folder3), + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword + "mod"), + }, + }, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + f4 := vfs.BaseVirtualFolder{ + Name: folder4, + MappedPath: filepath.Join(os.TempDir(), folder4), + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + Prefix: path.Join("/", folder4), + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f4, http.StatusCreated) + assert.NoError(t, err) + f5 := vfs.BaseVirtualFolder{ + Name: folder5, + MappedPath: filepath.Join(os.TempDir(), folder5), + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + Prefix: path.Join("/", folder5), + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f5, http.StatusCreated) + assert.NoError(t, err) + f6 := vfs.BaseVirtualFolder{ + Name: folder6, + MappedPath: filepath.Join(os.TempDir(), folder6), + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: "127.0.0.1:4024", + Username: baseUser.Username, + Prefix: path.Join("/", folder6), + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f6, http.StatusCreated) + assert.NoError(t, err) + f7 := vfs.BaseVirtualFolder{ + Name: folder7, + MappedPath: filepath.Join(os.TempDir(), folder7), + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + Prefix: path.Join("/", folder4), + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f7, http.StatusCreated) + assert.NoError(t, err) + + u := getCryptFsUser() + u.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder1, + }, + VirtualPath: path.Join("/", folder1), + QuotaSize: -1, + QuotaFiles: -1, + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder2, + }, + VirtualPath: path.Join("/", folder2), + QuotaSize: -1, + QuotaFiles: -1, + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder3, + }, + VirtualPath: path.Join("/", folder3), + QuotaSize: -1, + QuotaFiles: -1, + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder4, + }, + VirtualPath: path.Join("/", folder4), + QuotaSize: -1, + QuotaFiles: -1, + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder5, + }, + VirtualPath: path.Join("/", folder5), + QuotaSize: -1, + QuotaFiles: -1, + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder6, + }, + VirtualPath: path.Join("/", folder6), + QuotaSize: -1, + QuotaFiles: -1, + }, + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder7, + }, + VirtualPath: path.Join("/", folder7), + QuotaSize: -1, + QuotaFiles: -1, + }, + } + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + subDir := "testSubDir" + err = client.Mkdir(subDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(subDir, "afile.bin"), 64, client) + assert.NoError(t, err) + err = client.Rename(subDir, path.Join("/", folder1, subDir)) + assert.NoError(t, err) + _, err = client.Stat(path.Join("/", folder1, subDir)) + assert.NoError(t, err) + _, err = client.Stat(path.Join("/", folder1, subDir, "afile.bin")) + assert.NoError(t, err) + err = client.Rename(path.Join("/", folder1, subDir), path.Join("/", folder2, subDir)) + assert.NoError(t, err) + _, err = client.Stat(path.Join("/", folder2, subDir)) + assert.NoError(t, err) + _, err = client.Stat(path.Join("/", folder2, subDir, "afile.bin")) + assert.NoError(t, err) + err = client.Rename(path.Join("/", folder2, subDir), path.Join("/", folder3, subDir)) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join("/", folder3, "file.bin"), 64, client) + assert.NoError(t, err) + err = client.Rename(path.Join("/", folder3, "file.bin"), "/renamed.bin") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join("/", folder3, "file.bin"), path.Join("/", folder2, "/renamed.bin")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join("/", folder3, "file.bin"), path.Join("/", folder3, "/renamed.bin")) + assert.NoError(t, err) + err = writeSFTPFile("/afile.bin", 64, client) + assert.NoError(t, err) + err = client.Rename("afile.bin", path.Join("/", folder4, "afile_renamed.bin")) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join("/", folder4, "afile.bin"), 64, client) + assert.NoError(t, err) + err = client.Rename(path.Join("/", folder4, "afile.bin"), path.Join("/", folder5, "afile_renamed.bin")) + assert.NoError(t, err) + err = client.Rename(path.Join("/", folder5, "afile_renamed.bin"), path.Join("/", folder6, "afile_renamed.bin")) + assert.ErrorIs(t, err, os.ErrPermission) + err = writeSFTPFile(path.Join("/", folder4, "afile.bin"), 64, client) + assert.NoError(t, err) + _, err = client.Stat(path.Join("/", folder7, "afile.bin")) + assert.NoError(t, err) + err = client.Rename(path.Join("/", folder4, "afile.bin"), path.Join("/", folder7, "afile.bin")) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + for _, folderName := range []string{folder1, folder2, folder3, folder4, folder5, folder6, folder7} { + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) + assert.NoError(t, err) + } +} + +func TestDirs(t *testing.T) { + u := getTestUser() + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + vdirPath := "/path/vdir" + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + }) + u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, + dataprovider.PermDelete, dataprovider.PermCreateDirs, dataprovider.PermRename, dataprovider.PermListItems} + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + info, err := client.ReadDir("/") + if assert.NoError(t, err) { + if assert.Len(t, info, 1) { + assert.Equal(t, "path", info[0].Name()) + } + } + fi, err := client.Stat(path.Dir(vdirPath)) + if assert.NoError(t, err) { + assert.True(t, fi.IsDir()) + } + err = client.RemoveDirectory("/") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.RemoveDirectory(vdirPath) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.RemoveDirectory(path.Dir(vdirPath)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.Mkdir(vdirPath) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir("adir") + assert.NoError(t, err) + err = client.Rename("/adir", path.Dir(vdirPath)) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = client.MkdirAll("/subdir/adir") + assert.NoError(t, err) + err = client.Rename("adir", "subdir/adir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = writeSFTPFile("/subdir/afile.bin", 64, client) + assert.NoError(t, err) + err = writeSFTPFile("/afile.bin", 32, client) + assert.NoError(t, err) + err = client.Rename("afile.bin", "subdir/afile.bin") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename("afile.bin", "subdir/afile1.bin") + assert.NoError(t, err) + err = client.Rename(path.Dir(vdirPath), "renamed_vdir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestCryptFsStat(t *testing.T) { + user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(4096) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + info, err = os.Stat(filepath.Join(user.HomeDir, testFileName)) + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), testFileSize) + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestFsPermissionErrors(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + user, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testDir := "tDir" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = os.Chmod(user.GetHomeDir(), 0111) + assert.NoError(t, err) + + err = client.RemoveDirectory(testDir) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testDir, testDir+"1") + assert.ErrorIs(t, err, os.ErrPermission) + + err = os.Chmod(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRenameErrorOutsideHomeDir(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldUploadMode := common.Config.UploadMode + oldTempPath := common.Config.TempPath + + common.Config.UploadMode = common.UploadModeAtomicWithResume + common.Config.TempPath = filepath.Clean(os.TempDir()) + vfs.SetTempPath(common.Config.TempPath) + + u := getTestUser() + u.QuotaFiles = 1000 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = os.Chmod(user.GetHomeDir(), 0555) + assert.NoError(t, err) + + err = checkBasicSFTP(client) + assert.NoError(t, err) + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.ErrorIs(t, err, os.ErrPermission) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + + err = os.Chmod(user.GetHomeDir(), os.ModeDir) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.UploadMode = oldUploadMode + common.Config.TempPath = oldTempPath + vfs.SetTempPath(oldTempPath) +} + +func TestResolvePathError(t *testing.T) { + u := getTestUser() + u.HomeDir = "relative_path" + conn := common.NewBaseConnection("", common.ProtocolFTP, "", "", u) + testPath := "apath" + _, err := conn.ListDir(testPath) + assert.Error(t, err) + err = conn.CreateDir(testPath, true) + assert.Error(t, err) + err = conn.RemoveDir(testPath) + assert.Error(t, err) + err = conn.Rename(testPath, testPath+"1") + assert.Error(t, err) + err = conn.CreateSymlink(testPath, testPath+".sym") + assert.Error(t, err) + _, err = conn.DoStat(testPath, 0, false) + assert.Error(t, err) + err = conn.RemoveAll(testPath) + assert.Error(t, err) + err = conn.SetStat(testPath, &common.StatAttributes{ + Atime: time.Now(), + Mtime: time.Now(), + }) + assert.Error(t, err) + + u = getTestUser() + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: "relative_mapped_path", + }, + VirtualPath: "/vpath", + }) + err = os.MkdirAll(u.HomeDir, os.ModePerm) + assert.NoError(t, err) + conn.User = u + err = conn.Rename(testPath, "/vpath/subpath") + assert.Error(t, err) + + outHomePath := filepath.Join(os.TempDir(), testFileName) + err = os.WriteFile(outHomePath, testFileContent, os.ModePerm) + assert.NoError(t, err) + err = os.Symlink(outHomePath, filepath.Join(u.HomeDir, testFileName+".link")) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(u.HomeDir, testFileName), testFileContent, os.ModePerm) + assert.NoError(t, err) + err = conn.CreateSymlink(testFileName, testFileName+".link") + assert.Error(t, err) + + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(outHomePath) + assert.NoError(t, err) +} + +func TestUserPasswordHashing(t *testing.T) { + if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { + t.Skip("this test is not supported with the memory provider") + } + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.PasswordHashing.Algo = dataprovider.HashingAlgoArgon2ID + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + currentUser, err := dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(currentUser.Password, "$2a$")) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + u = getTestUser() + user, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + currentUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(currentUser.Password, "$argon2id$")) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestAllowList(t *testing.T) { + configCopy := common.Config + + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "172.18.1.1/32", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + { + IPOrNet: "172.18.1.2/32", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + { + IPOrNet: "10.8.7.0/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 5, + }, + { + IPOrNet: "0.0.0.0/0", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 8, + }, + { + IPOrNet: "::/0", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 8, + }, + } + + for _, e := range entries { + _, resp, err := httpdtest.AddIPListEntry(e, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + + common.Config.AllowListStatus = 1 + err := common.Initialize(common.Config, 0) + assert.NoError(t, err) + assert.True(t, common.Config.IsAllowListEnabled()) + + testIP := "172.18.1.1" + assert.NoError(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolFTP)) + entry := entries[0] + entry.Protocols = 1 + _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) + assert.Error(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolFTP)) + assert.NoError(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolSSH)) + _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) + entries = entries[1:] + assert.Error(t, common.Connections.IsNewConnectionAllowed(testIP, common.ProtocolSSH)) + assert.Error(t, common.Connections.IsNewConnectionAllowed("172.18.1.3", common.ProtocolSSH)) + assert.NoError(t, common.Connections.IsNewConnectionAllowed("172.18.1.3", common.ProtocolHTTP)) + + assert.NoError(t, common.Connections.IsNewConnectionAllowed("10.8.7.3", common.ProtocolWebDAV)) + assert.NoError(t, common.Connections.IsNewConnectionAllowed("10.8.7.4", common.ProtocolSSH)) + assert.Error(t, common.Connections.IsNewConnectionAllowed("10.8.7.4", common.ProtocolFTP)) + assert.NoError(t, common.Connections.IsNewConnectionAllowed("10.8.7.4", common.ProtocolHTTP)) + assert.NoError(t, common.Connections.IsNewConnectionAllowed("2001:0db8::1428:57ab", common.ProtocolHTTP)) + assert.Error(t, common.Connections.IsNewConnectionAllowed("2001:0db8::1428:57ab", common.ProtocolSSH)) + assert.Error(t, common.Connections.IsNewConnectionAllowed("10.8.8.2", common.ProtocolWebDAV)) + assert.Error(t, common.Connections.IsNewConnectionAllowed("invalid IP", common.ProtocolHTTP)) + + common.Config = configCopy + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + assert.False(t, common.Config.IsAllowListEnabled()) + + for _, e := range entries { + _, err := httpdtest.RemoveIPListEntry(e, http.StatusOK) + assert.NoError(t, err) + } +} + +func TestDbDefenderErrors(t *testing.T) { + if !isDbDefenderSupported() { + t.Skip("this test is not supported with the current database provider") + } + configCopy := common.Config + common.Config.DefenderConfig.Enabled = true + common.Config.DefenderConfig.Driver = common.DefenderDriverProvider + err := common.Initialize(common.Config, 0) + assert.NoError(t, err) + + testIP := "127.1.1.1" + hosts, err := common.GetDefenderHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 0) + common.AddDefenderEvent(testIP, common.ProtocolSSH, common.HostEventLimitExceeded) + hosts, err = common.GetDefenderHosts() + assert.NoError(t, err) + assert.Len(t, hosts, 1) + score, err := common.GetDefenderScore(testIP) + assert.NoError(t, err) + assert.Equal(t, 3, score) + banTime, err := common.GetDefenderBanTime(testIP) + assert.NoError(t, err) + assert.Nil(t, banTime) + + err = dataprovider.Close() + assert.NoError(t, err) + + common.AddDefenderEvent(testIP, common.ProtocolFTP, common.HostEventLimitExceeded) + _, err = common.GetDefenderHosts() + assert.Error(t, err) + _, err = common.GetDefenderHost(testIP) + assert.Error(t, err) + _, err = common.GetDefenderBanTime(testIP) + assert.Error(t, err) + _, err = common.GetDefenderScore(testIP) + assert.Error(t, err) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Hour))) + assert.NoError(t, err) + + common.Config = configCopy + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) +} + +func TestDelayedQuotaUpdater(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.DelayedQuotaUpdate = 120 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + u := getTestUser() + u.QuotaFiles = 100 + u.TotalDataTransfer = 2000 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + err = dataprovider.UpdateUserQuota(&user, 10, 6000, false) + assert.NoError(t, err) + err = dataprovider.UpdateUserTransferQuota(&user, 100, 200, false) + assert.NoError(t, err) + files, size, ulSize, dlSize, err := dataprovider.GetUsedQuota(user.Username) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + assert.Equal(t, int64(100), ulSize) + assert.Equal(t, int64(200), dlSize) + + userGet, err := dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.Equal(t, 0, userGet.UsedQuotaFiles) + assert.Equal(t, int64(0), userGet.UsedQuotaSize) + assert.Equal(t, int64(0), userGet.UsedUploadDataTransfer) + assert.Equal(t, int64(0), userGet.UsedDownloadDataTransfer) + + err = dataprovider.UpdateUserQuota(&user, 10, 6000, true) + assert.NoError(t, err) + err = dataprovider.UpdateUserTransferQuota(&user, 100, 200, true) + assert.NoError(t, err) + files, size, ulSize, dlSize, err = dataprovider.GetUsedQuota(user.Username) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + assert.Equal(t, int64(100), ulSize) + assert.Equal(t, int64(200), dlSize) + + userGet, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.Equal(t, 10, userGet.UsedQuotaFiles) + assert.Equal(t, int64(6000), userGet.UsedQuotaSize) + assert.Equal(t, int64(100), userGet.UsedUploadDataTransfer) + assert.Equal(t, int64(200), userGet.UsedDownloadDataTransfer) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + folder := vfs.BaseVirtualFolder{ + Name: "folder", + MappedPath: filepath.Join(os.TempDir(), "p"), + } + err = dataprovider.AddFolder(&folder, "", "", "") + assert.NoError(t, err) + + err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 6000, false) + assert.NoError(t, err) + files, size, err = dataprovider.GetUsedVirtualFolderQuota(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + + folderGet, err := dataprovider.GetFolderByName(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 0, folderGet.UsedQuotaFiles) + assert.Equal(t, int64(0), folderGet.UsedQuotaSize) + + err = dataprovider.UpdateVirtualFolderQuota(&folder, 10, 6000, true) + assert.NoError(t, err) + files, size, err = dataprovider.GetUsedVirtualFolderQuota(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 10, files) + assert.Equal(t, int64(6000), size) + + folderGet, err = dataprovider.GetFolderByName(folder.Name) + assert.NoError(t, err) + assert.Equal(t, 10, folderGet.UsedQuotaFiles) + assert.Equal(t, int64(6000), folderGet.UsedQuotaSize) + + err = dataprovider.DeleteFolder(folder.Name, "", "", "") + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestPasswordCaching(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + dbUser, err := dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + found, match := dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) + + user.Password = "wrong" + _, _, err = getSftpClient(user) + assert.Error(t, err) + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) + user.Password = "" + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.True(t, found) + assert.True(t, match) + + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword+"_", dbUser.Password) + assert.True(t, found) + assert.False(t, match) + + found, match = dataprovider.CheckCachedUserPassword(user.Username+"_", defaultPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) + + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // the password was not changed + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.True(t, found) + assert.True(t, match) + // the password hash will change + user.Password = defaultPassword + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + dbUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.True(t, found) + assert.True(t, match) + //change password + newPassword := defaultPassword + "mod" + user.Password = newPassword + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + dbUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + found, match = dataprovider.CheckCachedUserPassword(user.Username, newPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.True(t, found) + assert.False(t, match) + found, match = dataprovider.CheckCachedUserPassword(user.Username, newPassword, dbUser.Password) + assert.True(t, found) + assert.True(t, match) + // update the password + err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") + assert.NoError(t, err) + dbUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + // the stored hash does not match + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) + + user.Password = defaultPassword + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.True(t, found) + assert.True(t, match) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + found, match = dataprovider.CheckCachedUserPassword(user.Username, defaultPassword, dbUser.Password) + assert.False(t, found) + assert.False(t, match) +} + +func TestEventRule(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "http://localhost", + Timeout: 20, + Method: http.MethodGet, + }, + }, + } + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeBackup, + } + a3 := dataprovider.BaseEventAction{ + Name: "action3", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test1@example.com", "test2@example.com"}, + Bcc: []string{"test3@example.com"}, + Subject: `New "{{.Event}}" from "{{.Name}}" status {{.StatusString}}`, + Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} Data: {{.ObjectData}} {{.ErrorString}}", + }, + }, + } + a4 := dataprovider.BaseEventAction{ + Name: "action4", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `Failed "{{.Event}}" from "{{.Name}}"`, + Body: "Fs path {{.FsPath}}, protocol: {{.Protocol}}, IP: {{.IP}} {{.ErrorString}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + action4, _, err := httpdtest.AddEventAction(a4, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "test rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + Options: dataprovider.ConditionOptions{ + EventStatuses: []int{1}, + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "/subdir/*.dat", + }, + { + Pattern: "/**/*.txt", + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 3, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action4.Name, + }, + Order: 4, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + r2 := dataprovider.EventRule{ + Name: "test rule2", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"download"}, + Options: dataprovider.ConditionOptions{ + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "/**/*.dat", + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action4.Name, + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + }, + } + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + + r3 := dataprovider.EventRule{ + Name: "test rule3", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"delete"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 1, + }, + }, + } + rule3, _, err := httpdtest.AddEventRule(r3, http.StatusCreated) + assert.NoError(t, err) + + uploadScriptPath := filepath.Join(os.TempDir(), "upload.sh") + u := getTestUser() + u.DownloadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + movedFileName := "moved.dat" + movedPath := filepath.Join(user.HomeDir, movedFileName) + err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 0), 0755) + assert.NoError(t, err) + + dataprovider.EnabledActionCommands = []string{uploadScriptPath} + defer func() { + dataprovider.EnabledActionCommands = nil + }() + + action1.Type = dataprovider.ActionTypeCommand + action1.Options = dataprovider.BaseEventActionOptions{ + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: uploadScriptPath, + Timeout: 10, + EnvVars: []dataprovider.KeyValue{ + { + Key: "SFTPGO_ACTION_PATH", + Value: "{{.FsPath}}", + }, + { + Key: "CUSTOM_ENV_VAR", + Value: "value", + }, + }, + }, + } + action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + + dirName := "subdir" + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + size := int64(32768) + // rule conditions does not match + err = writeSFTPFileNoCheck(testFileName, size, client) + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, size, info.Size()) + } + err = client.Mkdir(dirName) + assert.NoError(t, err) + err = client.Mkdir("subdir1") + assert.NoError(t, err) + // rule conditions match + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck(path.Join(dirName, testFileName), size, client) + assert.NoError(t, err) + _, err = client.Stat(path.Join(dirName, testFileName)) + assert.Error(t, err) + info, err = client.Stat(movedFileName) + if assert.NoError(t, err) { + assert.Equal(t, size, info.Size()) + } + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 3) + assert.True(t, slices.Contains(email.To, "test1@example.com")) + assert.True(t, slices.Contains(email.To, "test2@example.com")) + assert.True(t, slices.Contains(email.To, "test3@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "upload" from "%s" status OK`, user.Username)) + // test the failure action, we download a file that exceeds the transfer quota limit + err = writeSFTPFileNoCheck(path.Join("subdir1", testFileName), 1*1024*1024+65535, client) + assert.NoError(t, err) + lastReceivedEmail.reset() + f, err := client.Open(path.Join("subdir1", testFileName)) + assert.NoError(t, err) + _, err = io.ReadAll(f) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 3) + assert.True(t, slices.Contains(email.To, "test1@example.com")) + assert.True(t, slices.Contains(email.To, "test2@example.com")) + assert.True(t, slices.Contains(email.To, "test3@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s" status KO`, user.Username)) + assert.Contains(t, email.Data, `"download" failed`) + assert.Contains(t, email.Data, common.ErrReadQuotaExceeded.Error()) + _, err = httpdtest.UpdateTransferQuotaUsage(user, "", http.StatusOK) + assert.NoError(t, err) + + // remove the upload script to test the failure action + err = os.Remove(uploadScriptPath) + assert.NoError(t, err) + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck(path.Join(dirName, testFileName), size, client) + assert.Error(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "failure@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: Failed "upload" from "%s"`, user.Username)) + assert.Contains(t, email.Data, fmt.Sprintf(`action %q failed`, action1.Name)) + // now test the download rule + lastReceivedEmail.reset() + f, err = client.Open(movedFileName) + assert.NoError(t, err) + contents, err := io.ReadAll(f) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + assert.Len(t, contents, int(size)) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 3) + assert.True(t, slices.Contains(email.To, "test1@example.com")) + assert.True(t, slices.Contains(email.To, "test2@example.com")) + assert.True(t, slices.Contains(email.To, "test3@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: New "download" from "%s"`, user.Username)) + } + // test upload action command with arguments + action1.Options.CmdConfig.Args = []string{"{{.Event}}", "{{.VirtualPath}}", "custom_arg"} + action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + uploadLogFilePath := filepath.Join(os.TempDir(), "upload.log") + err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, uploadLogFilePath, 0), 0755) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = writeSFTPFileNoCheck(path.Join(dirName, testFileName), 123, client) + assert.NoError(t, err) + + logContent, err := os.ReadFile(uploadLogFilePath) + assert.NoError(t, err) + assert.Equal(t, fmt.Sprintf("upload %s custom_arg", util.CleanPath(path.Join(dirName, testFileName))), + strings.TrimSpace(string(logContent))) + + err = os.Remove(uploadLogFilePath) + assert.NoError(t, err) + lastReceivedEmail.reset() + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + } + + lastReceivedEmail.reset() + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 3) + assert.True(t, slices.Contains(email.To, "test1@example.com")) + assert.True(t, slices.Contains(email.To, "test2@example.com")) + assert.True(t, slices.Contains(email.To, "test3@example.com")) + assert.Contains(t, email.Data, `Subject: New "delete" from "admin"`) + _, err = httpdtest.RemoveEventRule(rule3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action4, http.StatusOK) + assert.NoError(t, err) + lastReceivedEmail.reset() + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleStatues(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test6@example.com"}, + Subject: `New "{{.Event}}" error`, + Body: "{{.ErrorString}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + r := dataprovider.EventRule{ + Name: "rule", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + Options: dataprovider.ConditionOptions{ + EventStatuses: []int{3}, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule, resp, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + u := getTestUser() + u.UploadDataTransfer = 1 + u.DownloadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(999999) + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, testFileSize) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + assert.Len(t, contents, int(testFileSize)) + err = f.Close() + assert.NoError(t, err) + + lastReceivedEmail.reset() + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From == "" + }, 600*time.Millisecond, 500*time.Millisecond) + + err = writeSFTPFile(testFileName, testFileSize, client) + assert.Error(t, err) + lastReceivedEmail.reset() + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test6@example.com")) + assert.Contains(t, email.Data, `Subject: New "upload" error`) + assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error()) + } + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleDisabledCommand(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh") + outPath := filepath.Join(os.TempDir(), "provider_out.json") + err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755) + assert.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeCommand, + Options: dataprovider.BaseEventActionOptions{ + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: saveObjectScriptPath, + Timeout: 10, + EnvVars: []dataprovider.KeyValue{ + { + Key: "SFTPGO_OBJECT_DATA", + Value: "{{.ObjectData}}", + }, + }, + }, + }, + } + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test3@example.com"}, + Subject: `New "{{.Event}}" from "{{.Name}}"`, + Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}} Data: {{.ObjectData}}", + }, + }, + } + + a3 := dataprovider.BaseEventAction{ + Name: "a3", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `Failed "{{.Event}}" from "{{.Name}}"`, + Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}}, IP: {{.IP}}", + }, + }, + } + _, _, err = httpdtest.AddEventAction(a1, http.StatusBadRequest) + assert.NoError(t, err) + // Enable the command to allow saving + dataprovider.EnabledActionCommands = []string{a1.Options.CmdConfig.Cmd} + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + + r := dataprovider.EventRule{ + Name: "rule", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + Options: dataprovider.ConditionOptions{ + ProviderObjects: []string{"folder"}, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 3, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + StopOnFailure: true, + }, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + // restrict command execution + dataprovider.EnabledActionCommands = nil + + lastReceivedEmail.reset() + // create a folder to trigger the rule + folder := vfs.BaseVirtualFolder{ + Name: "ftest failed command", + MappedPath: filepath.Join(os.TempDir(), "p"), + } + folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err) + + assert.NoFileExists(t, outPath) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "failure@example.com")) + assert.Contains(t, email.Data, `Subject: Failed "add" from "admin"`) + assert.Contains(t, email.Data, fmt.Sprintf("Object name: %s object type: folder", folder.Name)) + lastReceivedEmail.reset() + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) +} + +func TestEventRuleProviderEvents(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + saveObjectScriptPath := filepath.Join(os.TempDir(), "provider.sh") + outPath := filepath.Join(os.TempDir(), "provider_out.json") + err = os.WriteFile(saveObjectScriptPath, getSaveProviderObjectScriptContent(outPath, 0), 0755) + assert.NoError(t, err) + + dataprovider.EnabledActionCommands = []string{saveObjectScriptPath} + defer func() { + dataprovider.EnabledActionCommands = nil + }() + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeCommand, + Options: dataprovider.BaseEventActionOptions{ + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: saveObjectScriptPath, + Timeout: 10, + EnvVars: []dataprovider.KeyValue{ + { + Key: "SFTPGO_OBJECT_DATA", + Value: "{{.ObjectData}}", + }, + }, + }, + }, + } + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test3@example.com"}, + Subject: `New "{{.Event}}" from "{{.Name}}"`, + Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}} Data: {{.ObjectData}}", + }, + }, + } + + a3 := dataprovider.BaseEventAction{ + Name: "a3", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `Failed "{{.Event}}" from "{{.Name}}"`, + Body: "Object name: {{.ObjectName}} object type: {{.ObjectType}}, IP: {{.IP}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + + r := dataprovider.EventRule{ + Name: "rule", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"update"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 3, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + StopOnFailure: true, + }, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + + lastReceivedEmail.reset() + // create and update a folder to trigger the rule + folder := vfs.BaseVirtualFolder{ + Name: "ftest rule", + MappedPath: filepath.Join(os.TempDir(), "p"), + } + folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err) + // no action is triggered on add + assert.NoFileExists(t, outPath) + // update the folder + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + if assert.Eventually(t, func() bool { + _, err := os.Stat(outPath) + return err == nil + }, 2*time.Second, 100*time.Millisecond) { + content, err := os.ReadFile(outPath) + assert.NoError(t, err) + var folderGet vfs.BaseVirtualFolder + err = json.Unmarshal(content, &folderGet) + assert.NoError(t, err) + assert.Equal(t, folder, folderGet) + err = os.Remove(outPath) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test3@example.com")) + assert.Contains(t, email.Data, `Subject: New "update" from "admin"`) + } + // now delete the script to generate an error + lastReceivedEmail.reset() + err = os.Remove(saveObjectScriptPath) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + assert.NoFileExists(t, outPath) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "failure@example.com")) + assert.Contains(t, email.Data, `Subject: Failed "update" from "admin"`) + assert.Contains(t, email.Data, fmt.Sprintf("Object name: %s object type: folder", folder.Name)) + lastReceivedEmail.reset() + // generate an error for the failure action + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + assert.NoFileExists(t, outPath) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 0) + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) +} + +func TestEventRuleFsActions(t *testing.T) { + dirsToCreate := []string{ + "/basedir/1", + "/basedir/sub/2", + "/basedir/3", + } + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionMkdirs, + MkDirs: dirsToCreate, + }, + }, + } + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/{{.VirtualDirPath}}/{{.ObjectName}}", + Value: "/{{.ObjectName}}_renamed", + }, + }, + }, + }, + }, + } + a3 := dataprovider.BaseEventAction{ + Name: "a3", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionDelete, + Deletes: []string{"/{{.ObjectName}}_renamed"}, + }, + }, + } + a4 := dataprovider.BaseEventAction{ + Name: "a4", + Type: dataprovider.ActionTypeFolderQuotaReset, + } + a5 := dataprovider.BaseEventAction{ + Name: "a5", + Type: dataprovider.ActionTypeUserQuotaReset, + } + a6 := dataprovider.BaseEventAction{ + Name: "a6", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionExist, + Exist: []string{"/{{.VirtualPath}}"}, + }, + }, + } + action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + action2, resp, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + action3, resp, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err, string(resp)) + action4, resp, err := httpdtest.AddEventAction(a4, http.StatusCreated) + assert.NoError(t, err, string(resp)) + action5, resp, err := httpdtest.AddEventAction(a5, http.StatusCreated) + assert.NoError(t, err, string(resp)) + action6, resp, err := httpdtest.AddEventAction(a6, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + r1 := dataprovider.EventRule{ + Name: "r1", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + r2 := dataprovider.EventRule{ + Name: "r2", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action5.Name, + }, + Order: 2, + }, + }, + } + r3 := dataprovider.EventRule{ + Name: "r3", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"mkdir"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action6.Name, + }, + Order: 2, + }, + }, + } + r4 := dataprovider.EventRule{ + Name: "r4", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"rmdir"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action4.Name, + }, + Order: 1, + }, + }, + } + r5 := dataprovider.EventRule{ + Name: "r5", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action4.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + rule3, _, err := httpdtest.AddEventRule(r3, http.StatusCreated) + assert.NoError(t, err) + rule4, _, err := httpdtest.AddEventRule(r4, http.StatusCreated) + assert.NoError(t, err) + rule5, _, err := httpdtest.AddEventRule(r5, http.StatusCreated) + assert.NoError(t, err) + + folderMappedPath := filepath.Join(os.TempDir(), "folder") + err = os.MkdirAll(folderMappedPath, os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(folderMappedPath, "file.txt"), []byte("1"), 0666) + assert.NoError(t, err) + + folder, _, err := httpdtest.AddFolder(vfs.BaseVirtualFolder{ + Name: "test folder", + MappedPath: folderMappedPath, + }, http.StatusCreated) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + folderGet, _, err := httpdtest.GetFolderByName(folder.Name, http.StatusOK) + if err != nil { + return false + } + return folderGet.UsedQuotaFiles == 1 && folderGet.UsedQuotaSize == 1 + }, 2*time.Second, 100*time.Millisecond) + + u := getTestUser() + u.Filters.DisableFsChecks = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + // check initial directories creation + for _, dir := range dirsToCreate { + assert.Eventually(t, func() bool { + _, err := client.Stat(dir) + return err == nil + }, 2*time.Second, 100*time.Millisecond) + } + // upload a file and check the sync rename + size := int64(32768) + err = writeSFTPFileNoCheck(path.Join("basedir", testFileName), size, client) + assert.NoError(t, err) + _, err = client.Stat(path.Join("basedir", testFileName)) + assert.Error(t, err) + info, err := client.Stat(testFileName + "_renamed") //nolint:goconst + if assert.NoError(t, err) { + assert.Equal(t, size, info.Size()) + } + assert.NoError(t, err) + assert.Eventually(t, func() bool { + userGet, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) + if err != nil { + return false + } + return userGet.UsedQuotaFiles == 1 && userGet.UsedQuotaSize == size + }, 2*time.Second, 100*time.Millisecond) + + for i := 0; i < 2; i++ { + err = client.Mkdir(testFileName) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + _, err = client.Stat(testFileName + "_renamed") + return err != nil + }, 2*time.Second, 100*time.Millisecond) + err = client.RemoveDirectory(testFileName) + assert.NoError(t, err) + } + err = client.Mkdir(testFileName + "_renamed") + assert.NoError(t, err) + err = client.Mkdir(testFileName) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + _, err = client.Stat(testFileName + "_renamed") + return err != nil + }, 2*time.Second, 100*time.Millisecond) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(folderMappedPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule4, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule5, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action4, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action5, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action6, http.StatusOK) + assert.NoError(t, err) +} + +func TestEventActionObjectBaseName(t *testing.T) { + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/{{.VirtualDirPath}}/{{.ObjectName}}", + Value: "/{{.ObjectBaseName}}", + }, + }, + }, + }, + }, + } + action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + r1 := dataprovider.EventRule{ + Name: "r2", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testDir := "test dir name" + err = client.Mkdir(testDir) + fileSize := int64(32768) + assert.NoError(t, err) + err = writeSFTPFileNoCheck(path.Join(testDir, testFileName), fileSize, client) + assert.NoError(t, err) + + _, err = client.Stat(path.Join(testDir, testFileName)) + assert.ErrorIs(t, err, os.ErrNotExist) + + _, err = client.Stat(strings.TrimSuffix(testFileName, path.Ext(testFileName))) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) +} + +func TestUploadEventRule(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test1@example.com"}, + Subject: `New "{{.Event}}" from "{{.Name}}" status {{.StatusString}}`, + Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} Data: {{.ObjectData}} {{.ErrorString}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "test rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + Options: dataprovider.ConditionOptions{ + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "/**/*.filepart", + InverseMatch: true, + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck("/test.filepart", 32768, client) + assert.NoError(t, err) + email := lastReceivedEmail.get() + assert.Empty(t, email.From) + + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck(testFileName, 32768, client) + assert.NoError(t, err) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.Data, `Subject: New "upload"`) + } + + r2 := dataprovider.EventRule{ + Name: "test rule2", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"rename"}, + Options: dataprovider.ConditionOptions{ + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "/**/*.filepart", + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + tempName := "file.filepart" + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck(tempName, 32768, client) + assert.NoError(t, err) + email := lastReceivedEmail.get() + assert.Empty(t, email.From) + + lastReceivedEmail.reset() + err = client.Rename(tempName, testFileName) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.Data, `Subject: New "rename"`) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRulePreDelete(t *testing.T) { + movePath := "recycle bin" + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/{{.VirtualPath}}", + Value: fmt.Sprintf("/%s/{{.VirtualPath}}", movePath), + }, + UpdateModTime: true, + }, + }, + }, + }, + } + action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"pre-delete"}, + Options: dataprovider.ConditionOptions{ + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: fmt.Sprintf("/%s/**", movePath), + InverseMatch: true, + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + f := vfs.BaseVirtualFolder{ + Name: movePath, + MappedPath: filepath.Join(os.TempDir(), movePath), + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.QuotaFiles = 1000 + u.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: movePath, + }, + VirtualPath: "/" + movePath, + QuotaFiles: 1000, + }, + } + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 1000 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testDir := "sub dir" + err = client.MkdirAll(testDir) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, 100, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), 100, client) + assert.NoError(t, err) + modTime := time.Now().Add(-36 * time.Hour) + err = client.Chtimes(testFileName, modTime, modTime) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Remove(path.Join(testDir, testFileName)) + assert.NoError(t, err) + // check files + _, err = client.Stat(testFileName) + assert.ErrorIs(t, err, os.ErrNotExist) + _, err = client.Stat(path.Join(testDir, testFileName)) + assert.ErrorIs(t, err, os.ErrNotExist) + info, err := client.Stat(path.Join("/", movePath, testFileName)) + assert.NoError(t, err) + diff := math.Abs(time.Until(info.ModTime()).Seconds()) + assert.LessOrEqual(t, diff, float64(2)) + + _, err = client.Stat(path.Join("/", movePath, testDir, testFileName)) + assert.NoError(t, err) + // check quota + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if user.Username == localUser.Username { + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, folder.UsedQuotaFiles) + assert.Equal(t, int64(200), folder.UsedQuotaSize) + } else { + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(100), user.UsedQuotaSize) + } + // pre-delete action is not executed in movePath + err = client.Remove(path.Join("/", movePath, testFileName)) + assert.NoError(t, err) + if user.Username == localUser.Username { + // check quota + folder, _, err := httpdtest.GetFolderByName(movePath, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, folder.UsedQuotaFiles) + assert.Equal(t, int64(100), folder.UsedQuotaSize) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + } + } + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: movePath}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), movePath)) + assert.NoError(t, err) +} + +func TestEventRulePreDownloadUpload(t *testing.T) { + testDir := "/d" + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionMkdirs, + MkDirs: []string{testDir}, + }, + }, + } + action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/missing source", + Value: "/missing target", + }, + }, + }, + }, + }, + } + action2, resp, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"pre-download", "pre-upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + // the rule will always succeed, so uploads/downloads will work + err = writeSFTPFile(testFileName, 100, client) + assert.NoError(t, err) + _, err = client.Stat(testDir) + assert.NoError(t, err) + err = client.RemoveDirectory(testDir) + assert.NoError(t, err) + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, 100) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(100), n) + err = f.Close() + assert.NoError(t, err) + // disable the rule + rule1.Status = 0 + _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + err = client.RemoveDirectory(testDir) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, 100, client) + assert.NoError(t, err) + _, err = client.Stat(testDir) + assert.ErrorIs(t, err, fs.ErrNotExist) + // now update the rule so that it will always fail + rule1.Status = 1 + rule1.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + } + _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = client.Open(testFileName) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, 100, client) + assert.ErrorIs(t, err, os.ErrPermission) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestEventActionCommandEnvVars(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + envName := "MY_ENV" + uploadScriptPath := filepath.Join(os.TempDir(), "upload.sh") + + dataprovider.EnabledActionCommands = []string{uploadScriptPath} + defer func() { + dataprovider.EnabledActionCommands = nil + }() + + err := os.WriteFile(uploadScriptPath, getUploadScriptEnvContent(envName), 0755) + assert.NoError(t, err) + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeCommand, + Options: dataprovider.BaseEventActionOptions{ + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: uploadScriptPath, + Timeout: 10, + EnvVars: []dataprovider.KeyValue{ + { + Key: envName, + Value: "$", + }, + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "test rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = writeSFTPFileNoCheck(testFileName, 100, client) + assert.Error(t, err) + } + + os.Setenv(envName, "1") + defer os.Unsetenv(envName) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = writeSFTPFileNoCheck(testFileName, 100, client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(uploadScriptPath) + assert.NoError(t, err) +} + +func TestFsActionCopy(t *testing.T) { + dirCopy := "/dircopy" + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCopy, + Copy: []dataprovider.KeyValue{ + { + Key: "/{{.VirtualPath}}/", + Value: dirCopy + "/", + }, + }, + }, + }, + } + action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + g1 := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "group1", + }, + UserSettings: dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + Permissions: map[string][]string{ + // Restrict permissions in copyPath to check that action + // will have full permissions anyway. + dirCopy: {dataprovider.PermListItems, dataprovider.PermDelete}, + }, + }, + }, + } + group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = writeSFTPFile(testFileName, 100, client) + assert.NoError(t, err) + _, err = client.Stat(path.Join(dirCopy, testFileName)) + assert.NoError(t, err) + + action1.Options.FsConfig.Copy = []dataprovider.KeyValue{ + { + Key: "/missing path", + Value: "/copied path", + }, + } + _, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + // copy a missing path will fail + err = writeSFTPFile(testFileName, 100, client) + assert.Error(t, err) + } + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) +} + +func TestEventFsActionsGroupFilters(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"example@example.net"}, + Subject: `New "{{.Event}}" from "{{.Name}}" status {{.StatusString}}`, + Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} {{.ErrorString}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + Options: dataprovider.ConditionOptions{ + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "group*", + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + // the user has no group, so the rule does not match + lastReceivedEmail.reset() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + assert.Empty(t, lastReceivedEmail.get().From) + } + g1 := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "agroup1", + }, + } + group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err) + + g2 := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "group2", + }, + } + group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) + assert.NoError(t, err) + user.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + } + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + // the group does not match + lastReceivedEmail.reset() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + assert.Empty(t, lastReceivedEmail.get().From) + } + user.Groups = append(user.Groups, sdk.GroupMapping{ + Name: group2.Name, + Type: sdk.GroupTypeSecondary, + }) + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + // the group matches + lastReceivedEmail.reset() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + assert.NotEmpty(t, lastReceivedEmail.get().From) + } + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventProviderActionGroupFilters(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"example@example.net"}, + Subject: `New "{{.Event}}" from "{{.Name}}"`, + Body: "IP: {{.IP}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"add", "update"}, + Options: dataprovider.ConditionOptions{ + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "group_*", + }, + }, + ProviderObjects: []string{"user"}, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + g1 := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "agroup_1", + }, + } + group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err) + + g2 := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "group_2", + }, + } + group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group2.Name, + Type: sdk.GroupTypePrimary, + }, + } + + lastReceivedEmail.reset() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + + user.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + } + + lastReceivedEmail.reset() + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + time.Sleep(300 * time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 0) + + user.Groups = []sdk.GroupMapping{ + { + Name: group2.Name, + Type: sdk.GroupTypePrimary, + }, + } + + lastReceivedEmail.reset() + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestBackupAsAttachment(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1 with space", + Type: dataprovider.ActionTypeBackup, + } + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}} {{.StatusString}}"`, + Body: "Domain: {{.Name}}", + Attachments: []string{"/{{.VirtualPath}}"}, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "test rule certificate", + Status: 1, + Trigger: dataprovider.EventTriggerCertificate, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + lastReceivedEmail.reset() + renewalEvent := "Certificate renewal" + + common.HandleCertificateEvent(common.EventParams{ + Name: "example.com", + Timestamp: time.Now(), + Status: 1, + Event: renewalEvent, + }) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent)) + assert.Contains(t, email.Data, `Domain: example.com`) + assert.Contains(t, email.Data, "Content-Type: application/json") + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventActionHTTPMultipart(t *testing.T) { + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: fmt.Sprintf("http://%s/multipart", httpAddr), + Method: http.MethodPut, + Parts: []dataprovider.HTTPPart{ + { + Name: "part1", + Headers: []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + }, + Body: `{"FilePath": "{{.VirtualPath}}"}`, + }, + { + Name: "file", + Filepath: "/{{.VirtualPath}}", + }, + }, + }, + }, + } + action1, resp, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + r1 := dataprovider.EventRule{ + Name: "test http multipart", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + // now add an missing file to the http multipart action + action1.Options.HTTPConfig.Parts = append(action1.Options.HTTPConfig.Parts, dataprovider.HTTPPart{ + Name: "file1", + Filepath: "/missing", + }) + _, resp, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err, string(resp)) + + f, err = client.Create("testfile.txt") + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.Error(t, err) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestEventActionCompress(t *testing.T) { + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCompress, + Compress: dataprovider.EventActionFsCompress{ + Name: "/{{.VirtualPath}}.zip", + Paths: []string{"/{{.VirtualPath}}"}, + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test compress", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.QuotaFiles = 1000 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.FsConfig.SFTPConfig.BufferSize = 1 + u.QuotaFiles = 1000 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getCryptFsUser() + u.QuotaFiles = 1000 + cryptFsUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser, cryptFsUser} { + // cleanup home dir + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + rule1.Conditions.Options.Names = []dataprovider.ConditionPattern{ + { + Pattern: user.Username, + }, + } + _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + expectedQuotaSize := int64(len(testFileContent)) + expectedQuotaFiles := 1 + if user.Username == cryptFsUser.Username { + encryptedFileSize, err := getEncryptedFileSize(expectedQuotaSize) + assert.NoError(t, err) + expectedQuotaSize = encryptedFileSize + } + + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + info, err := client.Stat(testFileName + ".zip") //nolint:goconst + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), int64(0)) + // check quota + archiveSize := info.Size() + if user.Username == cryptFsUser.Username { + encryptedFileSize, err := getEncryptedFileSize(archiveSize) + assert.NoError(t, err) + archiveSize = encryptedFileSize + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles+1, user.UsedQuotaFiles, + "quota file does no match for user %q", user.Username) + assert.Equal(t, expectedQuotaSize+archiveSize, user.UsedQuotaSize, + "quota size does no match for user %q", user.Username) + } + // now overwrite the same file + f, err = client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + info, err = client.Stat(testFileName + ".zip") + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), int64(0)) + archiveSize := info.Size() + if user.Username == cryptFsUser.Username { + encryptedFileSize, err := getEncryptedFileSize(archiveSize) + assert.NoError(t, err) + archiveSize = encryptedFileSize + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles+1, user.UsedQuotaFiles, + "quota file after overwrite does no match for user %q", user.Username) + assert.Equal(t, expectedQuotaSize+archiveSize, user.UsedQuotaSize, + "quota size after overwrite does no match for user %q", user.Username) + } + } + if user.Username == localUser.Username { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + } + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(cryptFsUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(cryptFsUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestEventActionCompressQuotaErrors(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + testDir := "archiveDir" + zipPath := "/archive.zip" + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCompress, + Compress: dataprovider.EventActionFsCompress{ + Name: zipPath, + Paths: []string{"/" + testDir}, + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"Compress failed"`, + Body: "Error: {{.ErrorString}}", + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test compress", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"rename"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + Order: 2, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + fileSize := int64(100) + u := getTestUser() + u.QuotaSize = 10 * fileSize + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.MkdirAll(path.Join(testDir, "1", "1")) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "1", testFileName), fileSize, client) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "2", "2")) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "2", testFileName), fileSize, client) + assert.NoError(t, err) + err = client.Symlink(path.Join(testDir, "2", testFileName), path.Join(testDir, "2", testFileName+"_link")) + assert.NoError(t, err) + // trigger the compress action + err = client.Mkdir("a") + assert.NoError(t, err) + err = client.Rename("a", "b") + assert.NoError(t, err) + assert.Eventually(t, func() bool { + _, err := client.Stat(zipPath) + return err == nil + }, 3*time.Second, 100*time.Millisecond) + err = client.Remove(zipPath) + assert.NoError(t, err) + // add other 6 file, the compress action should fail with a quota error + err = writeSFTPFile(path.Join(testDir, "1", "1", testFileName), fileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "2", "2", testFileName), fileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "1", "1", testFileName+"1"), fileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "2", "2", testFileName+"2"), fileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "1", testFileName+"1"), fileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, "2", testFileName+"2"), fileSize, client) + assert.NoError(t, err) + lastReceivedEmail.reset() + err = client.Rename("b", "a") + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3*time.Second, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, `Subject: "Compress failed"`) + assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error()) + // update quota size so the user is already overquota + user.QuotaSize = 7 * fileSize + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + lastReceivedEmail.reset() + err = client.Rename("a", "b") + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3*time.Second, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, `Subject: "Compress failed"`) + assert.Contains(t, email.Data, common.ErrQuotaExceeded.Error()) + // remove the path to compress to trigger an error for size estimation + out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %s", testDir), user) + assert.NoError(t, err, string(out)) + lastReceivedEmail.reset() + err = client.Rename("b", "a") + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3*time.Second, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, `Subject: "Compress failed"`) + assert.Contains(t, email.Data, "unable to estimate archive size") + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventActionCompressQuotaFolder(t *testing.T) { + testDir := "/folder" + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCompress, + Compress: dataprovider.EventActionFsCompress{ + Name: "/{{.VirtualPath}}.zip", + Paths: []string{"/{{.VirtualPath}}", testDir}, + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test compress", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.QuotaFiles = 1000 + mappedPath := filepath.Join(os.TempDir(), "virtualpath") + folderName := filepath.Base(mappedPath) + vdirPath := "/virtualpath" + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaSize: -1, + QuotaFiles: -1, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Mkdir(testDir) + assert.NoError(t, err) + expectedQuotaSize := int64(len(testFileContent)) + expectedQuotaFiles := 1 + err = client.Symlink(path.Join(testDir, testFileName), path.Join(testDir, testFileName+"_link")) + assert.NoError(t, err) + f, err := client.Create(path.Join(testDir, testFileName)) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + info, err := client.Stat(path.Join(testDir, testFileName) + ".zip") + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), int64(0)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + expectedQuotaFiles++ + expectedQuotaSize += info.Size() + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + vfolder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, vfolder.UsedQuotaFiles) + assert.Equal(t, int64(0), vfolder.UsedQuotaSize) + // upload in the virtual path + f, err = client.Create(path.Join(vdirPath, testFileName)) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + info, err = client.Stat(path.Join(vdirPath, testFileName) + ".zip") + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), int64(0)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + expectedQuotaFiles += 2 + expectedQuotaSize += info.Size() + int64(len(testFileContent)) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + vfolder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, vfolder.UsedQuotaFiles) + assert.Equal(t, int64(0), vfolder.UsedQuotaSize) + } + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestEventActionCompressErrors(t *testing.T) { + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCompress, + Compress: dataprovider.EventActionFsCompress{ + Name: "/{{.VirtualPath}}.zip", + Paths: []string{"/{{.VirtualPath}}.zip"}, // cannot compress itself + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test compress", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.Error(t, err) + } + // try to compress a missing file + action1.Options.FsConfig.Compress.Paths = []string{"/missing file"} + _, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.Error(t, err) + } + // try to overwrite a directory + testDir := "/adir" + action1.Options.FsConfig.Compress.Name = testDir + action1.Options.FsConfig.Compress.Paths = []string{"/{{.VirtualPath}}"} + _, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Mkdir(testDir) + assert.NoError(t, err) + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.Error(t, err) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestEventActionEmailAttachments(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCompress, + Compress: dataprovider.EventActionFsCompress{ + Name: "/archive/{{.VirtualPath}}.zip", + Paths: []string{"/{{.VirtualPath}}"}, + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}}" from "{{.Name}}"`, + Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}} {{.EscapedVirtualPath}}", + Attachments: []string{"/archive/{{.VirtualPath}}.zip"}, + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test email with attachment", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + localUser, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser() + u.FsConfig.SFTPConfig.BufferSize = 1 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + cryptFsUser, _, err := httpdtest.AddUser(getCryptFsUser(), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser, cryptFsUser} { + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, `Subject: "upload" from`) + assert.Contains(t, email.Data, url.QueryEscape("/"+testFileName)) + assert.Contains(t, email.Data, "Content-Disposition: attachment") + } + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(cryptFsUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(cryptFsUser.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventActionsRetentionReports(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + testDir := "/d" + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeDataRetentionCheck, + Options: dataprovider.BaseEventActionOptions{ + RetentionConfig: dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: testDir, + Retention: 1, + DeleteEmptyDirs: true, + }, + }, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}}" from "{{.Name}}"`, + Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}}", + Attachments: []string{dataprovider.RetentionReportPlaceHolder}, + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + a3 := dataprovider.BaseEventAction{ + Name: "action3", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: fmt.Sprintf("http://%s/", httpAddr), + Timeout: 20, + Method: http.MethodPost, + Body: dataprovider.RetentionReportPlaceHolder, + }, + }, + } + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + a4 := dataprovider.BaseEventAction{ + Name: "action4", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: fmt.Sprintf("http://%s/multipart", httpAddr), + Timeout: 20, + Method: http.MethodPost, + Parts: []dataprovider.HTTPPart{ + { + Name: "reports.zip", + Filepath: dataprovider.RetentionReportPlaceHolder, + }, + }, + }, + }, + } + action4, _, err := httpdtest.AddEventAction(a4, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 3, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action4.Name, + }, + Order: 4, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + subdir := path.Join(testDir, "sub") + err = client.MkdirAll(subdir) + assert.NoError(t, err) + + lastReceivedEmail.reset() + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "upload" from "%s"`, user.Username)) + assert.Contains(t, email.Data, "Content-Disposition: attachment") + _, err = client.Stat(testDir) + assert.NoError(t, err) + _, err = client.Stat(subdir) + assert.ErrorIs(t, err, os.ErrNotExist) + + err = client.Mkdir(subdir) + assert.NoError(t, err) + newName := path.Join(testDir, testFileName) + err = client.Rename(testFileName, newName) + assert.NoError(t, err) + err = client.Chtimes(newName, time.Now().Add(-24*time.Hour), time.Now().Add(-24*time.Hour)) + assert.NoError(t, err) + + lastReceivedEmail.reset() + f, err = client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + _, err = client.Stat(subdir) + assert.ErrorIs(t, err, os.ErrNotExist) + _, err = client.Stat(subdir) + assert.ErrorIs(t, err, os.ErrNotExist) + } + // now remove the retention check to test errors + rule1.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: false, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 3, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: false, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action4.Name, + }, + Order: 4, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + StopOnFailure: false, + }, + }, + } + _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + f, err := client.Create(testFileName) + assert.NoError(t, err) + _, err = f.Write(testFileContent) + assert.NoError(t, err) + err = f.Close() + assert.Error(t, err) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action4, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleFirstUploadDownloadActions(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}}" from "{{.Name}}"`, + Body: "Fs path {{.FsPath}}, size: {{.FileSize}}, protocol: {{.Protocol}}, IP: {{.IP}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test first upload rule", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"first-upload"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + r2 := dataprovider.EventRule{ + Name: "test first download rule", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"first-download"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(32768) + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck(testFileName, testFileSize, client) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-upload" from "%s"`, user.Username)) + lastReceivedEmail.reset() + // a new upload will not produce a new notification + err = writeSFTPFileNoCheck(testFileName+"_1", 32768, client) + assert.NoError(t, err) + assert.Never(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1000*time.Millisecond, 100*time.Millisecond) + // the same for download + f, err := client.Open(testFileName) + assert.NoError(t, err) + contents := make([]byte, testFileSize) + n, err := io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + err = f.Close() + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "first-download" from "%s"`, user.Username)) + // download again + lastReceivedEmail.reset() + f, err = client.Open(testFileName) + assert.NoError(t, err) + contents = make([]byte, testFileSize) + n, err = io.ReadFull(f, contents) + assert.NoError(t, err) + assert.Equal(t, int(testFileSize), n) + err = f.Close() + assert.NoError(t, err) + assert.Never(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1000*time.Millisecond, 100*time.Millisecond) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleRenameEvent(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}}" from "{{.Name}}"`, + ContentType: 1, + Body: `

Fs path {{.FsPath}}, Name: {{.Name}}, Target path "{{.VirtualTargetDirPath}}/{{.TargetName}}", size: {{.FileSize}}

`, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test rename rule", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"rename"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.Username = "test & chars" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFileSize := int64(32768) + lastReceivedEmail.reset() + err = writeSFTPFileNoCheck(testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Mkdir("subdir") + assert.NoError(t, err) + err = client.Rename(testFileName, path.Join("/subdir", testFileName)) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "rename" from "%s"`, user.Username)) + assert.Contains(t, email.Data, "Content-Type: text/html") + assert.Contains(t, email.Data, fmt.Sprintf("Target path %q", path.Join("/subdir", testFileName))) + assert.Contains(t, email.Data, "Name: test & chars,") + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleIDPLogin(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + lastReceivedEmail.reset() + + username := `test_'idp_'login` + custom1 := `cust"oa"1` + u := map[string]any{ + "username": "{{.Name}}", + "status": 1, + "home_dir": filepath.Join(os.TempDir(), "{{.IDPFieldcustom1}}"), + "permissions": map[string][]string{ + "/": {dataprovider.PermAny}, + }, + } + userTmpl, err := json.Marshal(u) + require.NoError(t, err) + a := map[string]any{ + "username": "{{.Name}}", + "status": 1, + "permissions": []string{dataprovider.PermAdminAny}, + } + adminTmpl, err := json.Marshal(a) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeIDPAccountCheck, + Options: dataprovider.BaseEventActionOptions{ + IDPConfig: dataprovider.EventActionIDPAccountCheck{ + Mode: 1, // create if not exists + TemplateUser: string(userTmpl), + TemplateAdmin: string(adminTmpl), + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}} {{.StatusString}}"`, + Body: "{{.Name}} Custom field: {{.IDPFieldcustom1}}", + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "test rule IDP login", + Status: 1, + Trigger: dataprovider.EventTriggerIDPLogin, + Conditions: dataprovider.EventConditions{ + IDPLoginEvent: dataprovider.IDPLoginUser, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, // the rule is not sync and will be skipped + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + customFields := map[string]any{ + "custom1": custom1, + } + user, admin, err := common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginUser, + Status: 1, + }, &customFields) + assert.Nil(t, user) + assert.Nil(t, admin) + assert.NoError(t, err) + + rule1.Actions[0].Options.ExecuteSync = true + rule1, resp, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err, string(resp)) + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginUser, + Status: 1, + }, &customFields) + if assert.NotNil(t, user) { + assert.Equal(t, filepath.Join(os.TempDir(), custom1), user.GetHomeDir()) + _, err = httpdtest.RemoveUser(*user, http.StatusOK) + assert.NoError(t, err) + } + assert.Nil(t, admin) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginUser)) + assert.Contains(t, email.Data, username) + assert.Contains(t, email.Data, custom1) + + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.Nil(t, user) + assert.Nil(t, admin) + assert.NoError(t, err) + + rule1.Conditions.IDPLoginEvent = dataprovider.IDPLoginAny + rule1.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + Order: 1, + }, + } + rule1, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + + r2 := dataprovider.EventRule{ + Name: "test email on IDP login", + Status: 1, + Trigger: dataprovider.EventTriggerIDPLogin, + Conditions: dataprovider.EventConditions{ + IDPLoginEvent: dataprovider.IDPLoginAdmin, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 1, + }, + }, + } + rule2, resp, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + lastReceivedEmail.reset() + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.Nil(t, user) + if assert.NotNil(t, admin) { + assert.Equal(t, 1, admin.Status) + } + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, common.IDPLoginAdmin)) + assert.Contains(t, email.Data, username) + assert.Contains(t, email.Data, custom1) + admin.Status = 0 + _, _, err = httpdtest.UpdateAdmin(*admin, http.StatusOK) + assert.NoError(t, err) + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.Nil(t, user) + if assert.NotNil(t, admin) { + assert.Equal(t, 0, admin.Status) + } + assert.NoError(t, err) + action1.Options.IDPConfig.Mode = 0 + action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.Nil(t, user) + if assert.NotNil(t, admin) { + assert.Equal(t, 1, admin.Status) + } + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(*admin, http.StatusOK) + assert.NoError(t, err) + + r3 := dataprovider.EventRule{ + Name: "test rule2 IDP login", + Status: 1, + Trigger: dataprovider.EventTriggerIDPLogin, + Conditions: dataprovider.EventConditions{ + IDPLoginEvent: dataprovider.IDPLoginAny, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule3, resp, err := httpdtest.AddEventRule(r3, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.Nil(t, user) + assert.Nil(t, admin) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "more than one account check action rules matches") + } + + _, err = httpdtest.RemoveEventRule(rule3, http.StatusOK) + assert.NoError(t, err) + + action1.Options.IDPConfig.TemplateAdmin = `{}` + action1, _, err = httpdtest.UpdateEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, _, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.ErrorIs(t, err, util.ErrValidation) + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + + user, admin, err = common.HandleIDPLoginEvent(common.EventParams{ + Name: username, + Event: common.IDPLoginAdmin, + Status: 1, + }, &customFields) + assert.Nil(t, user) + assert.Nil(t, admin) + assert.NoError(t, err) + + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleEmailField(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + lastReceivedEmail.reset() + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"{{.Email}}"}, + Subject: `"{{.Event}}" from "{{.Name}}"`, + Body: "Sample email body", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `"Failure`, + Body: "{{.ErrorString}}", + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "r1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"mkdir"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + }, + }, + } + r2 := dataprovider.EventRule{ + Name: "test rule2", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + Options: dataprovider.ConditionOptions{ + ProviderObjects: []string{"user"}, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Email = "user@example.com" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, user.Email)) + assert.Contains(t, email.Data, `Subject: "add" from "admin"`) + + // if we add a user without email the notification will fail + lastReceivedEmail.reset() + u1 := getTestUser() + u1.Username += "_1" + user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "failure@example.com")) + assert.Contains(t, email.Data, `no recipient addresses set`) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err = client.Mkdir(testFileName) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, user.Email)) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "mkdir" from "%s"`, user.Username)) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleCertificate(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notify@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + lastReceivedEmail.reset() + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.com"}, + Subject: `"{{.Event}} {{.StatusString}}"`, + ContentType: 0, + Body: "Domain: {{.Name}} Timestamp: {{.Timestamp}} {{.ErrorString}} Date time: {{.DateTime}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeFolderQuotaReset, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "test rule certificate", + Status: 1, + Trigger: dataprovider.EventTriggerCertificate, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + r2 := dataprovider.EventRule{ + Name: "test rule 2", + Status: 1, + Trigger: dataprovider.EventTriggerCertificate, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + + renewalEvent := "Certificate renewal" + + common.HandleCertificateEvent(common.EventParams{ + Name: "example.com", + Timestamp: time.Now(), + Status: 1, + Event: renewalEvent, + }) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s OK"`, renewalEvent)) + assert.Contains(t, email.Data, "Content-Type: text/plain") + assert.Contains(t, email.Data, `Domain: example.com Timestamp`) + + lastReceivedEmail.reset() + dateTime := time.Now() + params := common.EventParams{ + Name: "example.com", + Timestamp: dateTime, + Status: 2, + Event: renewalEvent, + } + errRenew := errors.New("generic renew error") + params.AddError(errRenew) + common.HandleCertificateEvent(params) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email = lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "test@example.com")) + assert.Contains(t, email.Data, fmt.Sprintf(`Subject: "%s KO"`, renewalEvent)) + assert.Contains(t, email.Data, `Domain: example.com Timestamp`) + assert.Contains(t, email.Data, dateTime.UTC().Format("2006-01-02T15:04:05.000")) + assert.Contains(t, email.Data, errRenew.Error()) + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + // ignored no more certificate rules + common.HandleCertificateEvent(common.EventParams{ + Name: "example.com", + Timestamp: time.Now(), + Status: 1, + Event: renewalEvent, + }) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleIPBlocked(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Threshold = 3 + cfg.DefenderConfig.ScoreLimitExceeded = 2 + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test3@example.com", "test4@example.com"}, + Subject: `New "{{.Event}}"`, + Body: "IP: {{.IP}} Timestamp: {{.Timestamp}}", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeFolderQuotaReset, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "test rule ip blocked", + Status: 1, + Trigger: dataprovider.EventTriggerIPBlocked, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + r2 := dataprovider.EventRule{ + Name: "test rule 2", + Status: 1, + Trigger: dataprovider.EventTriggerIPBlocked, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + lastReceivedEmail.reset() + time.Sleep(300 * time.Millisecond) + assert.Empty(t, lastReceivedEmail.get().From, lastReceivedEmail.get().Data) + + for i := 0; i < 3; i++ { + user.Password = "wrong_pwd" + _, _, err = getSftpClient(user) + assert.Error(t, err) + } + // the client is now banned + user.Password = defaultPassword + _, _, err = getSftpClient(user) + assert.Error(t, err) + // check the email notification + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 2) + assert.True(t, slices.Contains(email.To, "test3@example.com")) + assert.True(t, slices.Contains(email.To, "test4@example.com")) + assert.Contains(t, email.Data, `Subject: New "IP Blocked"`) + + err = dataprovider.DeleteEventRule(rule1.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventRule(rule2.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(action1.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(action2.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestEventRuleRotateLog(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeRotateLogs, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"success@example.net"}, + Subject: `OK`, + Body: "OK action", + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"mkdir"}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user.Username, + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err := client.Mkdir("just a test dir") + assert.NoError(t, err) + // just check that the action is executed + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.To, "success@example.net") + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRuleInactivityCheck(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeUserInactivityCheck, + Options: dataprovider.BaseEventActionOptions{ + UserInactivityConfig: dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 20, + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"success@example.net"}, + Subject: `OK`, + Body: "OK action", + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"mkdir"}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: user.Username, + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + }, + }, + } + rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err := client.Mkdir("just a test dir") + assert.NoError(t, err) + // just check that the action is executed + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.To, "success@example.net") + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestEventRulePasswordExpiration(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.net"}, + Subject: `Failure`, + Body: "Failure action", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypePasswordExpirationCheck, + Options: dataprovider.BaseEventActionOptions{ + PwdExpirationConfig: dataprovider.EventActionPasswordExpiration{ + Threshold: 10, + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + a3 := dataprovider.BaseEventAction{ + Name: "a3", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"success@example.net"}, + Subject: `OK`, + Body: "OK action", + }, + }, + } + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"mkdir"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 2, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + }, + } + rule1, resp, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + dirName := "aTestDir" + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err := client.Mkdir(dirName) + assert.NoError(t, err) + // the user has no password expiration, the check will be skipped and the ok action executed + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.To, "success@example.net") + err = client.RemoveDirectory(dirName) + assert.NoError(t, err) + } + user.Filters.PasswordExpiration = 20 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err := client.Mkdir(dirName) + assert.NoError(t, err) + // the passowrd is not about to expire, the check will be skipped and the ok action executed + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.To, "success@example.net") + err = client.RemoveDirectory(dirName) + assert.NoError(t, err) + } + user.Filters.PasswordExpiration = 5 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err := client.Mkdir(dirName) + assert.NoError(t, err) + // the passowrd is about to expire, the user has no email, the failure action will be executed + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.Contains(t, email.To, "failure@example.net") + err = client.RemoveDirectory(dirName) + assert.NoError(t, err) + } + // remove the success action + rule1.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + } + _, _, err = httpdtest.UpdateEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + user.Email = "user@example.net" + user.Filters.AdditionalEmails = []string{"additional@example.net"} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + lastReceivedEmail.reset() + err := client.Mkdir(dirName) + assert.NoError(t, err) + // the passowrd expiration will be notified + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 1500*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 2) + assert.Contains(t, email.To, user.Email) + assert.Contains(t, email.To, user.Filters.AdditionalEmails[0]) + assert.Contains(t, email.Data, "your SFTPGo password expires in 5 days") + err = client.RemoveDirectory(dirName) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestSyncUploadAction(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + uploadScriptPath := filepath.Join(os.TempDir(), "upload.sh") + common.Config.Actions.ExecuteOn = []string{"upload"} + common.Config.Actions.ExecuteSync = []string{"upload"} + common.Config.Actions.Hook = uploadScriptPath + + u := getTestUser() + u.QuotaFiles = 1000 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + movedFileName := "moved.dat" + movedPath := filepath.Join(user.HomeDir, movedFileName) + err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 0), 0755) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + size := int64(32768) + err = writeSFTPFileNoCheck(testFileName, size, client) + assert.NoError(t, err) + _, err = client.Stat(testFileName) + assert.Error(t, err) + info, err := client.Stat(movedFileName) + if assert.NoError(t, err) { + assert.Equal(t, size, info.Size()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, size, user.UsedQuotaSize) + // test some hook failure + // the uploaded file is moved and the hook fails, it will be not removed from the quota + err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 1), 0755) + assert.NoError(t, err) + err = writeSFTPFileNoCheck(testFileName+"_1", size, client) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, size*2, user.UsedQuotaSize) + + // the uploaded file is not moved and the hook fails, the uploaded file will be deleted + // and removed from the quota + movedPath = filepath.Join(user.HomeDir, "missing dir", movedFileName) + err = os.WriteFile(uploadScriptPath, getUploadScriptContent(movedPath, "", 1), 0755) + assert.NoError(t, err) + err = writeSFTPFileNoCheck(testFileName+"_2", size, client) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, size*2, user.UsedQuotaSize) + // overwrite an existing file + _, err = client.Stat(movedFileName) + assert.NoError(t, err) + err = writeSFTPFileNoCheck(movedFileName, size, client) + assert.Error(t, err) + _, err = client.Stat(movedFileName) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, size, user.UsedQuotaSize) + } + + err = os.Remove(uploadScriptPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = nil + common.Config.Actions.ExecuteSync = nil + common.Config.Actions.Hook = uploadScriptPath +} + +func TestQuotaTrackDisabled(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.TrackQuota = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"1") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestGetQuotaError(t *testing.T) { + if dataprovider.GetProviderStatus().Driver == "memory" { + t.Skip("this test is not available with the memory provider") + } + u := getTestUser() + u.TotalDataTransfer = 2000 + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + vdirPath := "/vpath" + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaSize: 0, + QuotaFiles: 10, + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(testFileName, 32, client) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + + err = client.Rename(testFileName, path.Join(vdirPath, testFileName)) + assert.Error(t, err) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestRetentionAPI(t *testing.T) { + u := getTestUser() + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermOverwrite, dataprovider.PermDownload, dataprovider.PermCreateDirs, + dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + uploadPath := path.Join(testDir, testFileName) + + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = writeSFTPFile(uploadPath, 32, client) + assert.NoError(t, err) + + folderRetention := []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 24, + DeleteEmptyDirs: true, + }, + } + check := common.RetentionCheck{ + Folders: folderRetention, + } + c := common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + _, err = client.Stat(uploadPath) + assert.NoError(t, err) + + err = client.Chtimes(uploadPath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) + assert.NoError(t, err) + + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + _, err = client.Stat(uploadPath) + assert.ErrorIs(t, err, os.ErrNotExist) + + _, err = client.Stat(testDir) + assert.ErrorIs(t, err, os.ErrNotExist) + + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = writeSFTPFile(uploadPath, 32, client) + assert.NoError(t, err) + + check.Folders[0].DeleteEmptyDirs = false + err = client.Chtimes(uploadPath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) + assert.NoError(t, err) + + c = common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + _, err = client.Stat(uploadPath) + assert.ErrorIs(t, err, os.ErrNotExist) + + _, err = client.Stat(testDir) + assert.NoError(t, err) + + err = writeSFTPFile(uploadPath, 32, client) + assert.NoError(t, err) + err = client.Chtimes(uploadPath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) + assert.NoError(t, err) + conn.Close() + client.Close() + } + + // remove delete permissions to the user, it will be automatically granted + user.Permissions["/"+testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermCreateDirs, dataprovider.PermChtimes} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + innerUploadFilePath := path.Join("/"+testDir, testDir, testFileName) + err = client.Mkdir(path.Join(testDir, testDir)) + assert.NoError(t, err) + + err = writeSFTPFile(innerUploadFilePath, 32, client) + assert.NoError(t, err) + err = client.Chtimes(innerUploadFilePath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) + assert.NoError(t, err) + + folderRetention := []dataprovider.FolderRetention{ + { + Path: "/missing", + Retention: 24, + }, + { + Path: "/" + testDir, + Retention: 24, + DeleteEmptyDirs: true, + }, + { + Path: path.Dir(innerUploadFilePath), + Retention: 0, + }, + } + check := common.RetentionCheck{ + Folders: folderRetention, + } + c := common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + _, err = client.Stat(uploadPath) + assert.ErrorIs(t, err, os.ErrNotExist) + _, err = client.Stat(innerUploadFilePath) + assert.NoError(t, err) + + folderRetention = []dataprovider.FolderRetention{ + + { + Path: "/" + testDir, + Retention: 24, + DeleteEmptyDirs: true, + }, + } + + check = common.RetentionCheck{ + Folders: folderRetention, + } + c = common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + _, err = client.Stat(innerUploadFilePath) + assert.ErrorIs(t, err, os.ErrNotExist) + conn.Close() + client.Close() + } + // finally test some errors removing files or folders + if runtime.GOOS != osWindows { + dirPath := filepath.Join(user.HomeDir, "adir", "sub") + err := os.MkdirAll(dirPath, os.ModePerm) + assert.NoError(t, err) + filePath := filepath.Join(dirPath, "f.dat") + err = os.WriteFile(filePath, nil, os.ModePerm) + assert.NoError(t, err) + + err = os.Chtimes(filePath, time.Now().Add(-72*time.Hour), time.Now().Add(-72*time.Hour)) + assert.NoError(t, err) + + err = os.Chmod(dirPath, 0001) + assert.NoError(t, err) + + folderRetention := []dataprovider.FolderRetention{ + + { + Path: "/adir", + Retention: 24, + DeleteEmptyDirs: true, + }, + } + + check := common.RetentionCheck{ + Folders: folderRetention, + } + c := common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.ErrorIs(t, err, os.ErrPermission) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + err = os.Chmod(dirPath, 0555) + assert.NoError(t, err) + + c = common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.ErrorIs(t, err, os.ErrPermission) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + err = os.Chmod(dirPath, os.ModePerm) + assert.NoError(t, err) + + check = common.RetentionCheck{ + Folders: folderRetention, + } + c = common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + assert.NoDirExists(t, dirPath) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1*time.Second, 50*time.Millisecond) +} + +func TestPerUserTransferLimits(t *testing.T) { + oldMaxPerHostConns := common.Config.MaxPerHostConnections + + common.Config.MaxPerHostConnections = 2 + + u := getTestUser() + u.UploadBandwidth = 32 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + if !assert.NoError(t, err) { + printLatestLogs(20) + } + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + var wg sync.WaitGroup + numErrors := 0 + for i := 0; i <= 2; i++ { + wg.Add(1) + go func(counter int) { + defer wg.Done() + + time.Sleep(20 * time.Millisecond) + err := writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client) + if err != nil { + numErrors++ + } + }(i) + } + wg.Wait() + + assert.Equal(t, 1, numErrors) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldMaxPerHostConns +} + +func TestMaxSessionsSameConnection(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 32 + u.MaxSessions = 2 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + var wg sync.WaitGroup + numErrors := 0 + for i := 0; i <= 2; i++ { + wg.Add(1) + go func(counter int) { + defer wg.Done() + + var err error + if counter < 2 { + err = writeSFTPFile(fmt.Sprintf("%s_%d", testFileName, counter), 64*1024, client) + } else { + // wait for the transfers to start + time.Sleep(50 * time.Millisecond) + _, _, err = getSftpClient(user) + } + if err != nil { + numErrors++ + } + }(i) + } + + wg.Wait() + assert.Equal(t, 1, numErrors) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRenameDir(t *testing.T) { + u := getTestUser() + testDir := "/dir-to-rename" + u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), 32, client) + assert.NoError(t, err) + err = client.Rename(testDir, testDir+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestBuiltinKeyboardInteractiveAuthentication(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + authMethods := []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { + return []string{defaultPassword}, nil + }), + } + conn, client, err := getCustomAuthSftpClient(user, authMethods) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + } + // add multi-factor authentication + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) + assert.NoError(t, err) + passwordAsked := false + passcodeAsked := false + authMethods = []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) { + var answers []string + if strings.HasPrefix(questions[0], "Password") { + answers = append(answers, defaultPassword) + passwordAsked = true + } else { + answers = append(answers, passcode) + passcodeAsked = true + } + return answers, nil + }), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + } + assert.True(t, passwordAsked) + assert.True(t, passcodeAsked) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMultiStepBuiltinKeyboardAuth(t *testing.T) { + u := getTestUser() + u.PublicKeys = []string{testPubKey} + u.Filters.DeniedLoginMethods = []string{ + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + assert.NoError(t, err) + // public key + password + authMethods := []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { + return []string{defaultPassword}, nil + }), + } + conn, client, err := getCustomAuthSftpClient(user, authMethods) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + } + // add multi-factor authentication + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) + assert.NoError(t, err) + // public key + passcode + authMethods = []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { + return []string{passcode}, nil + }), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRenameSymlink(t *testing.T) { + u := getTestUser() + testDir := "/dir-no-create-links" + otherDir := "otherdir" + u.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir(otherDir) + assert.NoError(t, err) + err = client.Symlink(otherDir, otherDir+".link") + assert.NoError(t, err) + err = client.Rename(otherDir+".link", path.Join(testDir, "symlink")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(otherDir+".link", "allowed_link") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSplittedDeletePerms(t *testing.T) { + u := getTestUser() + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDeleteDirs, + dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.Error(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.RemoveDirectory(testDir) + assert.NoError(t, err) + } + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDeleteFiles, + dataprovider.PermCreateDirs, dataprovider.PermOverwrite} + _, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") + assert.NoError(t, err) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.RemoveDirectory(testDir) + assert.Error(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSplittedRenamePerms(t *testing.T) { + u := getTestUser() + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermRenameDirs, + dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"_renamed") + assert.Error(t, err) + err = client.Rename(testDir, testDir+"_renamed") + assert.NoError(t, err) + } + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermRenameFiles, + dataprovider.PermCreateDirs, dataprovider.PermOverwrite} + _, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") + assert.NoError(t, err) + + conn, client, err = getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"_renamed") + assert.NoError(t, err) + err = client.Rename(testDir, testDir+"_renamed") + assert.Error(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPLoopError(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 2525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + user1 := getTestUser() + user2 := getTestUser() + user1.Username += "1" + user2.Username += "2" + // user1 is a local account with a virtual SFTP folder to user2 + // user2 has user1 as SFTP fs + f := vfs.BaseVirtualFolder{ + Name: "sftp", + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user2.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + folder, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user1.VirtualFolders = append(user1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder.Name, + }, + VirtualPath: "/vdir", + }) + + user2.FsConfig.Provider = sdk.SFTPFilesystemProvider + user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user1.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + + user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + a1 := dataprovider.BaseEventAction{ + Name: "a1", + Type: dataprovider.ActionTypeUserQuotaReset, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + a2 := dataprovider.BaseEventAction{ + Name: "a2", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `Failed action"`, + Body: "Test body", + }, + }, + } + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + r1 := dataprovider.EventRule{ + Name: "rule1", + Status: 1, + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"update"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + + lastReceivedEmail.reset() + _, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return lastReceivedEmail.get().From != "" + }, 3000*time.Millisecond, 100*time.Millisecond) + email := lastReceivedEmail.get() + assert.Len(t, email.To, 1) + assert.True(t, slices.Contains(email.To, "failure@example.com")) + assert.Contains(t, email.Data, `Subject: Failed action`) + + user1.VirtualFolders[0].FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + user2.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + + conn := common.NewBaseConnection("", common.ProtocolWebDAV, "", "", user1) + _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) + assert.ErrorIs(t, err, os.ErrPermission) + + conn = common.NewBaseConnection("", common.ProtocolSFTP, "", "", user1) + _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) + assert.Error(t, err) + conn = common.NewBaseConnection("", common.ProtocolFTP, "", "", user1) + _, _, err = conn.GetFsAndResolvedPath(user1.VirtualFolders[0].VirtualPath) + assert.Error(t, err) + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestNonLocalCrossRename(t *testing.T) { + baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestUser() + u.HomeDir += "_folders" + u.Username += "_folders" + mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") + folderNameSFTP := filepath.Base(mappedPathSFTP) + vdirSFTPPath := "/vdir/sftp" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + }, + VirtualPath: vdirSFTPPath, + }) + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + }, + VirtualPath: vdirCryptPath, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirSFTPPath, testFileName), 8192, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), 16384, client) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirSFTPPath, testFileName), path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirSFTPPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirSFTPPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirSFTPPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + // rename on local fs or on the same folder must work + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = client.Rename(path.Join(vdirSFTPPath, testFileName), path.Join(vdirSFTPPath, testFileName+"_rename")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirCryptPath, testFileName+"_rename")) + assert.NoError(t, err) + // renaming a virtual folder is not allowed + err = client.Rename(vdirSFTPPath, vdirSFTPPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(vdirCryptPath, vdirCryptPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(vdirCryptPath, path.Join(vdirCryptPath, "rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Mkdir(path.Join(vdirCryptPath, "subcryptdir")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirCryptPath, "subcryptdir"), vdirCryptPath) + assert.ErrorIs(t, err, os.ErrPermission) + // renaming root folder is not allowed + err = client.Rename("/", "new_name") + assert.ErrorIs(t, err, os.ErrPermission) + // renaming a path to a virtual folder is not allowed + err = client.Rename("/vdir", "new_vdir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathSFTP) + assert.NoError(t, err) +} + +func TestNonLocalCrossRenameNonLocalBaseUser(t *testing.T) { + baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestSFTPUser() + mappedPathLocal := filepath.Join(os.TempDir(), "local") + folderNameLocal := filepath.Base(mappedPathLocal) + vdirLocalPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameLocal, + }, + VirtualPath: vdirLocalPath, + }) + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + }, + VirtualPath: vdirCryptPath, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderNameLocal, + MappedPath: mappedPathLocal, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + err = writeSFTPFile(testFileName, 4096, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirLocalPath, testFileName), 8192, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), 16384, client) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirLocalPath, testFileName), path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirLocalPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirCryptPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(testFileName, path.Join(vdirLocalPath, testFileName+".rename")) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirLocalPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(path.Join(vdirCryptPath, testFileName), testFileName+".rename") + assert.ErrorIs(t, err, os.ErrPermission) + // rename on local fs or on the same folder must work + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = client.Rename(path.Join(vdirLocalPath, testFileName), path.Join(vdirLocalPath, testFileName+"_rename")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirCryptPath, testFileName), path.Join(vdirCryptPath, testFileName+"_rename")) + assert.NoError(t, err) + // renaming a virtual folder is not allowed + err = client.Rename(vdirLocalPath, vdirLocalPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Rename(vdirCryptPath, vdirCryptPath+"_rename") + assert.ErrorIs(t, err, os.ErrPermission) + // renaming root folder is not allowed + err = client.Rename("/", "new_name") + assert.ErrorIs(t, err, os.ErrPermission) + // renaming a path to a virtual folder is not allowed + err = client.Rename("/vdir", "new_vdir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameLocal}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathLocal) + assert.NoError(t, err) +} + +func TestCopyAndRemoveSSHCommands(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1000 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + fileSize := int64(32) + err = writeSFTPFile(testFileName, fileSize, client) + assert.NoError(t, err) + + testFileNameCopy := testFileName + "_copy" + out, err := runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) + assert.NoError(t, err, string(out)) + // the resolved destination path match the source path + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, path.Dir(testFileName)), user) + assert.Error(t, err, string(out)) + + info, err := client.Stat(testFileNameCopy) + if assert.NoError(t, err) { + assert.Equal(t, fileSize, info.Size()) + } + + testDir := "test dir" + err = client.Mkdir(testDir) + assert.NoError(t, err) + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s '%s'`, testFileName, testDir), user) + assert.NoError(t, err, string(out)) + info, err = client.Stat(path.Join(testDir, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, fileSize, info.Size()) + } + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3*fileSize, user.UsedQuotaSize) + assert.Equal(t, 3, user.UsedQuotaFiles) + + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %s", testFileNameCopy), user) + assert.NoError(t, err, string(out)) + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove '%s'`, testDir), user) + assert.NoError(t, err, string(out)) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, fileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + _, err = client.Stat(testFileNameCopy) + assert.ErrorIs(t, err, os.ErrNotExist) + // create a dir tree + dir1 := "dir1" + dir2 := "dir 2" + err = client.MkdirAll(path.Join(dir1, dir2)) + assert.NoError(t, err) + toCreate := []string{ + path.Join(dir1, testFileName), + path.Join(dir1, dir2, testFileName), + } + for _, p := range toCreate { + err = writeSFTPFile(p, fileSize, client) + assert.NoError(t, err) + } + // create a symlink, copying a symlink is not supported + err = client.Symlink(path.Join("/", dir1, testFileName), path.Join("/", dir1, testFileName+"_link")) + assert.NoError(t, err) + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1, testFileName+"_link"), + path.Join("/", testFileName+"_link")), user) + assert.Error(t, err, string(out)) + // copying a dir inside itself should fail + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1), + path.Join("/", dir1, "sub")), user) + assert.Error(t, err, string(out)) + // copy source and dest must differ + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1), + path.Join("/", dir1)), user) + assert.Error(t, err, string(out)) + // copy a missing file/dir should fail + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", "missing_entry"), + path.Join("/", dir1)), user) + assert.Error(t, err, string(out)) + // try to overwrite a file with a dir + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join("/", dir1), testFileName), user) + assert.Error(t, err, string(out)) + + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s "%s"`, dir1, dir2), user) + assert.NoError(t, err, string(out)) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 5*fileSize, user.UsedQuotaSize) + assert.Equal(t, 5, user.UsedQuotaFiles) + + // copy again, quota must remain unchanged + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s/ "%s"`, dir1, dir2), user) + assert.NoError(t, err, string(out)) + _, err = client.Stat(dir2) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 5*fileSize, user.UsedQuotaSize) + assert.Equal(t, 5, user.UsedQuotaFiles) + // now copy inside target + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s "%s"`, dir1, dir2), user) + assert.NoError(t, err, string(out)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 7*fileSize, user.UsedQuotaSize) + assert.Equal(t, 7, user.UsedQuotaFiles) + + for _, p := range []string{dir1, dir2} { + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove "%s"`, p), user) + assert.NoError(t, err, string(out)) + _, err = client.Stat(p) + assert.ErrorIs(t, err, os.ErrNotExist) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, fileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + // test quota errors + user.QuotaFiles = 1 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // quota files exceeded + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) + assert.Error(t, err, string(out)) + user.QuotaFiles = 1000 + user.QuotaSize = fileSize + 1 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // quota size exceeded after the copy + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) + assert.Error(t, err, string(out)) + user.QuotaSize = fileSize - 1 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // quota size exceeded + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileNameCopy), user) + assert.Error(t, err, string(out)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCopyAndRemovePermissions(t *testing.T) { + u := getTestUser() + restrictedPath := "/dir/path" + patternFilterPath := "/patterns" + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: patternFilterPath, + DeniedPatterns: []string{"*.dat"}, + }, + } + u.Permissions[restrictedPath] = []string{} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.MkdirAll(restrictedPath) + assert.NoError(t, err) + err = client.MkdirAll(patternFilterPath) + assert.NoError(t, err) + err = writeSFTPFile(testFileName, 100, client) + assert.NoError(t, err) + // getting file writer will fail + out, err := runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) + assert.Error(t, err, string(out)) + // file pattern not allowed + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, patternFilterPath), user) + assert.Error(t, err, string(out)) + + testDir := path.Join("/", path.Base(restrictedPath)) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), 100, client) + assert.NoError(t, err) + // creating target dir will fail + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s/`, testDir, restrictedPath), user) + assert.Error(t, err, string(out)) + // get dir contents will fail + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s /`, restrictedPath), user) + assert.Error(t, err, string(out)) + // get dir contents will fail + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove %s`, restrictedPath), user) + assert.Error(t, err, string(out)) + // give list dir permissions and retry, now delete will fail + user.Permissions[restrictedPath] = []string{dataprovider.PermListItems, dataprovider.PermUpload} + user.Permissions[testDir] = []string{dataprovider.PermListItems} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // no copy permission + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) + assert.Error(t, err, string(out)) + user.Permissions[restrictedPath] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermCopy} + user.Permissions[testDir] = []string{dataprovider.PermListItems, dataprovider.PermCopy} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) + assert.NoError(t, err, string(out)) + // overwrite will fail, no permission + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, testFileName, restrictedPath), user) + assert.Error(t, err, string(out)) + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-remove %s`, restrictedPath), user) + assert.Error(t, err, string(out)) + // try to copy a file from testDir, we have only list permissions so getFileReader will fail + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, path.Join(testDir, testFileName), testFileName+".copy"), user) + assert.Error(t, err, string(out)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCrossFoldersCopy(t *testing.T) { + baseUser, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + + u := getTestUser() + u.Username += "_1" + u.HomeDir = filepath.Join(os.TempDir(), u.Username) + u.QuotaFiles = 1000 + mappedPath1 := filepath.Join(os.TempDir(), "mapped1") + folderName1 := filepath.Base(mappedPath1) + vpath1 := "/vdirs/vdir1" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vpath1, + QuotaSize: -1, + QuotaFiles: -1, + }) + mappedPath2 := filepath.Join(os.TempDir(), "mapped1", "dir", "mapped2") + folderName2 := filepath.Base(mappedPath2) + vpath2 := "/vdirs/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vpath2, + QuotaSize: -1, + QuotaFiles: -1, + }) + mappedPath3 := filepath.Join(os.TempDir(), "mapped3") + folderName3 := filepath.Base(mappedPath3) + vpath3 := "/vdirs/vdir3" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName3, + }, + VirtualPath: vpath3, + QuotaSize: -1, + QuotaFiles: -1, + }) + mappedPath4 := filepath.Join(os.TempDir(), "mapped4") + folderName4 := filepath.Base(mappedPath4) + vpath4 := "/vdirs/vdir4" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName4, + }, + VirtualPath: vpath4, + QuotaSize: -1, + QuotaFiles: -1, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderName3, + MappedPath: mappedPath3, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + f4 := vfs.BaseVirtualFolder{ + Name: folderName4, + MappedPath: mappedPath4, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f4, http.StatusCreated) + assert.NoError(t, err) + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + baseFileSize := int64(100) + err = writeSFTPFile(path.Join(vpath1, testFileName), baseFileSize+1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vpath2, testFileName), baseFileSize+2, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vpath3, testFileName), baseFileSize+3, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vpath4, testFileName), baseFileSize+4, client) + assert.NoError(t, err) + // cannot remove a directory with virtual folders inside + out, err := runSSHCommand(fmt.Sprintf(`sftpgo-remove %s`, path.Dir(vpath1)), user) + assert.Error(t, err, string(out)) + // copy across virtual folders + copyDir := "/copy" + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s/`, path.Dir(vpath1), copyDir), user) + assert.NoError(t, err, string(out)) + // check the copy + info, err := client.Stat(path.Join(copyDir, vpath1, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, baseFileSize+1, info.Size()) + } + info, err = client.Stat(path.Join(copyDir, vpath2, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, baseFileSize+2, info.Size()) + } + info, err = client.Stat(path.Join(copyDir, vpath3, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, baseFileSize+3, info.Size()) + } + info, err = client.Stat(path.Join(copyDir, vpath4, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, baseFileSize+4, info.Size()) + } + // nested fs paths + out, err = runSSHCommand(fmt.Sprintf(`sftpgo-copy %s %s`, vpath1, vpath2), user) + assert.Error(t, err, string(out)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + for _, folderName := range []string{folderName1, folderName2, folderName3, folderName4} { + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) + assert.NoError(t, err) + } +} + +func TestHTTPFs(t *testing.T) { + u := getTestUserWithHTTPFs() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + + conn := common.NewBaseConnection(xid.New().String(), common.ProtocolFTP, "", "", user) + err = conn.CreateDir(httpFsWellKnowDir, false) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(os.TempDir(), "httpfs", defaultHTTPFsUsername, httpFsWellKnowDir, "file.txt"), []byte("data"), 0666) + assert.NoError(t, err) + + err = conn.Copy(httpFsWellKnowDir, httpFsWellKnowDir+"_copy") + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestProxyProtocol(t *testing.T) { + resp, err := httpclient.Get(fmt.Sprintf("http://%v", httpProxyAddr)) + if !assert.Error(t, err) { + resp.Body.Close() + } +} + +func TestSetProtocol(t *testing.T) { + conn := common.NewBaseConnection("id", "sshd_exec", "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}) + conn.SetProtocol(common.ProtocolSCP) + require.Equal(t, "SCP_id", conn.GetID()) +} + +func TestGetFsError(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + conn := common.NewBaseConnection("", common.ProtocolFTP, "", "", u) + _, _, err := conn.GetFsAndResolvedPath("/vpath") + assert.Error(t, err) +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + logger.WarnToConsole("tcp server %v not listening: %v", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + logger.InfoToConsole("tcp server %v now listening", address) + conn.Close() + break + } +} + +func checkBasicSFTP(client *sftp.Client) error { + _, err := client.Getwd() + if err != nil { + return err + } + _, err = client.ReadDir(".") + return err +} + +func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMethod) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: authMethods, + Timeout: 5 * time.Second, + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func runSSHCommand(command string, user dataprovider.User) ([]byte, error) { + var sshSession *ssh.Session + var output []byte + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return output, err + } + defer conn.Close() + sshSession, err = conn.NewSession() + if err != nil { + return output, err + } + var stdout, stderr bytes.Buffer + sshSession.Stdout = &stdout + sshSession.Stderr = &stderr + err = sshSession.Run(command) + if err != nil { + return nil, fmt.Errorf("failed to run command %v: %v", command, stderr.Bytes()) + } + return stdout.Bytes(), err +} + +func getWebDavClient(user dataprovider.User) *gowebdav.Client { + rootPath := fmt.Sprintf("http://localhost:%d/", webDavServerPort) + pwd := defaultPassword + if user.Password != "" { + pwd = user.Password + } + client := gowebdav.NewClient(rootPath, user.Username, pwd) + client.SetTimeout(10 * time.Second) + return client +} + +func getTestUser() dataprovider.User { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Status: 1, + ExpirationDate: 0, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = allPerms + return user +} + +func getTestSFTPUser() dataprovider.User { + u := getTestUser() + u.Username = defaultSFTPUsername + u.FsConfig.Provider = sdk.SFTPFilesystemProvider + u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + return u +} + +func getCryptFsUser() dataprovider.User { + u := getTestUser() + u.Username += "_crypt" + u.FsConfig.Provider = sdk.CryptedFilesystemProvider + u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) + return u +} + +func getTestUserWithHTTPFs() dataprovider.User { + u := getTestUser() + u.FsConfig.Provider = sdk.HTTPFilesystemProvider + u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), + Username: defaultHTTPFsUsername, + }, + } + return u +} + +func writeSFTPFile(name string, size int64, client *sftp.Client) error { + err := writeSFTPFileNoCheck(name, size, client) + if err != nil { + return err + } + info, err := client.Stat(name) + if err != nil { + return err + } + if info.Size() != size { + return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size()) + } + return nil +} + +func writeSFTPFileNoCheck(name string, size int64, client *sftp.Client) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + f, err := client.Create(name) + if err != nil { + return err + } + _, err = io.Copy(f, bytes.NewBuffer(content)) + if err != nil { + f.Close() + return err + } + return f.Close() +} + +func getUploadScriptEnvContent(envVar string) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("if [ -z \"$%s\" ]\n", envVar))...) + content = append(content, []byte("then\n")...) + content = append(content, []byte(" exit 1\n")...) + content = append(content, []byte("else\n")...) + content = append(content, []byte(" exit 0\n")...) + content = append(content, []byte("fi\n")...) + return content +} + +func getUploadScriptContent(movedPath, logFilePath string, exitStatus int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte("sleep 1\n")...) + if logFilePath != "" { + content = append(content, []byte(fmt.Sprintf("echo $@ > %v\n", logFilePath))...) + } + content = append(content, []byte(fmt.Sprintf("mv ${SFTPGO_ACTION_PATH} %v\n", movedPath))...) + content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...) + return content +} + +func getSaveProviderObjectScriptContent(outFilePath string, exitStatus int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("echo ${SFTPGO_OBJECT_DATA} > %v\n", outFilePath))...) + content = append(content, []byte(fmt.Sprintf("exit %d", exitStatus))...) + return content +} + +func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) { + return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: algo, + }) +} + +func isDbDefenderSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} + +func getEncryptedFileSize(size int64) (int64, error) { + encSize, err := sio.EncryptedSize(uint64(size)) + return int64(encSize) + 33, err +} + +func printLatestLogs(maxNumberOfLines int) { + var lines []string + f, err := os.Open(logFilePath) + if err != nil { + return + } + defer f.Close() + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()+"\r\n") + for len(lines) > maxNumberOfLines { + lines = lines[1:] + } + } + if scanner.Err() != nil { + logger.WarnToConsole("Unable to print latest logs: %v", scanner.Err()) + return + } + for _, line := range lines { + logger.DebugToConsole("%s", line) + } +} + +type receivedEmail struct { + sync.RWMutex + From string + To []string + Data string +} + +func (e *receivedEmail) set(from string, to []string, data []byte) { + e.Lock() + defer e.Unlock() + + e.From = from + e.To = to + e.Data = strings.ReplaceAll(string(data), "=\r\n", "") +} + +func (e *receivedEmail) reset() { + e.Lock() + defer e.Unlock() + + e.From = "" + e.To = nil + e.Data = "" +} + +func (e *receivedEmail) get() receivedEmail { + e.RLock() + defer e.RUnlock() + + return receivedEmail{ + From: e.From, + To: e.To, + Data: e.Data, + } +} + +func startHTTPFs() { + go func() { + readdirCallback := func(name string) []os.FileInfo { + if name == httpFsWellKnowDir { + return []os.FileInfo{vfs.NewFileInfo("ghost.txt", false, 0, time.Unix(0, 0), false)} + } + return nil + } + callbacks := &httpdtest.HTTPFsCallbacks{ + Readdir: readdirCallback, + } + if err := httpdtest.StartTestHTTPFs(httpFsPort, callbacks); err != nil { + logger.ErrorToConsole("could not start HTTPfs test server: %v", err) + os.Exit(1) + } + }() + waitTCPListening(fmt.Sprintf(":%d", httpFsPort)) +} diff --git a/internal/common/ratelimiter.go b/internal/common/ratelimiter.go new file mode 100644 index 00000000..85d88092 --- /dev/null +++ b/internal/common/ratelimiter.go @@ -0,0 +1,245 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "slices" + "sort" + "sync" + "sync/atomic" + "time" + + "golang.org/x/time/rate" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + errNoBucket = errors.New("no bucket found") + errReserve = errors.New("unable to reserve token") + rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP} +) + +// RateLimiterType defines the supported rate limiters types +type RateLimiterType int + +// Supported rate limiter types +const ( + rateLimiterTypeGlobal RateLimiterType = iota + 1 + rateLimiterTypeSource +) + +// RateLimiterConfig defines the configuration for a rate limiter +type RateLimiterConfig struct { + // Average defines the maximum rate allowed. 0 means disabled + Average int64 `json:"average" mapstructure:"average"` + // Period defines the period as milliseconds. Default: 1000 (1 second). + // The rate is actually defined by dividing average by period. + // So for a rate below 1 req/s, one needs to define a period larger than a second. + Period int64 `json:"period" mapstructure:"period"` + // Burst is the maximum number of requests allowed to go through in the + // same arbitrarily small period of time. Default: 1. + Burst int `json:"burst" mapstructure:"burst"` + // Type defines the rate limiter type: + // - rateLimiterTypeGlobal is a global rate limiter independent from the source + // - rateLimiterTypeSource is a per-source rate limiter + Type int `json:"type" mapstructure:"type"` + // Protocols defines the protocols for this rate limiter. + // Available protocols are: "SFTP", "FTP", "DAV". + // A rate limiter with no protocols defined is disabled + Protocols []string `json:"protocols" mapstructure:"protocols"` + // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter, + // a new defender event will be generated + GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"` + // The number of per-ip rate limiters kept in memory will vary between the + // soft and hard limit + EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"` + EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"` +} + +func (r *RateLimiterConfig) isEnabled() bool { + return r.Average > 0 && len(r.Protocols) > 0 +} + +func (r *RateLimiterConfig) validate() error { + if r.Burst < 1 { + return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst) + } + if r.Period < 100 { + return fmt.Errorf("invalid period %v. It must be >= 100", r.Period) + } + if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) { + return fmt.Errorf("invalid type %v", r.Type) + } + if r.Type != int(rateLimiterTypeGlobal) { + if r.EntriesSoftLimit <= 0 { + return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit) + } + if r.EntriesHardLimit <= r.EntriesSoftLimit { + return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit) + } + } + r.Protocols = util.RemoveDuplicates(r.Protocols, true) + for _, protocol := range r.Protocols { + if !slices.Contains(rateLimiterProtocolValues, protocol) { + return fmt.Errorf("invalid protocol %q", protocol) + } + } + return nil +} + +func (r *RateLimiterConfig) getLimiter() *rateLimiter { + limiter := &rateLimiter{ + burst: r.Burst, + globalBucket: nil, + generateDefenderEvents: r.GenerateDefenderEvents, + } + var maxDelay time.Duration + period := time.Duration(r.Period) * time.Millisecond + rtl := float64(r.Average*int64(time.Second)) / float64(period) + limiter.rate = rate.Limit(rtl) + if rtl < 1 { + maxDelay = period / 2 + } else { + maxDelay = time.Second / (time.Duration(rtl) * 2) + } + if maxDelay > 10*time.Second { + maxDelay = 10 * time.Second + } + limiter.maxDelay = maxDelay + limiter.buckets = sourceBuckets{ + buckets: make(map[string]sourceRateLimiter), + hardLimit: r.EntriesHardLimit, + softLimit: r.EntriesSoftLimit, + } + if r.Type != int(rateLimiterTypeSource) { + limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst) + } + return limiter +} + +// RateLimiter defines a rate limiter +type rateLimiter struct { + rate rate.Limit + burst int + maxDelay time.Duration + globalBucket *rate.Limiter + buckets sourceBuckets + generateDefenderEvents bool +} + +// Wait blocks until the limit allows one event to happen +// or returns an error if the time to wait exceeds the max +// allowed delay +func (rl *rateLimiter) Wait(source, protocol string) (time.Duration, error) { + var res *rate.Reservation + if rl.globalBucket != nil { + res = rl.globalBucket.Reserve() + } else { + var err error + res, err = rl.buckets.reserve(source) + if err != nil { + rateLimiter := rate.NewLimiter(rl.rate, rl.burst) + res = rl.buckets.addAndReserve(rateLimiter, source) + } + } + if !res.OK() { + return 0, errReserve + } + delay := res.Delay() + if delay > rl.maxDelay { + res.Cancel() + if rl.generateDefenderEvents && rl.globalBucket == nil { + AddDefenderEvent(source, protocol, HostEventLimitExceeded) + } + return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay) + } + time.Sleep(delay) + return 0, nil +} + +type sourceRateLimiter struct { + lastActivity *atomic.Int64 + bucket *rate.Limiter +} + +func (s *sourceRateLimiter) updateLastActivity() { + s.lastActivity.Store(time.Now().UnixNano()) +} + +func (s *sourceRateLimiter) getLastActivity() int64 { + return s.lastActivity.Load() +} + +type sourceBuckets struct { + sync.RWMutex + buckets map[string]sourceRateLimiter + hardLimit int + softLimit int +} + +func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) { + b.RLock() + defer b.RUnlock() + + if src, ok := b.buckets[source]; ok { + src.updateLastActivity() + return src.bucket.Reserve(), nil + } + + return nil, errNoBucket +} + +func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation { + b.Lock() + defer b.Unlock() + + b.cleanup() + + src := sourceRateLimiter{ + lastActivity: new(atomic.Int64), + bucket: r, + } + src.updateLastActivity() + b.buckets[source] = src + return src.bucket.Reserve() +} + +func (b *sourceBuckets) cleanup() { + if len(b.buckets) >= b.hardLimit { + numToRemove := len(b.buckets) - b.softLimit + + kvList := make(kvList, 0, len(b.buckets)) + + for k, v := range b.buckets { + kvList = append(kvList, kv{ + Key: k, + Value: v.getLastActivity(), + }) + } + + sort.Sort(kvList) + + for idx, kv := range kvList { + if idx >= numToRemove { + break + } + + delete(b.buckets, kv.Key) + } + } +} diff --git a/internal/common/ratelimiter_test.go b/internal/common/ratelimiter_test.go new file mode 100644 index 00000000..dda2690d --- /dev/null +++ b/internal/common/ratelimiter_test.go @@ -0,0 +1,149 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRateLimiterConfig(t *testing.T) { + config := RateLimiterConfig{} + err := config.validate() + require.Error(t, err) + config.Burst = 1 + config.Period = 10 + err = config.validate() + require.Error(t, err) + config.Period = 1000 + config.Type = 100 + err = config.validate() + require.Error(t, err) + config.Type = int(rateLimiterTypeSource) + config.EntriesSoftLimit = 0 + err = config.validate() + require.Error(t, err) + config.EntriesSoftLimit = 150 + config.EntriesHardLimit = 0 + err = config.validate() + require.Error(t, err) + config.EntriesHardLimit = 200 + config.Protocols = []string{"unsupported protocol"} + err = config.validate() + require.Error(t, err) + config.Protocols = rateLimiterProtocolValues + err = config.validate() + require.NoError(t, err) + + limiter := config.getLimiter() + require.Equal(t, 500*time.Millisecond, limiter.maxDelay) + require.Nil(t, limiter.globalBucket) + config.Type = int(rateLimiterTypeGlobal) + config.Average = 1 + config.Period = 10000 + limiter = config.getLimiter() + require.Equal(t, 5*time.Second, limiter.maxDelay) + require.NotNil(t, limiter.globalBucket) + config.Period = 100000 + limiter = config.getLimiter() + require.Equal(t, 10*time.Second, limiter.maxDelay) + config.Period = 500 + config.Average = 1 + limiter = config.getLimiter() + require.Equal(t, 250*time.Millisecond, limiter.maxDelay) +} + +func TestRateLimiter(t *testing.T) { + config := RateLimiterConfig{ + Average: 1, + Period: 1000, + Burst: 1, + Type: int(rateLimiterTypeGlobal), + Protocols: rateLimiterProtocolValues, + } + limiter := config.getLimiter() + _, err := limiter.Wait("", ProtocolFTP) + require.NoError(t, err) + _, err = limiter.Wait("", ProtocolSSH) + require.Error(t, err) + + config.Type = int(rateLimiterTypeSource) + config.GenerateDefenderEvents = true + config.EntriesSoftLimit = 5 + config.EntriesHardLimit = 10 + limiter = config.getLimiter() + + source := "192.168.1.2" + _, err = limiter.Wait(source, ProtocolSSH) + require.NoError(t, err) + _, err = limiter.Wait(source, ProtocolSSH) + require.Error(t, err) + // a different source should work + _, err = limiter.Wait(source+"1", ProtocolSSH) + require.NoError(t, err) + + config.Burst = 0 + limiter = config.getLimiter() + _, err = limiter.Wait(source, ProtocolSSH) + require.ErrorIs(t, err, errReserve) +} + +func TestLimiterCleanup(t *testing.T) { + config := RateLimiterConfig{ + Average: 100, + Period: 1000, + Burst: 1, + Type: int(rateLimiterTypeSource), + Protocols: rateLimiterProtocolValues, + EntriesSoftLimit: 1, + EntriesHardLimit: 3, + } + limiter := config.getLimiter() + source1 := "10.8.0.1" + source2 := "10.8.0.2" + source3 := "10.8.0.3" + source4 := "10.8.0.4" + _, err := limiter.Wait(source1, ProtocolSSH) + assert.NoError(t, err) + time.Sleep(20 * time.Millisecond) + _, err = limiter.Wait(source2, ProtocolSSH) + assert.NoError(t, err) + time.Sleep(20 * time.Millisecond) + assert.Len(t, limiter.buckets.buckets, 2) + _, ok := limiter.buckets.buckets[source1] + assert.True(t, ok) + _, ok = limiter.buckets.buckets[source2] + assert.True(t, ok) + _, err = limiter.Wait(source3, ProtocolSSH) + assert.NoError(t, err) + assert.Len(t, limiter.buckets.buckets, 3) + _, ok = limiter.buckets.buckets[source1] + assert.True(t, ok) + _, ok = limiter.buckets.buckets[source2] + assert.True(t, ok) + _, ok = limiter.buckets.buckets[source3] + assert.True(t, ok) + time.Sleep(20 * time.Millisecond) + _, err = limiter.Wait(source4, ProtocolSSH) + assert.NoError(t, err) + assert.Len(t, limiter.buckets.buckets, 2) + _, ok = limiter.buckets.buckets[source3] + assert.True(t, ok) + _, ok = limiter.buckets.buckets[source4] + assert.True(t, ok) +} diff --git a/internal/common/tlsutils.go b/internal/common/tlsutils.go new file mode 100644 index 00000000..a7dfb0c6 --- /dev/null +++ b/internal/common/tlsutils.go @@ -0,0 +1,316 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "io/fs" + "math/rand" + "os" + "path/filepath" + "slices" + "sync" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + // DefaultTLSKeyPaidID defines the id to use for non-binding specific key pairs + DefaultTLSKeyPaidID = "default" + pemCRLType = "X509 CRL" +) + +var ( + pemCRLPrefix = []byte("-----BEGIN X509 CRL") +) + +// TLSKeyPair defines the paths and the unique identifier for a TLS key pair +type TLSKeyPair struct { + Cert string + Key string + ID string +} + +// CertManager defines a TLS certificate manager +type CertManager struct { + keyPairs []TLSKeyPair + configDir string + logSender string + sync.RWMutex + caCertificates []string + caRevocationLists []string + monitorList []string + certs map[string]*tls.Certificate + certsInfo map[string]fs.FileInfo + rootCAs *x509.CertPool + crls []*x509.RevocationList +} + +// Reload tries to reload certificate and CRLs +func (m *CertManager) Reload() error { + errCrt := m.loadCertificates() + errCRLs := m.LoadCRLs() + + if errCrt != nil { + return errCrt + } + return errCRLs +} + +// LoadCertificates tries to load the configured x509 key pairs +func (m *CertManager) loadCertificates() error { + if len(m.keyPairs) == 0 { + return errors.New("no key pairs defined") + } + certs := make(map[string]*tls.Certificate) + for _, keyPair := range m.keyPairs { + if keyPair.ID == "" { + return errors.New("TLS certificate without ID") + } + newCert, err := tls.LoadX509KeyPair(keyPair.Cert, keyPair.Key) + if err != nil { + logger.Error(m.logSender, "", "unable to load X509 key pair, cert file %q key file %q error: %v", + keyPair.Cert, keyPair.Key, err) + return err + } + if _, ok := certs[keyPair.ID]; ok { + logger.Error(m.logSender, "", "TLS certificate with id %q is duplicated", keyPair.ID) + return fmt.Errorf("TLS certificate with id %q is duplicated", keyPair.ID) + } + logger.Debug(m.logSender, "", "TLS certificate %q successfully loaded, id %v", keyPair.Cert, keyPair.ID) + certs[keyPair.ID] = &newCert + if !slices.Contains(m.monitorList, keyPair.Cert) { + m.monitorList = append(m.monitorList, keyPair.Cert) + } + } + + m.Lock() + defer m.Unlock() + + m.certs = certs + return nil +} + +// HasCertificate returns true if there is a certificate for the specified certID +func (m *CertManager) HasCertificate(certID string) bool { + m.RLock() + defer m.RUnlock() + + _, ok := m.certs[certID] + return ok +} + +// GetCertificateFunc returns the loaded certificate +func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + m.RLock() + defer m.RUnlock() + + val, ok := m.certs[certID] + if !ok { + logger.Error(m.logSender, "", "no certificate for id %s", certID) + return nil, fmt.Errorf("no certificate for id %s", certID) + } + + return val, nil + } +} + +// IsRevoked returns true if the specified certificate has been revoked +func (m *CertManager) IsRevoked(crt *x509.Certificate, caCrt *x509.Certificate) bool { + m.RLock() + defer m.RUnlock() + + if crt == nil || caCrt == nil { + logger.Error(m.logSender, "", "unable to verify crt %v, ca crt %v", crt, caCrt) + return len(m.crls) > 0 + } + + for _, crl := range m.crls { + if crl.CheckSignatureFrom(caCrt) == nil { + for _, rc := range crl.RevokedCertificateEntries { + if rc.SerialNumber.Cmp(crt.SerialNumber) == 0 { + return true + } + } + } + } + + return false +} + +// LoadCRLs tries to load certificate revocation lists from the given paths +func (m *CertManager) LoadCRLs() error { + if len(m.caRevocationLists) == 0 { + return nil + } + + var crls []*x509.RevocationList + + for _, revocationList := range m.caRevocationLists { + if !util.IsFileInputValid(revocationList) { + return fmt.Errorf("invalid root CA revocation list %q", revocationList) + } + if revocationList != "" && !filepath.IsAbs(revocationList) { + revocationList = filepath.Join(m.configDir, revocationList) + } + crlBytes, err := os.ReadFile(revocationList) + if err != nil { + logger.Error(m.logSender, "", "unable to read revocation list %q", revocationList) + return err + } + if bytes.HasPrefix(crlBytes, pemCRLPrefix) { + block, _ := pem.Decode(crlBytes) + if block != nil && block.Type == pemCRLType { + crlBytes = block.Bytes + } + } + crl, err := x509.ParseRevocationList(crlBytes) + if err != nil { + logger.Error(m.logSender, "", "unable to parse revocation list %q", revocationList) + return err + } + + logger.Debug(m.logSender, "", "CRL %q successfully loaded", revocationList) + crls = append(crls, crl) + if !slices.Contains(m.monitorList, revocationList) { + m.monitorList = append(m.monitorList, revocationList) + } + } + + m.Lock() + defer m.Unlock() + + m.crls = crls + + return nil +} + +// GetRootCAs returns the set of root certificate authorities that servers +// use if required to verify a client certificate +func (m *CertManager) GetRootCAs() *x509.CertPool { + m.RLock() + defer m.RUnlock() + + return m.rootCAs +} + +// LoadRootCAs tries to load root CA certificate authorities from the given paths +func (m *CertManager) LoadRootCAs() error { + if len(m.caCertificates) == 0 { + return nil + } + + rootCAs := x509.NewCertPool() + + for _, rootCA := range m.caCertificates { + if !util.IsFileInputValid(rootCA) { + return fmt.Errorf("invalid root CA certificate %q", rootCA) + } + if rootCA != "" && !filepath.IsAbs(rootCA) { + rootCA = filepath.Join(m.configDir, rootCA) + } + crt, err := os.ReadFile(rootCA) + if err != nil { + logger.Error(m.logSender, "", "unable to read root CA from file %q: %v", rootCA, err) + return err + } + if rootCAs.AppendCertsFromPEM(crt) { + logger.Debug(m.logSender, "", "TLS certificate authority %q successfully loaded", rootCA) + } else { + err := fmt.Errorf("unable to load TLS certificate authority %q", rootCA) + logger.Error(m.logSender, "", "%v", err) + return err + } + } + + m.Lock() + defer m.Unlock() + + m.rootCAs = rootCAs + return nil +} + +// SetCACertificates sets the root CA authorities file paths. +// This should not be changed at runtime +func (m *CertManager) SetCACertificates(caCertificates []string) { + m.caCertificates = util.RemoveDuplicates(caCertificates, true) +} + +// SetCARevocationLists sets the CA revocation lists file paths. +// This should not be changed at runtime +func (m *CertManager) SetCARevocationLists(caRevocationLists []string) { + m.caRevocationLists = util.RemoveDuplicates(caRevocationLists, true) +} + +func (m *CertManager) monitor() { + certsInfo := make(map[string]fs.FileInfo) + + for _, crt := range m.monitorList { + info, err := os.Stat(crt) + if err != nil { + logger.Warn(m.logSender, "", "unable to stat certificate to monitor %q: %v", crt, err) + return + } + certsInfo[crt] = info + } + + m.Lock() + + isChanged := false + for k, oldInfo := range m.certsInfo { + newInfo, ok := certsInfo[k] + if ok { + if newInfo.Size() != oldInfo.Size() || newInfo.ModTime() != oldInfo.ModTime() { + logger.Debug(m.logSender, "", "change detected for certificate %q, reload required", k) + isChanged = true + } + } + } + m.certsInfo = certsInfo + + m.Unlock() + + if isChanged { + m.Reload() //nolint:errcheck + } +} + +// NewCertManager creates a new certificate manager +func NewCertManager(keyPairs []TLSKeyPair, configDir, logSender string) (*CertManager, error) { + manager := &CertManager{ + keyPairs: keyPairs, + configDir: configDir, + logSender: logSender, + certs: make(map[string]*tls.Certificate), + certsInfo: make(map[string]fs.FileInfo), + } + err := manager.loadCertificates() + if err != nil { + return nil, err + } + randSecs := rand.Intn(59) + manager.monitor() + if eventScheduler != nil { + _, err = eventScheduler.AddFunc(fmt.Sprintf("@every 8h0m%ds", randSecs), manager.monitor) + } + return manager, err +} diff --git a/internal/common/tlsutils_test.go b/internal/common/tlsutils_test.go new file mode 100644 index 00000000..e6b98f51 --- /dev/null +++ b/internal/common/tlsutils_test.go @@ -0,0 +1,526 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "crypto/tls" + "crypto/x509" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + serverCert = `-----BEGIN CERTIFICATE----- +MIIEIjCCAgqgAwIBAgIQfxHX0pnvRtkmtfLklgrcNzANBgkqhkiG9w0BAQsFADAT +MREwDwYDVQQDEwhDZXJ0QXV0aDAeFw0yMzAxMDMxMDIyMDdaFw0zMzAxMDMxMDMw +NDVaMBQxEjAQBgNVBAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBAKbMWjMhyjMnDsq/19J9D44Y13uPSMN26NFOCfjVgV23zcqvI8W1 +csosYj89gSmIRxpcL2FtX7NjIT4vaqXob/en1lYy8hstacOs2cy2LcVZHfxu/hv3 +6hEKLY28tOD41L1CYZesBt3yV8vGcYIOnnAdIiG52SChnduTafBVE9Pq5P7qJ1gZ +d4uBYxe8/Za0metKDvMN6FTK+THq56eD830iRwFOdSw3Z4NS/nQNeVW263E4CC4u +BVxgwIHu6giqEfIoV6oVTY64y8X2YlwqvbVN/OtWNIJBLu+mN2EhR2ygpZdAyc82 +1yrk/X2/Dd3OiKSrrvXL1fOuNGlLNGD+3vUCAwEAAaNxMG8wDgYDVR0PAQH/BAQD +AgO4MB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4EFgQUabrE +6ATHRqEf/CDQiNWI+0e/nhIwHwYDVR0jBBgwFoAUKPyWZxHuWgH3MA/996i3V4gd +aYgwDQYJKoZIhvcNAQELBQADggIBAHFtnPXxCCeeGw4RiIai3bavGtyK5qooZUia +hN8abJp9VJKYthLwF75c0wn8W0ZMTY8z9xgmFK9afWHCBNyK+0KCpd/LdDUfwvIn +3RwR4HRFjNG+n1UZBA4l1W6X6kCq9/x7YaKLrek9aBHfxwMnoMrOeMUybm6D+B5E +lSkAyJRq5VHVatM7UGmdux2MXK5IMpzlIBzz1pXddnzF3f9nfS54xt6ilWst9bMi +6mBxisJmqc51L/Fyb2SoCJoO/6kv+3V5HnRNBcZuVE8G5/Uc+WRnyy9dh996W83b +jNvSJ9UpspqMtKx7DKU4fC/3xYDjRimZvZ3akfIdkf3j5GVWMtVbx+QVSZ8aKBSM +Zx35p8aF0zppTjp2JvBpiQlGIXKfPkmmH4bLpU7Z7qLXFFnp+fs3CjcIng19gGgi +XQldgHVsl8FtIebxgW6wc5jb2y/fXjgx9c0SKEeeA3Pp6fExH8PdQdyHHmkHKQzO +ozon1tZhQbcjkNz8kXFp3x3X/0i4TsR6vsUigSFHXT7DgusBK8eAiRVOLSpbfIyp +7Ul/9DjhtYxcZjNI/xNJcECPGazNDdKh4TdLh35pnQHOsRXDWB873rr5xkJIUXbU +ubo+q0VpmF7OtfPO9PrPilWAUhVDRx7CCTW3YUsWrYJkr8d6F/n6y7QPKMtB9Y2P +jRJ4LDqX +-----END CERTIFICATE-----` + serverKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApsxaMyHKMycOyr/X0n0PjhjXe49Iw3bo0U4J+NWBXbfNyq8j +xbVyyixiPz2BKYhHGlwvYW1fs2MhPi9qpehv96fWVjLyGy1pw6zZzLYtxVkd/G7+ +G/fqEQotjby04PjUvUJhl6wG3fJXy8Zxgg6ecB0iIbnZIKGd25Np8FUT0+rk/uon +WBl3i4FjF7z9lrSZ60oO8w3oVMr5Mernp4PzfSJHAU51LDdng1L+dA15VbbrcTgI +Li4FXGDAge7qCKoR8ihXqhVNjrjLxfZiXCq9tU3861Y0gkEu76Y3YSFHbKCll0DJ +zzbXKuT9fb8N3c6IpKuu9cvV8640aUs0YP7e9QIDAQABAoIBADbD9gG/4HH3KwYr +AyPbaBYR1f59xzhWfI7sfp2zDGzHAsy/wJETyILVG9UDzrriQeZHyk7E6J0vuSR/ +0RZ0QP8hnmBjDdcajBVxVXm/fzvCzPOrRcfNGI9LtjVJdmI/kSoq93wjQYXyIh2I +JJC9WAwbpK9KJB5wsjH8LtZ4OLBlcdeB8jcvO6FzGij6HwyxqyPctxetlvpcmc/w +zNJhps6t+TJ8PpNtEmTpOOmx85V6HMb3QJexwmUYygRaOoiQKBKZSNaOnGoC8w1d +WahyyXJk4B3OUllqG1TLUgabFGqq2PeJSP8RvYFH8DUj+fdxD78qDHAygrL8ELLZ +2O3Wi0ECgYEAyREnS/kylyIcAsyKczsKEDMIDUF9rGvm2B+QG7cLKHTu24oiNg5B +Ik5nkaYmSSrC3O2/s4v47mYzMtWbLxlogiNK6ljLPpdU5/JaeHncZC+18seBoePQ +9nOW3AvY2A6ihzy8sKRMfl3FUx/1rcXLdNwkMQo0FWR7nqVPUme9QkkCgYEA1F5n +lhfDptiHekagKMTf9SGw4B2UiG6SLjMWhcG2AEFeXpZlsk7Qubnuzk0krjYp+JAI +brlzMOkmBXBQywKLe3SG0s0McbRGWVFbEA1SA+WZV5rwJe5PO7W6ndCF2+slyZ5T +dPwOY1RybV6R07EvjtfnE8Wtdyko4X22sTkyd00CgYA5MYnuEHqVhvxUx33yfS7F +oN5/dsuayi6l94R0fcLMxUZUaJyGp9NbQNYxFgP5+BHp6i8HkZ9DoQqbQSudYCrc +KdHbi1p0+XMLb2LQtkk8rl2hK6LyO+1qzUJyYWRTQQZ2VY6O6I1hvKaumH636XWQ +TjZ1RKPAGg8X94nytNOfEQKBgQC/+TL0iDjyGyykyTFAiW/WXQVSIwtBJYr5Pm9u +rESFCJJxOM1nmT2vlrecQDoXTZk1O6aTyQqrPSeEpRoz2fISwKyb5IYKRyeM2DFU +WmY4ZZXvjnzmHP39APNYc8Z9nZzEHF5fEvdCrXTfDy0Ny08tdlhKFFkRreBprkW3 +APhwxQKBgDBdionnjdB9jdGbYHrsPaweMGdQNXkrTTCFfBA47F+qZswfon12yu4A ++cBKCnQe2dQHl8AV3IeUKpmNghu4iICOASQEO9dS6OWZI5vBxZMePBm6+bjTOuf6 +ozecw3yR55tKpPImt87rhrWlwp35uWuhOr9GHYBdFSwgrEkVMw++ +-----END RSA PRIVATE KEY-----` + caCRT = `-----BEGIN CERTIFICATE----- +MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 +QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT +CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m +fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 +v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh +yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL +CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U +4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb +KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo +NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb +S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i +M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY +/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy +OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD +AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu +XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F +sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki +pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE +jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX +y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy +WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV +oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV +L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 +DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E +eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli +SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w +kK8u1YiiVenmoQ== +-----END CERTIFICATE-----` + caKey = `-----BEGIN RSA PRIVATE KEY----- +MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x +b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD +OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz +n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj +7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ +De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb +cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq +R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 +8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB +P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z +xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA +AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY +00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa +KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t +Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR +y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz ++465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc +9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx +lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR +nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ +u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz +pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP +L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 +0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM +0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN ++jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u +BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe +KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO +Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 +/hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn +v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX +pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H +RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO +hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH +Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR +V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP +ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz +1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h +iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y +BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q +uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN +38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ +SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq +jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck +vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ +2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB +PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ +VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor +PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== +-----END RSA PRIVATE KEY-----` + caCRL = `-----BEGIN X509 CRL----- +MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN +MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg +ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r +rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV ++jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 +XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE +jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF +DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD +MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 +iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE +ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 +hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 +cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 +sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn +w7bXJGfJcIMrrKs= +-----END X509 CRL-----` + client1Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 +HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE +OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 +Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf +XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO +c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo +Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU +OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC +sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I +l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb +BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz +MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 +QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz +J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ +offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX +G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr +kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au +MU3Bo0A= +-----END CERTIFICATE-----` + client1Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki +J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi +85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW +eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 +ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy +bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO +hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw +eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl +X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ +CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL +j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU +NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf +sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB +Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB +LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g +mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 +T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn +RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa +Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW +asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ +DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 +0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls +ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY +nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe +yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== +-----END RSA PRIVATE KEY-----` + // client 2 crt is revoked + client2Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b +Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua +eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ +cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT +A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 +6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f +Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz +PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY +xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D +EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB +4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO +78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 +7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e +qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N +f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv +/ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S +ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR +E0+r2+4= +-----END CERTIFICATE-----` + client2Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV +jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP +kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK +h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P +QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu +4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq +op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM +qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt +Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL +VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO +qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy +VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl +NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 +4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl +/owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd +aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu +khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz +3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ +eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i +vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB +GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb +Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D +8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl +RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA +ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl +-----END RSA PRIVATE KEY-----` +) + +func TestLoadCertificate(t *testing.T) { + startEventScheduler() + caCrtPath := filepath.Join(os.TempDir(), "testca.crt") + caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") + certPath := filepath.Join(os.TempDir(), "test.crt") + keyPath := filepath.Join(os.TempDir(), "test.key") + err := os.WriteFile(caCrtPath, []byte(caCRT), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(certPath, []byte(serverCert), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(keyPath, []byte(serverKey), os.ModePerm) + assert.NoError(t, err) + keyPairs := []TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: DefaultTLSKeyPaidID, + }, + { + Cert: certPath, + Key: keyPath, + ID: DefaultTLSKeyPaidID, + }, + } + certManager, err := NewCertManager(keyPairs, configDir, logSenderTest) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is duplicated") + } + assert.Nil(t, certManager) + + keyPairs = []TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: DefaultTLSKeyPaidID, + }, + } + + certManager, err = NewCertManager(keyPairs, configDir, logSenderTest) + assert.NoError(t, err) + assert.True(t, certManager.HasCertificate(DefaultTLSKeyPaidID)) + assert.False(t, certManager.HasCertificate("unknownID")) + certFunc := certManager.GetCertificateFunc(DefaultTLSKeyPaidID) + if assert.NotNil(t, certFunc) { + hello := &tls.ClientHelloInfo{ + ServerName: "localhost", + CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}, + } + cert, err := certFunc(hello) + assert.NoError(t, err) + assert.Equal(t, certManager.certs[DefaultTLSKeyPaidID], cert) + } + certFunc = certManager.GetCertificateFunc("unknownID") + if assert.NotNil(t, certFunc) { + hello := &tls.ClientHelloInfo{ + ServerName: "localhost", + CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}, + } + _, err = certFunc(hello) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no certificate for id unknownID") + } + } + certManager.SetCACertificates(nil) + err = certManager.LoadRootCAs() + assert.NoError(t, err) + + certManager.SetCACertificates([]string{""}) + err = certManager.LoadRootCAs() + assert.Error(t, err) + + certManager.SetCACertificates([]string{"invalid"}) + err = certManager.LoadRootCAs() + assert.Error(t, err) + + // laoding the key as root CA must fail + certManager.SetCACertificates([]string{keyPath}) + err = certManager.LoadRootCAs() + assert.Error(t, err) + + certManager.SetCACertificates([]string{certPath}) + err = certManager.LoadRootCAs() + assert.NoError(t, err) + + rootCa := certManager.GetRootCAs() + assert.NotNil(t, rootCa) + + err = certManager.Reload() + assert.NoError(t, err) + + certManager.SetCARevocationLists(nil) + err = certManager.LoadCRLs() + assert.NoError(t, err) + + certManager.SetCARevocationLists([]string{""}) + err = certManager.LoadCRLs() + assert.Error(t, err) + + certManager.SetCARevocationLists([]string{"invalid crl"}) + err = certManager.LoadCRLs() + assert.Error(t, err) + + // this is not a crl and must fail + certManager.SetCARevocationLists([]string{caCrtPath}) + err = certManager.LoadCRLs() + assert.Error(t, err) + + certManager.SetCARevocationLists([]string{caCrlPath}) + err = certManager.LoadCRLs() + assert.NoError(t, err) + + crt, err := tls.X509KeyPair([]byte(caCRT), []byte(caKey)) + assert.NoError(t, err) + + x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + crt, err = tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + x509crt, err := x509.ParseCertificate(crt.Certificate[0]) + if assert.NoError(t, err) { + assert.False(t, certManager.IsRevoked(x509crt, x509CAcrt)) + } + + crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + x509crt, err = x509.ParseCertificate(crt.Certificate[0]) + if assert.NoError(t, err) { + assert.True(t, certManager.IsRevoked(x509crt, x509CAcrt)) + } + + assert.True(t, certManager.IsRevoked(nil, nil)) + + err = os.Remove(caCrlPath) + assert.NoError(t, err) + err = certManager.Reload() + assert.Error(t, err) + + err = os.Remove(certPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + err = certManager.Reload() + assert.Error(t, err) + + err = os.Remove(caCrtPath) + assert.NoError(t, err) + stopEventScheduler() +} + +func TestLoadInvalidCert(t *testing.T) { + startEventScheduler() + certManager, err := NewCertManager(nil, configDir, logSenderTest) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no key pairs defined") + } + assert.Nil(t, certManager) + + keyPairs := []TLSKeyPair{ + { + Cert: "test.crt", + Key: "test.key", + ID: DefaultTLSKeyPaidID, + }, + } + certManager, err = NewCertManager(keyPairs, configDir, logSenderTest) + assert.Error(t, err) + assert.Nil(t, certManager) + + keyPairs = []TLSKeyPair{ + { + Cert: "test.crt", + Key: "test.key", + }, + } + certManager, err = NewCertManager(keyPairs, configDir, logSenderTest) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TLS certificate without ID") + } + assert.Nil(t, certManager) + stopEventScheduler() +} + +func TestCertificateMonitor(t *testing.T) { + startEventScheduler() + defer stopEventScheduler() + + certPath := filepath.Join(os.TempDir(), "test.crt") + keyPath := filepath.Join(os.TempDir(), "test.key") + caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") + err := os.WriteFile(certPath, []byte(serverCert), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(keyPath, []byte(serverKey), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) + assert.NoError(t, err) + + keyPairs := []TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: DefaultTLSKeyPaidID, + }, + } + certManager, err := NewCertManager(keyPairs, configDir, logSenderTest) + assert.NoError(t, err) + assert.Len(t, certManager.monitorList, 1) + require.Len(t, certManager.certsInfo, 1) + info := certManager.certsInfo[certPath] + require.NotNil(t, info) + certManager.SetCARevocationLists([]string{caCrlPath}) + err = certManager.LoadCRLs() + assert.NoError(t, err) + assert.Len(t, certManager.monitorList, 2) + certManager.monitor() + require.Len(t, certManager.certsInfo, 2) + + err = os.Remove(certPath) + assert.NoError(t, err) + certManager.monitor() + + time.Sleep(100 * time.Millisecond) + err = os.WriteFile(certPath, []byte(serverCert), os.ModePerm) + assert.NoError(t, err) + certManager.monitor() + require.Len(t, certManager.certsInfo, 2) + newInfo := certManager.certsInfo[certPath] + require.NotNil(t, newInfo) + assert.Equal(t, info.Size(), newInfo.Size()) + assert.NotEqual(t, info.ModTime(), newInfo.ModTime()) + + err = os.Remove(caCrlPath) + assert.NoError(t, err) + + err = os.Remove(certPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) +} diff --git a/internal/common/transfer.go b/internal/common/transfer.go new file mode 100644 index 00000000..df9638b1 --- /dev/null +++ b/internal/common/transfer.go @@ -0,0 +1,570 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "io/fs" + "path" + "sync" + "sync/atomic" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + // ErrTransferClosed defines the error returned for a closed transfer + ErrTransferClosed = errors.New("transfer already closed") +) + +// BaseTransfer contains protocols common transfer details for an upload or a download. +type BaseTransfer struct { + ID int64 + BytesSent atomic.Int64 + BytesReceived atomic.Int64 + Fs vfs.Fs + File vfs.File + Connection *BaseConnection + cancelFn func() + fsPath string + effectiveFsPath string + requestPath string + ftpMode string + start time.Time + MaxWriteSize int64 + MinWriteOffset int64 + InitialSize int64 + truncatedSize int64 + isNewFile bool + transferType int + AbortTransfer atomic.Bool + aTime time.Time + mTime time.Time + transferQuota dataprovider.TransferQuota + metadata map[string]string + sync.Mutex + errAbort error + ErrTransfer error +} + +// NewBaseTransfer returns a new BaseTransfer and adds it to the given connection +func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPath, effectiveFsPath, requestPath string, + transferType int, minWriteOffset, initialSize, maxWriteSize, truncatedSize int64, isNewFile bool, fs vfs.Fs, + transferQuota dataprovider.TransferQuota, +) *BaseTransfer { + t := &BaseTransfer{ + ID: conn.GetTransferID(), + File: file, + Connection: conn, + cancelFn: cancelFn, + fsPath: fsPath, + effectiveFsPath: effectiveFsPath, + start: time.Now(), + transferType: transferType, + MinWriteOffset: minWriteOffset, + InitialSize: initialSize, + isNewFile: isNewFile, + requestPath: requestPath, + MaxWriteSize: maxWriteSize, + truncatedSize: truncatedSize, + transferQuota: transferQuota, + Fs: fs, + } + t.AbortTransfer.Store(false) + t.BytesSent.Store(0) + t.BytesReceived.Store(0) + + conn.AddTransfer(t) + return t +} + +// GetTransferQuota returns data transfer quota limits +func (t *BaseTransfer) GetTransferQuota() dataprovider.TransferQuota { + return t.transferQuota +} + +// SetFtpMode sets the FTP mode for the current transfer +func (t *BaseTransfer) SetFtpMode(mode string) { + t.ftpMode = mode +} + +// GetID returns the transfer ID +func (t *BaseTransfer) GetID() int64 { + return t.ID +} + +// GetType returns the transfer type +func (t *BaseTransfer) GetType() int { + return t.transferType +} + +// GetSize returns the transferred size +func (t *BaseTransfer) GetSize() int64 { + if t.transferType == TransferDownload { + return t.BytesSent.Load() + } + return t.BytesReceived.Load() +} + +// GetDownloadedSize returns the transferred size +func (t *BaseTransfer) GetDownloadedSize() int64 { + return t.BytesSent.Load() +} + +// GetUploadedSize returns the transferred size +func (t *BaseTransfer) GetUploadedSize() int64 { + return t.BytesReceived.Load() +} + +// GetStartTime returns the start time +func (t *BaseTransfer) GetStartTime() time.Time { + return t.start +} + +// GetAbortError returns the error to send to the client if the transfer was aborted +func (t *BaseTransfer) GetAbortError() error { + t.Lock() + defer t.Unlock() + + if t.errAbort != nil { + return t.errAbort + } + return getQuotaExceededError(t.Connection.protocol) +} + +// SignalClose signals that the transfer should be closed after the next read/write. +// The optional error argument allow to send a specific error, otherwise a generic +// transfer aborted error is sent +func (t *BaseTransfer) SignalClose(err error) { + t.Lock() + t.errAbort = err + t.Unlock() + t.AbortTransfer.Store(true) +} + +// GetTruncatedSize returns the truncated sized if this is an upload overwriting +// an existing file +func (t *BaseTransfer) GetTruncatedSize() int64 { + return t.truncatedSize +} + +// HasSizeLimit returns true if there is an upload or download size limit +func (t *BaseTransfer) HasSizeLimit() bool { + if t.MaxWriteSize > 0 { + return true + } + if t.transferQuota.HasSizeLimits() { + return true + } + + return false +} + +// GetVirtualPath returns the transfer virtual path +func (t *BaseTransfer) GetVirtualPath() string { + return t.requestPath +} + +// GetFsPath returns the transfer filesystem path +func (t *BaseTransfer) GetFsPath() string { + return t.fsPath +} + +// SetTimes stores access and modification times if fsPath matches the current file +func (t *BaseTransfer) SetTimes(fsPath string, atime time.Time, mtime time.Time) bool { + if fsPath == t.GetFsPath() { + t.aTime = atime + t.mTime = mtime + return true + } + return false +} + +// GetRealFsPath returns the real transfer filesystem path. +// If atomic uploads are enabled this differ from fsPath +func (t *BaseTransfer) GetRealFsPath(fsPath string) string { + if fsPath == t.GetFsPath() { + if t.File != nil || vfs.IsLocalOsFs(t.Fs) { + return t.effectiveFsPath + } + return t.fsPath + } + return "" +} + +// SetMetadata sets the metadata for the file +func (t *BaseTransfer) SetMetadata(val map[string]string) { + t.metadata = val +} + +// SetCancelFn sets the cancel function for the transfer +func (t *BaseTransfer) SetCancelFn(cancelFn func()) { + t.cancelFn = cancelFn +} + +// ConvertError accepts an error that occurs during a read or write and +// converts it into a more understandable form for the client if it is a +// well-known type of error +func (t *BaseTransfer) ConvertError(err error) error { + var pathError *fs.PathError + if errors.As(err, &pathError) { + return fmt.Errorf("%s %s: %s", pathError.Op, t.GetVirtualPath(), pathError.Err.Error()) + } + return t.Connection.GetFsError(t.Fs, err) +} + +// CheckRead returns an error if read if not allowed +func (t *BaseTransfer) CheckRead() error { + if t.transferQuota.AllowedDLSize == 0 && t.transferQuota.AllowedTotalSize == 0 { + return nil + } + if t.transferQuota.AllowedTotalSize > 0 { + if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { + return t.Connection.GetReadQuotaExceededError() + } + } else if t.transferQuota.AllowedDLSize > 0 { + if t.BytesSent.Load() > t.transferQuota.AllowedDLSize { + return t.Connection.GetReadQuotaExceededError() + } + } + return nil +} + +// CheckWrite returns an error if write if not allowed +func (t *BaseTransfer) CheckWrite() error { + if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize { + return t.Connection.GetQuotaExceededError() + } + if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 { + return nil + } + if t.transferQuota.AllowedTotalSize > 0 { + if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { + return t.Connection.GetQuotaExceededError() + } + } else if t.transferQuota.AllowedULSize > 0 { + if t.BytesReceived.Load() > t.transferQuota.AllowedULSize { + return t.Connection.GetQuotaExceededError() + } + } + return nil +} + +// Truncate changes the size of the opened file. +// Supported for local fs only +func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { + if fsPath == t.GetFsPath() { + if t.File != nil { + initialSize := t.InitialSize + err := t.File.Truncate(size) + if err == nil { + t.Lock() + t.InitialSize = size + if t.MaxWriteSize > 0 { + sizeDiff := initialSize - size + t.MaxWriteSize += sizeDiff + metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), + t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) + if t.transferQuota.HasSizeLimits() { + go func(ulSize, dlSize int64, user dataprovider.User) { + dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck + }(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User) + } + t.BytesReceived.Store(0) + } + t.Unlock() + } + t.Connection.Log(logger.LevelDebug, "file %q truncated to size %v max write size %v new initial size %v err: %v", + fsPath, size, t.MaxWriteSize, t.InitialSize, err) + return initialSize, err + } + if size == 0 && t.BytesSent.Load() == 0 { + // for cloud providers the file is always truncated to zero, we don't support append/resume for uploads. + // For buffered SFTP and local fs we can have buffered bytes so we returns an error + if !vfs.IsBufferedLocalOrSFTPFs(t.Fs) { + return 0, nil + } + } + return 0, vfs.ErrVfsUnsupported + } + return 0, errTransferMismatch +} + +// TransferError is called if there is an unexpected error. +// For example network or client issues +func (t *BaseTransfer) TransferError(err error) { + t.Lock() + defer t.Unlock() + if t.ErrTransfer != nil { + return + } + t.ErrTransfer = err + if t.cancelFn != nil { + t.cancelFn() + } + elapsed := time.Since(t.start).Nanoseconds() / 1000000 + t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %q, error: \"%v\" bytes sent: %v, "+ + "bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(), + t.BytesReceived.Load(), elapsed) +} + +func (t *BaseTransfer) getUploadFileSize() (int64, int, error) { + var fileSize int64 + var deletedFiles int + + switch dataprovider.GetQuotaTracking() { + case 0: + return fileSize, deletedFiles, errors.New("quota tracking disabled") + case 2: + if !t.Connection.User.HasQuotaRestrictions() { + vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) + if err != nil { + return fileSize, deletedFiles, errors.New("quota tracking disabled for this user") + } + if vfolder.IsIncludedInUserQuota() { + return fileSize, deletedFiles, errors.New("quota tracking disabled for this user and folder included in user quota") + } + } + } + + info, err := t.Fs.Stat(t.fsPath) + if err == nil { + fileSize = info.Size() + } + if t.ErrTransfer != nil && vfs.IsCryptOsFs(t.Fs) { + errDelete := t.Fs.Remove(t.fsPath, false) + if errDelete != nil { + t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %q: %v", t.fsPath, errDelete) + } else { + fileSize = 0 + deletedFiles = 1 + t.BytesReceived.Store(0) + t.MinWriteOffset = 0 + } + } + return fileSize, deletedFiles, err +} + +// return 1 if the file is outside the user home dir +func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int { + if err == nil { + return 0 + } + if t.ErrTransfer == nil { + t.ErrTransfer = err + } + if Config.TempPath == "" { + return 0 + } + err = t.Fs.Remove(t.effectiveFsPath, false) + t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %q, deletion error: %v", + t.effectiveFsPath, err) + // the file is outside the home dir so don't update the quota + t.BytesReceived.Store(0) + t.MinWriteOffset = 0 + return 1 +} + +// Close it is called when the transfer is completed. +// It logs the transfer info, updates the user quota (for uploads) +// and executes any defined action. +// If there is an error no action will be executed and, in atomic mode, +// we try to delete the temporary file +func (t *BaseTransfer) Close() error { + defer t.Connection.RemoveTransfer(t) + + var err error + numFiles := t.getUploadedFiles() + metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), + t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) + if t.transferQuota.HasSizeLimits() { + dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck + t.BytesSent.Load(), false) + } + if (t.File != nil || vfs.IsLocalOsFs(t.Fs)) && t.Connection.IsQuotaExceededError(t.ErrTransfer) { + // if quota is exceeded we try to remove the partial file for uploads to local filesystem + err = t.Fs.Remove(t.effectiveFsPath, false) + if err == nil { + t.BytesReceived.Store(0) + t.MinWriteOffset = 0 + } + t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %q, deletion error: %v", + t.effectiveFsPath, err) + } else if t.isAtomicUpload() { + if t.ErrTransfer == nil || Config.UploadMode&UploadModeAtomicWithResume != 0 { + _, _, err = t.Fs.Rename(t.effectiveFsPath, t.fsPath, 0) + t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %q -> %q, error: %v", + t.effectiveFsPath, t.fsPath, err) + // the file must be removed if it is uploaded to a path outside the home dir and cannot be renamed + t.checkUploadOutsideHomeDir(err) + } else { + err = t.Fs.Remove(t.effectiveFsPath, false) + t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %q, deletion error: %v", + t.ErrTransfer, t.effectiveFsPath, err) + if err == nil { + t.BytesReceived.Store(0) + t.MinWriteOffset = 0 + } + } + } + elapsed := time.Since(t.start).Nanoseconds() / 1000000 + var uploadFileSize int64 + if t.transferType == TransferDownload { + logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username, + t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode, + t.ErrTransfer) + ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck + t.BytesSent.Load(), t.ErrTransfer, elapsed, t.metadata) + } else { + statSize, deletedFiles, errStat := t.getUploadFileSize() + if errStat == nil { + uploadFileSize = statSize + } else { + uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset + if t.Fs.IsNotExist(errStat) { + uploadFileSize = 0 + numFiles-- + } + } + numFiles -= deletedFiles + t.Connection.Log(logger.LevelDebug, "upload file size %d, num files %d, deleted files %d, fs path %q", + uploadFileSize, numFiles, deletedFiles, t.fsPath) + numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize, elapsed) + t.updateQuota(numFiles, uploadFileSize) + t.updateTimes() + logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username, + t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode, + t.ErrTransfer) + } + if t.ErrTransfer != nil { + t.Connection.Log(logger.LevelError, "transfer error: %v, path: %q", t.ErrTransfer, t.fsPath) + if err == nil { + err = t.ErrTransfer + } + } + t.updateTransferTimestamps(uploadFileSize, elapsed) + return err +} + +func (t *BaseTransfer) isAtomicUpload() bool { + return t.transferType == TransferUpload && t.effectiveFsPath != t.fsPath +} + +func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize, elapsed int64) { + if t.ErrTransfer != nil { + return + } + if t.transferType == TransferUpload { + if t.Connection.User.FirstUpload == 0 && !t.Connection.uploadDone.Load() { + if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, true); err == nil { + t.Connection.uploadDone.Store(true) + ExecuteActionNotification(t.Connection, operationFirstUpload, t.fsPath, t.requestPath, "", //nolint:errcheck + "", "", uploadFileSize, t.ErrTransfer, elapsed, t.metadata) + } + } + return + } + if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 { + if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil { + t.Connection.downloadDone.Store(true) + ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck + "", "", t.BytesSent.Load(), t.ErrTransfer, elapsed, t.metadata) + } + } +} + +func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize, elapsed int64) (int, int64) { + err := ExecuteActionNotification(t.Connection, operationUpload, t.fsPath, t.requestPath, "", "", "", + fileSize, t.ErrTransfer, elapsed, t.metadata) + if err != nil { + if t.ErrTransfer == nil { + t.ErrTransfer = err + } + // try to remove the uploaded file + err = t.Fs.Remove(t.fsPath, false) + if err == nil { + numFiles-- + fileSize = 0 + t.BytesReceived.Store(0) + t.MinWriteOffset = 0 + } else { + t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err) + } + } + return numFiles, fileSize +} + +func (t *BaseTransfer) getUploadedFiles() int { + numFiles := 0 + if t.isNewFile { + numFiles = 1 + } + return numFiles +} + +func (t *BaseTransfer) updateTimes() { + if !t.aTime.IsZero() && !t.mTime.IsZero() { + err := t.Fs.Chtimes(t.fsPath, t.aTime, t.mTime, false) + t.Connection.Log(logger.LevelDebug, "set times for file %q, atime: %v, mtime: %v, err: %v", + t.fsPath, t.aTime, t.mTime, err) + } +} + +func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool { + // Uploads on some filesystem (S3 and similar) are atomic, if there is an error nothing is uploaded + if t.File == nil && t.ErrTransfer != nil && vfs.HasImplicitAtomicUploads(t.Fs) { + return false + } + sizeDiff := fileSize - t.InitialSize + if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) { + vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &t.Connection.User, numFiles, + sizeDiff, false) + } else { + dataprovider.UpdateUserQuota(&t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck + } + return true + } + return false +} + +// HandleThrottle manage bandwidth throttling +func (t *BaseTransfer) HandleThrottle() { + var wantedBandwidth int64 + var trasferredBytes int64 + if t.transferType == TransferDownload { + wantedBandwidth = t.Connection.User.DownloadBandwidth + trasferredBytes = t.BytesSent.Load() + } else { + wantedBandwidth = t.Connection.User.UploadBandwidth + trasferredBytes = t.BytesReceived.Load() + } + if wantedBandwidth > 0 { + // real and wanted elapsed as milliseconds, bytes as kilobytes + realElapsed := time.Since(t.start).Nanoseconds() / 1000000 + // trasferredBytes / 1024 = KB/s, we multiply for 1000 to get milliseconds + wantedElapsed := 1000 * (trasferredBytes / 1024) / wantedBandwidth + if wantedElapsed > realElapsed { + toSleep := time.Duration(wantedElapsed - realElapsed) + time.Sleep(toSleep * time.Millisecond) + } + } +} diff --git a/internal/common/transfer_test.go b/internal/common/transfer_test.go new file mode 100644 index 00000000..1bf0dbe1 --- /dev/null +++ b/internal/common/transfer_test.go @@ -0,0 +1,477 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/pkg/sftp" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func TestTransferUpdateQuota(t *testing.T) { + conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) + transfer := BaseTransfer{ + Connection: conn, + transferType: TransferUpload, + Fs: vfs.NewOsFs("", os.TempDir(), "", nil), + } + transfer.BytesReceived.Store(123) + errFake := errors.New("fake error") + transfer.TransferError(errFake) + err := transfer.Close() + if assert.Error(t, err) { + assert.EqualError(t, err, errFake.Error()) + } + mappedPath := filepath.Join(os.TempDir(), "vdir") + vdirPath := "/vdir" + conn.User.VirtualFolders = append(conn.User.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + transfer.ErrTransfer = nil + transfer.BytesReceived.Store(1) + transfer.requestPath = "/vdir/file" + assert.True(t, transfer.updateQuota(1, 0)) + err = transfer.Close() + assert.NoError(t, err) + + transfer.ErrTransfer = errFake + transfer.Fs = newMockOsFs(true, "", "", "S3Fs fake", nil) + assert.False(t, transfer.updateQuota(1, 0)) +} + +func TestTransferThrottling(t *testing.T) { + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test", + UploadBandwidth: 50, + DownloadBandwidth: 40, + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + testFileSize := int64(131072) + wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth + wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth + // some tolerance + wantedUploadElapsed -= wantedDownloadElapsed / 10 + wantedDownloadElapsed -= wantedDownloadElapsed / 10 + conn := NewBaseConnection("id", ProtocolSCP, "", "", u) + transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + transfer.BytesReceived.Store(testFileSize) + transfer.Connection.UpdateLastActivity() + startTime := transfer.Connection.GetLastActivity() + transfer.HandleThrottle() + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + assert.GreaterOrEqual(t, elapsed, wantedUploadElapsed, "upload bandwidth throttling not respected") + err := transfer.Close() + assert.NoError(t, err) + + transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + transfer.BytesSent.Store(testFileSize) + transfer.Connection.UpdateLastActivity() + startTime = transfer.Connection.GetLastActivity() + + transfer.HandleThrottle() + elapsed = time.Since(startTime).Nanoseconds() / 1000000 + assert.GreaterOrEqual(t, elapsed, wantedDownloadElapsed, "download bandwidth throttling not respected") + err = transfer.Close() + assert.NoError(t, err) +} + +func TestRealPath(t *testing.T) { + testFile := filepath.Join(os.TempDir(), "afile.txt") + fs := vfs.NewOsFs("123", os.TempDir(), "", nil) + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user", + HomeDir: os.TempDir(), + }, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + file, err := os.Create(testFile) + require.NoError(t, err) + conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) + transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", + TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + rPath := transfer.GetRealFsPath(testFile) + assert.Equal(t, testFile, rPath) + rPath = conn.getRealFsPath(testFile) + assert.Equal(t, testFile, rPath) + err = transfer.Close() + assert.NoError(t, err) + err = file.Close() + assert.NoError(t, err) + transfer.File = nil + rPath = transfer.GetRealFsPath(testFile) + assert.Equal(t, testFile, rPath) + rPath = transfer.GetRealFsPath("") + assert.Empty(t, rPath) + err = os.Remove(testFile) + assert.NoError(t, err) + assert.Len(t, conn.GetTransfers(), 0) +} + +func TestTruncate(t *testing.T) { + testFile := filepath.Join(os.TempDir(), "transfer_test_file") + fs := vfs.NewOsFs("123", os.TempDir(), "", nil) + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user", + HomeDir: os.TempDir(), + }, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + file, err := os.Create(testFile) + if !assert.NoError(t, err) { + assert.FailNow(t, "unable to open test file") + } + _, err = file.Write([]byte("hello")) + assert.NoError(t, err) + conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) + transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 5, + 100, 0, false, fs, dataprovider.TransferQuota{}) + + err = conn.SetStat("/transfer_test_file", &StatAttributes{ + Size: 2, + Flags: StatAttrSize, + }) + assert.NoError(t, err) + assert.Equal(t, int64(103), transfer.MaxWriteSize) + err = transfer.Close() + assert.NoError(t, err) + err = file.Close() + assert.NoError(t, err) + fi, err := os.Stat(testFile) + if assert.NoError(t, err) { + assert.Equal(t, int64(2), fi.Size()) + } + + transfer = NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, 0, 0, + 100, 0, true, fs, dataprovider.TransferQuota{}) + // file.Stat will fail on a closed file + err = conn.SetStat("/transfer_test_file", &StatAttributes{ + Size: 2, + Flags: StatAttrSize, + }) + assert.Error(t, err) + err = transfer.Close() + assert.NoError(t, err) + + transfer = NewBaseTransfer(nil, conn, nil, testFile, testFile, "", TransferUpload, 0, 0, 0, 0, true, + fs, dataprovider.TransferQuota{}) + _, err = transfer.Truncate("mismatch", 0) + assert.EqualError(t, err, errTransferMismatch.Error()) + _, err = transfer.Truncate(testFile, 0) + assert.NoError(t, err) + _, err = transfer.Truncate(testFile, 1) + assert.EqualError(t, err, vfs.ErrVfsUnsupported.Error()) + + err = transfer.Close() + assert.NoError(t, err) + + err = os.Remove(testFile) + assert.NoError(t, err) + + assert.Len(t, conn.GetTransfers(), 0) +} + +func TestTransferErrors(t *testing.T) { + isCancelled := false + cancelFn := func() { + isCancelled = true + } + testFile := filepath.Join(os.TempDir(), "transfer_test_file") + fs := vfs.NewOsFs("id", os.TempDir(), "", nil) + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test", + HomeDir: os.TempDir(), + }, + } + err := os.WriteFile(testFile, []byte("test data"), os.ModePerm) + assert.NoError(t, err) + file, err := os.Open(testFile) + if !assert.NoError(t, err) { + assert.FailNow(t, "unable to open test file") + } + conn := NewBaseConnection("id", ProtocolSFTP, "", "", u) + transfer := NewBaseTransfer(file, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, + 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + pathError := &os.PathError{ + Op: "test", + Path: testFile, + Err: os.ErrInvalid, + } + err = transfer.ConvertError(pathError) + assert.EqualError(t, err, fmt.Sprintf("%s %s: %s", pathError.Op, "/transfer_test_file", pathError.Err.Error())) + err = transfer.ConvertError(os.ErrNotExist) + assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) + err = transfer.ConvertError(os.ErrPermission) + assert.ErrorIs(t, err, sftp.ErrSSHFxPermissionDenied) + assert.Nil(t, transfer.cancelFn) + assert.Equal(t, testFile, transfer.GetFsPath()) + transfer.SetMetadata(map[string]string{"key": "val"}) + transfer.SetCancelFn(cancelFn) + errFake := errors.New("err fake") + transfer.BytesReceived.Store(9) + transfer.TransferError(ErrQuotaExceeded) + assert.True(t, isCancelled) + transfer.TransferError(errFake) + assert.Error(t, transfer.ErrTransfer, ErrQuotaExceeded.Error()) + // the file is closed from the embedding struct before to call close + err = file.Close() + assert.NoError(t, err) + err = transfer.Close() + if assert.Error(t, err) { + assert.Error(t, err, ErrQuotaExceeded.Error()) + } + assert.NoFileExists(t, testFile) + + err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) + assert.NoError(t, err) + file, err = os.Open(testFile) + if !assert.NoError(t, err) { + assert.FailNow(t, "unable to open test file") + } + fsPath := filepath.Join(os.TempDir(), "test_file") + transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, + fs, dataprovider.TransferQuota{}) + transfer.BytesReceived.Store(9) + transfer.TransferError(errFake) + assert.Error(t, transfer.ErrTransfer, errFake.Error()) + // the file is closed from the embedding struct before to call close + err = file.Close() + assert.NoError(t, err) + err = transfer.Close() + if assert.Error(t, err) { + assert.Error(t, err, errFake.Error()) + } + assert.NoFileExists(t, testFile) + + err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) + assert.NoError(t, err) + file, err = os.Open(testFile) + if !assert.NoError(t, err) { + assert.FailNow(t, "unable to open test file") + } + transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, + fs, dataprovider.TransferQuota{}) + transfer.BytesReceived.Store(9) + // the file is closed from the embedding struct before to call close + err = file.Close() + assert.NoError(t, err) + err = transfer.Close() + assert.NoError(t, err) + assert.NoFileExists(t, testFile) + assert.FileExists(t, fsPath) + err = os.Remove(fsPath) + assert.NoError(t, err) + + assert.Len(t, conn.GetTransfers(), 0) +} + +func TestRemovePartialCryptoFile(t *testing.T) { + testFile := filepath.Join(os.TempDir(), "transfer_test_file") + fs, err := vfs.NewCryptFs("id", os.TempDir(), "", vfs.CryptFsConfig{Passphrase: kms.NewPlainSecret("secret")}) + require.NoError(t, err) + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test", + HomeDir: os.TempDir(), + QuotaFiles: 1000000, + }, + } + conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, "", "", u) + transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, + 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + transfer.ErrTransfer = errors.New("test error") + _, _, err = transfer.getUploadFileSize() + assert.Error(t, err) + err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) + assert.NoError(t, err) + size, deletedFiles, err := transfer.getUploadFileSize() + assert.NoError(t, err) + assert.Equal(t, int64(0), size) + assert.Equal(t, 1, deletedFiles) + assert.NoFileExists(t, testFile) + err = transfer.Close() + assert.Error(t, err) + assert.Len(t, conn.GetTransfers(), 0) +} + +func TestFTPMode(t *testing.T) { + conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{}) + transfer := BaseTransfer{ + Connection: conn, + transferType: TransferUpload, + Fs: vfs.NewOsFs("", os.TempDir(), "", nil), + } + transfer.BytesReceived.Store(123) + assert.Empty(t, transfer.ftpMode) + transfer.SetFtpMode("active") + assert.Equal(t, "active", transfer.ftpMode) +} + +func TestTransferQuota(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + TotalDataTransfer: 3, + UploadDataTransfer: 2, + DownloadDataTransfer: 1, + }, + } + ul, dl, total := user.GetDataTransferLimits() + assert.Equal(t, int64(2*1048576), ul) + assert.Equal(t, int64(1*1048576), dl) + assert.Equal(t, int64(3*1048576), total) + user.TotalDataTransfer = -1 + user.UploadDataTransfer = -1 + user.DownloadDataTransfer = -1 + ul, dl, total = user.GetDataTransferLimits() + assert.Equal(t, int64(0), ul) + assert.Equal(t, int64(0), dl) + assert.Equal(t, int64(0), total) + transferQuota := dataprovider.TransferQuota{} + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.TotalSize = -1 + transferQuota.ULSize = -1 + transferQuota.DLSize = -1 + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.TotalSize = 100 + transferQuota.AllowedTotalSize = 10 + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.AllowedTotalSize = 0 + assert.False(t, transferQuota.HasDownloadSpace()) + assert.False(t, transferQuota.HasUploadSpace()) + transferQuota.TotalSize = 0 + transferQuota.DLSize = 100 + transferQuota.ULSize = 50 + transferQuota.AllowedTotalSize = 0 + assert.False(t, transferQuota.HasDownloadSpace()) + assert.False(t, transferQuota.HasUploadSpace()) + transferQuota.AllowedDLSize = 1 + transferQuota.AllowedULSize = 1 + assert.True(t, transferQuota.HasDownloadSpace()) + assert.True(t, transferQuota.HasUploadSpace()) + transferQuota.AllowedDLSize = -10 + transferQuota.AllowedULSize = -1 + assert.False(t, transferQuota.HasDownloadSpace()) + assert.False(t, transferQuota.HasUploadSpace()) + + conn := NewBaseConnection("", ProtocolSFTP, "", "", user) + transfer := NewBaseTransfer(nil, conn, nil, "file.txt", "file.txt", "/transfer_test_file", TransferUpload, + 0, 0, 0, 0, true, vfs.NewOsFs("", os.TempDir(), "", nil), dataprovider.TransferQuota{}) + err := transfer.CheckRead() + assert.NoError(t, err) + err = transfer.CheckWrite() + assert.NoError(t, err) + + transfer.transferQuota = dataprovider.TransferQuota{ + AllowedTotalSize: 10, + } + transfer.BytesReceived.Store(5) + transfer.BytesSent.Store(4) + err = transfer.CheckRead() + assert.NoError(t, err) + err = transfer.CheckWrite() + assert.NoError(t, err) + + transfer.BytesSent.Store(6) + err = transfer.CheckRead() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer.CheckWrite() + assert.True(t, conn.IsQuotaExceededError(err)) + + transferQuota = dataprovider.TransferQuota{ + AllowedTotalSize: 0, + AllowedULSize: 10, + AllowedDLSize: 5, + } + transfer.transferQuota = transferQuota + assert.Equal(t, transferQuota, transfer.GetTransferQuota()) + err = transfer.CheckRead() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer.CheckWrite() + assert.NoError(t, err) + + transfer.BytesReceived.Store(11) + err = transfer.CheckRead() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer.CheckWrite() + assert.True(t, conn.IsQuotaExceededError(err)) + + err = transfer.Close() + assert.NoError(t, err) + assert.Len(t, conn.GetTransfers(), 0) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) +} + +func TestUploadOutsideHomeRenameError(t *testing.T) { + oldTempPath := Config.TempPath + + conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) + transfer := BaseTransfer{ + Connection: conn, + transferType: TransferUpload, + Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), "", nil), + } + transfer.BytesReceived.Store(123) + + fileName := filepath.Join(os.TempDir(), "_temp") + err := os.WriteFile(fileName, []byte(`data`), 0644) + assert.NoError(t, err) + + transfer.effectiveFsPath = fileName + res := transfer.checkUploadOutsideHomeDir(os.ErrPermission) + assert.Equal(t, 0, res) + + Config.TempPath = filepath.Clean(os.TempDir()) + res = transfer.checkUploadOutsideHomeDir(nil) + assert.Equal(t, 0, res) + assert.Greater(t, transfer.BytesReceived.Load(), int64(0)) + res = transfer.checkUploadOutsideHomeDir(os.ErrPermission) + assert.Equal(t, 1, res) + assert.Equal(t, int64(0), transfer.BytesReceived.Load()) + assert.NoFileExists(t, fileName) + + Config.TempPath = oldTempPath +} diff --git a/internal/common/transferschecker.go b/internal/common/transferschecker.go new file mode 100644 index 00000000..345c95b7 --- /dev/null +++ b/internal/common/transferschecker.go @@ -0,0 +1,329 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "errors" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type overquotaTransfer struct { + ConnID string + TransferID int64 + TransferType int +} + +type uploadAggregationKey struct { + Username string + FolderName string +} + +// TransfersChecker defines the interface that transfer checkers must implement. +// A transfer checker ensure that multiple concurrent transfers does not exceeded +// the remaining user quota +type TransfersChecker interface { + AddTransfer(transfer dataprovider.ActiveTransfer) + RemoveTransfer(ID int64, connectionID string) + UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) + GetOverquotaTransfers() []overquotaTransfer +} + +func getTransfersChecker(isShared int) TransfersChecker { + if isShared == 1 { + logger.Info(logSender, "", "using provider transfer checker") + return &transfersCheckerDB{} + } + logger.Info(logSender, "", "using memory transfer checker") + return &transfersCheckerMem{} +} + +type baseTransferChecker struct { + transfers []dataprovider.ActiveTransfer +} + +func (t *baseTransferChecker) isDataTransferExceeded(user dataprovider.User, transfer dataprovider.ActiveTransfer, ulSize, + dlSize int64, +) bool { + ulQuota, dlQuota, totalQuota := user.GetDataTransferLimits() + if totalQuota > 0 { + allowedSize := totalQuota - (user.UsedUploadDataTransfer + user.UsedDownloadDataTransfer) + if ulSize+dlSize > allowedSize { + return transfer.CurrentDLSize > 0 || transfer.CurrentULSize > 0 + } + } + if dlQuota > 0 { + allowedSize := dlQuota - user.UsedDownloadDataTransfer + if dlSize > allowedSize { + return transfer.CurrentDLSize > 0 + } + } + if ulQuota > 0 { + allowedSize := ulQuota - user.UsedUploadDataTransfer + if ulSize > allowedSize { + return transfer.CurrentULSize > 0 + } + } + return false +} + +func (t *baseTransferChecker) getRemainingDiskQuota(user dataprovider.User, folderName string) (int64, error) { + var result int64 + + if folderName != "" { + for _, folder := range user.VirtualFolders { + if folder.Name == folderName { + if folder.QuotaSize > 0 { + return folder.QuotaSize - folder.UsedQuotaSize, nil + } + } + } + } else { + if user.QuotaSize > 0 { + return user.QuotaSize - user.UsedQuotaSize, nil + } + } + + return result, errors.New("no quota limit defined") +} + +func (t *baseTransferChecker) aggregateTransfersByUser(usersToFetch map[string]bool, +) (map[string]bool, map[string][]dataprovider.ActiveTransfer) { + aggregations := make(map[string][]dataprovider.ActiveTransfer) + for _, transfer := range t.transfers { + aggregations[transfer.Username] = append(aggregations[transfer.Username], transfer) + if len(aggregations[transfer.Username]) > 1 { + if _, ok := usersToFetch[transfer.Username]; !ok { + usersToFetch[transfer.Username] = false + } + } + } + + return usersToFetch, aggregations +} + +func (t *baseTransferChecker) aggregateUploadTransfers() (map[string]bool, map[int][]dataprovider.ActiveTransfer) { + usersToFetch := make(map[string]bool) + aggregations := make(map[int][]dataprovider.ActiveTransfer) + var keys []uploadAggregationKey + + for _, transfer := range t.transfers { + if transfer.Type != TransferUpload { + continue + } + key := -1 + for idx, k := range keys { + if k.Username == transfer.Username && k.FolderName == transfer.FolderName { + key = idx + break + } + } + if key == -1 { + key = len(keys) + } + keys = append(keys, uploadAggregationKey{ + Username: transfer.Username, + FolderName: transfer.FolderName, + }) + + aggregations[key] = append(aggregations[key], transfer) + if len(aggregations[key]) > 1 { + if transfer.FolderName != "" { + usersToFetch[transfer.Username] = true + } else { + if _, ok := usersToFetch[transfer.Username]; !ok { + usersToFetch[transfer.Username] = false + } + } + } + } + + return usersToFetch, aggregations +} + +func (t *baseTransferChecker) getUsersToCheck(usersToFetch map[string]bool) (map[string]dataprovider.User, error) { + users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) + if err != nil { + return nil, err + } + + usersMap := make(map[string]dataprovider.User) + + for _, user := range users { + usersMap[user.Username] = user + } + + return usersMap, nil +} + +func (t *baseTransferChecker) getOverquotaTransfers(usersToFetch map[string]bool, + uploadAggregations map[int][]dataprovider.ActiveTransfer, + userAggregations map[string][]dataprovider.ActiveTransfer, +) []overquotaTransfer { + if len(usersToFetch) == 0 { + return nil + } + usersMap, err := t.getUsersToCheck(usersToFetch) + if err != nil { + logger.Warn(logSender, "", "unable to check transfers, error getting users quota: %v", err) + return nil + } + + var overquotaTransfers []overquotaTransfer + + for _, transfers := range uploadAggregations { + username := transfers[0].Username + folderName := transfers[0].FolderName + remaningDiskQuota, err := t.getRemainingDiskQuota(usersMap[username], folderName) + if err != nil { + continue + } + var usedDiskQuota int64 + for _, tr := range transfers { + // We optimistically assume that a cloud transfer that replaces an existing + // file will be successful + usedDiskQuota += tr.CurrentULSize - tr.TruncatedSize + } + logger.Debug(logSender, "", "username %q, folder %q, concurrent transfers: %v, remaining disk quota (bytes): %v, disk quota used in ongoing transfers (bytes): %v", + username, folderName, len(transfers), remaningDiskQuota, usedDiskQuota) + if usedDiskQuota > remaningDiskQuota { + for _, tr := range transfers { + if tr.CurrentULSize > tr.TruncatedSize { + overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ + ConnID: tr.ConnID, + TransferID: tr.ID, + TransferType: tr.Type, + }) + } + } + } + } + + for username, transfers := range userAggregations { + var ulSize, dlSize int64 + for _, tr := range transfers { + ulSize += tr.CurrentULSize + dlSize += tr.CurrentDLSize + } + logger.Debug(logSender, "", "username %q, concurrent transfers: %v, quota (bytes) used in ongoing transfers, ul: %v, dl: %v", + username, len(transfers), ulSize, dlSize) + for _, tr := range transfers { + if t.isDataTransferExceeded(usersMap[username], tr, ulSize, dlSize) { + overquotaTransfers = append(overquotaTransfers, overquotaTransfer{ + ConnID: tr.ConnID, + TransferID: tr.ID, + TransferType: tr.Type, + }) + } + } + } + + return overquotaTransfers +} + +type transfersCheckerMem struct { + sync.RWMutex + baseTransferChecker +} + +func (t *transfersCheckerMem) AddTransfer(transfer dataprovider.ActiveTransfer) { + t.Lock() + defer t.Unlock() + + t.transfers = append(t.transfers, transfer) +} + +func (t *transfersCheckerMem) RemoveTransfer(ID int64, connectionID string) { + t.Lock() + defer t.Unlock() + + for idx, transfer := range t.transfers { + if transfer.ID == ID && transfer.ConnID == connectionID { + lastIdx := len(t.transfers) - 1 + t.transfers[idx] = t.transfers[lastIdx] + t.transfers = t.transfers[:lastIdx] + return + } + } +} + +func (t *transfersCheckerMem) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) { + t.Lock() + defer t.Unlock() + + for idx := range t.transfers { + if t.transfers[idx].ID == ID && t.transfers[idx].ConnID == connectionID { + t.transfers[idx].CurrentDLSize = dlSize + t.transfers[idx].CurrentULSize = ulSize + t.transfers[idx].UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + return + } + } +} + +func (t *transfersCheckerMem) GetOverquotaTransfers() []overquotaTransfer { + t.RLock() + + usersToFetch, uploadAggregations := t.aggregateUploadTransfers() + usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch) + + t.RUnlock() + + return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations) +} + +type transfersCheckerDB struct { + baseTransferChecker + lastCleanup time.Time +} + +func (t *transfersCheckerDB) AddTransfer(transfer dataprovider.ActiveTransfer) { + dataprovider.AddActiveTransfer(transfer) +} + +func (t *transfersCheckerDB) RemoveTransfer(ID int64, connectionID string) { + dataprovider.RemoveActiveTransfer(ID, connectionID) +} + +func (t *transfersCheckerDB) UpdateTransferCurrentSizes(ulSize, dlSize, ID int64, connectionID string) { + dataprovider.UpdateActiveTransferSizes(ulSize, dlSize, ID, connectionID) +} + +func (t *transfersCheckerDB) GetOverquotaTransfers() []overquotaTransfer { + if t.lastCleanup.IsZero() || t.lastCleanup.Add(periodicTimeoutCheckInterval*15).Before(time.Now()) { + before := time.Now().Add(-periodicTimeoutCheckInterval * 5) + err := dataprovider.CleanupActiveTransfers(before) + logger.Debug(logSender, "", "cleanup active transfers completed, err: %v", err) + if err == nil { + t.lastCleanup = time.Now() + } + } + var err error + from := time.Now().Add(-periodicTimeoutCheckInterval * 2) + t.transfers, err = dataprovider.GetActiveTransfers(from) + if err != nil { + logger.Error(logSender, "", "unable to check overquota transfers, error getting active transfers: %v", err) + return nil + } + + usersToFetch, uploadAggregations := t.aggregateUploadTransfers() + usersToFetch, userAggregations := t.aggregateTransfersByUser(usersToFetch) + + return t.getOverquotaTransfers(usersToFetch, uploadAggregations, userAggregations) +} diff --git a/internal/common/transferschecker_test.go b/internal/common/transferschecker_test.go new file mode 100644 index 00000000..0528de61 --- /dev/null +++ b/internal/common/transferschecker_test.go @@ -0,0 +1,768 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package common + +import ( + "fmt" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func TestTransfersCheckerDiskQuota(t *testing.T) { + username := "transfers_check_username" + folderName := "test_transfers_folder" + groupName := "test_transfers_group" + vdirPath := "/vdir" + group := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: groupName, + }, + UserSettings: dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + QuotaSize: 120, + }, + }, + } + folder := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), folderName), + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "testpwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + QuotaSize: 0, // the quota size defined for the group is used + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaSize: 100, + }, + }, + Groups: []sdk.GroupMapping{ + { + Name: groupName, + Type: sdk.GroupTypePrimary, + }, + }, + } + err := dataprovider.AddGroup(&group, "", "", "") + assert.NoError(t, err) + group, err = dataprovider.GroupExists(groupName) + assert.NoError(t, err) + err = dataprovider.AddFolder(&folder, "", "", "") + assert.NoError(t, err) + assert.Equal(t, int64(120), group.UserSettings.QuotaSize) + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user, err = dataprovider.GetUserWithGroupSettings(username, "") + assert.NoError(t, err) + + connID1 := xid.New().String() + fsUser, err := user.GetFilesystemForPath("/file1", connID1) + assert.NoError(t, err) + conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "", user) + fakeConn1 := &fakeConnection{ + BaseConnection: conn1, + } + transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), + "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) + transfer1.BytesReceived.Store(150) + err = Connections.Add(fakeConn1) + assert.NoError(t, err) + // the transferschecker will do nothing if there is only one ongoing transfer + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + + connID2 := xid.New().String() + conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "", user) + fakeConn2 := &fakeConnection{ + BaseConnection: conn2, + } + transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), + "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{}) + transfer1.BytesReceived.Store(50) + transfer2.BytesReceived.Store(60) + err = Connections.Add(fakeConn2) + assert.NoError(t, err) + + connID3 := xid.New().String() + conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) + fakeConn3 := &fakeConnection{ + BaseConnection: conn3, + } + transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), + "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) + transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download + err = Connections.Add(fakeConn3) + assert.NoError(t, err) + + // the transfers are not overquota + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + + transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + transfer1.BytesReceived.Store(120) + // we are now overquota + // if another check is in progress nothing is done + Connections.transfersCheckStatus.Store(true) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + Connections.transfersCheckStatus.Store(false) + + Connections.checkTransfers() + assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort) + assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort), transfer2.errAbort) + assert.True(t, conn1.IsQuotaExceededError(transfer1.GetAbortError())) + assert.Nil(t, transfer3.errAbort) + assert.True(t, conn3.IsQuotaExceededError(transfer3.GetAbortError())) + // update the user quota size + group.UserSettings.QuotaSize = 1000 + err = dataprovider.UpdateGroup(&group, []string{username}, "", "", "") + assert.NoError(t, err) + transfer1.errAbort = nil + transfer2.errAbort = nil + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + + group.UserSettings.QuotaSize = 0 + err = dataprovider.UpdateGroup(&group, []string{username}, "", "", "") + assert.NoError(t, err) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + // now check a public folder + transfer1.BytesReceived.Store(0) + transfer2.BytesReceived.Store(0) + connID4 := xid.New().String() + fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4) + assert.NoError(t, err) + conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) + fakeConn4 := &fakeConnection{ + BaseConnection: conn4, + } + transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(os.TempDir(), folderName, "file1"), + filepath.Join(os.TempDir(), folderName, "file1"), path.Join(vdirPath, "/file1"), TransferUpload, 0, 0, + 100, 0, true, fsFolder, dataprovider.TransferQuota{}) + err = Connections.Add(fakeConn4) + assert.NoError(t, err) + connID5 := xid.New().String() + conn5 := NewBaseConnection(connID5, ProtocolSFTP, "", "", user) + fakeConn5 := &fakeConnection{ + BaseConnection: conn5, + } + transfer5 := NewBaseTransfer(nil, conn5, nil, filepath.Join(os.TempDir(), folderName, "file2"), + filepath.Join(os.TempDir(), folderName, "file2"), path.Join(vdirPath, "/file2"), TransferUpload, 0, 0, + 100, 0, true, fsFolder, dataprovider.TransferQuota{}) + + err = Connections.Add(fakeConn5) + assert.NoError(t, err) + transfer4.BytesReceived.Store(50) + transfer5.BytesReceived.Store(40) + Connections.checkTransfers() + assert.Nil(t, transfer4.errAbort) + assert.Nil(t, transfer5.errAbort) + transfer5.BytesReceived.Store(60) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + assert.True(t, conn1.IsQuotaExceededError(transfer4.errAbort)) + assert.True(t, conn2.IsQuotaExceededError(transfer5.errAbort)) + + if dataprovider.GetProviderStatus().Driver != dataprovider.MemoryDataProviderName { + providerConf := dataprovider.GetProviderConfig() + err = dataprovider.Close() + assert.NoError(t, err) + + transfer4.errAbort = nil + transfer5.errAbort = nil + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + assert.Nil(t, transfer3.errAbort) + assert.Nil(t, transfer4.errAbort) + assert.Nil(t, transfer5.errAbort) + + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + } + + err = transfer1.Close() + assert.NoError(t, err) + err = transfer2.Close() + assert.NoError(t, err) + err = transfer3.Close() + assert.NoError(t, err) + err = transfer4.Close() + assert.NoError(t, err) + err = transfer5.Close() + assert.NoError(t, err) + + Connections.Remove(fakeConn1.GetID()) + Connections.Remove(fakeConn2.GetID()) + Connections.Remove(fakeConn3.GetID()) + Connections.Remove(fakeConn4.GetID()) + Connections.Remove(fakeConn5.GetID()) + stats := Connections.GetStats("") + assert.Len(t, stats, 0) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.DeleteFolder(folderName, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), folderName)) + assert.NoError(t, err) + err = dataprovider.DeleteGroup(groupName, "", "", "") + assert.NoError(t, err) +} + +func TestTransferCheckerTransferQuota(t *testing.T) { + username := "transfers_check_username" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "test_pwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + TotalDataTransfer: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + connID1 := xid.New().String() + fsUser, err := user.GetFilesystemForPath("/file1", connID1) + assert.NoError(t, err) + conn1 := NewBaseConnection(connID1, ProtocolSFTP, "", "192.168.1.1", user) + fakeConn1 := &fakeConnection{ + BaseConnection: conn1, + } + transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), + "/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) + transfer1.BytesReceived.Store(150) + err = Connections.Add(fakeConn1) + assert.NoError(t, err) + // the transferschecker will do nothing if there is only one ongoing transfer + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + + connID2 := xid.New().String() + conn2 := NewBaseConnection(connID2, ProtocolSFTP, "", "127.0.0.1", user) + fakeConn2 := &fakeConnection{ + BaseConnection: conn2, + } + transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), + "/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) + transfer2.BytesReceived.Store(150) + err = Connections.Add(fakeConn2) + assert.NoError(t, err) + Connections.checkTransfers() + assert.Nil(t, transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + // now test overquota + transfer1.BytesReceived.Store(1024*1024 + 1) + transfer2.BytesReceived.Store(0) + Connections.checkTransfers() + assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort) + assert.Nil(t, transfer2.errAbort) + transfer1.errAbort = nil + transfer1.BytesReceived.Store(1024*1024 + 1) + transfer2.BytesReceived.Store(1024) + Connections.checkTransfers() + assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) + assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort)) + transfer1.BytesReceived.Store(0) + transfer2.BytesReceived.Store(0) + transfer1.errAbort = nil + transfer2.errAbort = nil + + err = transfer1.Close() + assert.NoError(t, err) + err = transfer2.Close() + assert.NoError(t, err) + Connections.Remove(fakeConn1.GetID()) + Connections.Remove(fakeConn2.GetID()) + + connID3 := xid.New().String() + conn3 := NewBaseConnection(connID3, ProtocolSFTP, "", "", user) + fakeConn3 := &fakeConnection{ + BaseConnection: conn3, + } + transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), + "/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) + transfer3.BytesSent.Store(150) + err = Connections.Add(fakeConn3) + assert.NoError(t, err) + + connID4 := xid.New().String() + conn4 := NewBaseConnection(connID4, ProtocolSFTP, "", "", user) + fakeConn4 := &fakeConnection{ + BaseConnection: conn4, + } + transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), + "/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) + transfer4.BytesSent.Store(150) + err = Connections.Add(fakeConn4) + assert.NoError(t, err) + Connections.checkTransfers() + assert.Nil(t, transfer3.errAbort) + assert.Nil(t, transfer4.errAbort) + + transfer3.BytesSent.Store(512 * 1024) + transfer4.BytesSent.Store(512*1024 + 1) + Connections.checkTransfers() + if assert.Error(t, transfer3.errAbort) { + assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error()) + } + if assert.Error(t, transfer4.errAbort) { + assert.Contains(t, transfer4.errAbort.Error(), ErrReadQuotaExceeded.Error()) + } + err = transfer3.Close() + assert.NoError(t, err) + err = transfer4.Close() + assert.NoError(t, err) + + Connections.Remove(fakeConn3.GetID()) + Connections.Remove(fakeConn4.GetID()) + stats := Connections.GetStats("") + assert.Len(t, stats, 0) + assert.Equal(t, int32(0), Connections.GetTotalTransfers()) + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAggregateTransfers(t *testing.T) { + checker := transfersCheckerMem{} + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "1", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + usersToFetch, aggregations := checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 1) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferDownload, + ConnID: "2", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 0, + CurrentDLSize: 100, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 1) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "3", + Username: "user", + FolderName: "folder", + TruncatedSize: 0, + CurrentULSize: 10, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 2) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "4", + Username: "user1", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 0) + assert.Len(t, aggregations, 3) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "5", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 1) + val, ok := usersToFetch["user"] + assert.True(t, ok) + assert.False(t, val) + assert.Len(t, aggregations, 3) + aggregate, ok := aggregations[0] + assert.True(t, ok) + assert.Len(t, aggregate, 2) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "6", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 1) + val, ok = usersToFetch["user"] + assert.True(t, ok) + assert.False(t, val) + assert.Len(t, aggregations, 3) + aggregate, ok = aggregations[0] + assert.True(t, ok) + assert.Len(t, aggregate, 3) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "7", + Username: "user", + FolderName: "folder", + TruncatedSize: 0, + CurrentULSize: 10, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 1) + val, ok = usersToFetch["user"] + assert.True(t, ok) + assert.True(t, val) + assert.Len(t, aggregations, 3) + aggregate, ok = aggregations[0] + assert.True(t, ok) + assert.Len(t, aggregate, 3) + aggregate, ok = aggregations[1] + assert.True(t, ok) + assert.Len(t, aggregate, 2) + + checker.AddTransfer(dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferUpload, + ConnID: "8", + Username: "user", + FolderName: "", + TruncatedSize: 0, + CurrentULSize: 100, + CurrentDLSize: 0, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + }) + + usersToFetch, aggregations = checker.aggregateUploadTransfers() + assert.Len(t, usersToFetch, 1) + val, ok = usersToFetch["user"] + assert.True(t, ok) + assert.True(t, val) + assert.Len(t, aggregations, 3) + aggregate, ok = aggregations[0] + assert.True(t, ok) + assert.Len(t, aggregate, 4) + aggregate, ok = aggregations[1] + assert.True(t, ok) + assert.Len(t, aggregate, 2) +} + +func TestDataTransferExceeded(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + TotalDataTransfer: 1, + }, + } + transfer := dataprovider.ActiveTransfer{ + CurrentULSize: 0, + CurrentDLSize: 0, + } + user.UsedDownloadDataTransfer = 1024 * 1024 + user.UsedUploadDataTransfer = 512 * 1024 + checker := transfersCheckerMem{} + res := checker.isDataTransferExceeded(user, transfer, 100, 100) + assert.False(t, res) + transfer.CurrentULSize = 1 + res = checker.isDataTransferExceeded(user, transfer, 100, 100) + assert.True(t, res) + user.UsedDownloadDataTransfer = 512*1024 - 100 + user.UsedUploadDataTransfer = 512*1024 - 100 + res = checker.isDataTransferExceeded(user, transfer, 100, 100) + assert.False(t, res) + res = checker.isDataTransferExceeded(user, transfer, 101, 100) + assert.True(t, res) + + user.TotalDataTransfer = 0 + user.DownloadDataTransfer = 1 + user.UsedDownloadDataTransfer = 512 * 1024 + transfer.CurrentULSize = 0 + transfer.CurrentDLSize = 100 + res = checker.isDataTransferExceeded(user, transfer, 0, 512*1024) + assert.False(t, res) + res = checker.isDataTransferExceeded(user, transfer, 0, 512*1024+1) + assert.True(t, res) + + user.DownloadDataTransfer = 0 + user.UploadDataTransfer = 1 + user.UsedUploadDataTransfer = 512 * 1024 + transfer.CurrentULSize = 0 + transfer.CurrentDLSize = 0 + res = checker.isDataTransferExceeded(user, transfer, 512*1024+1, 0) + assert.False(t, res) + transfer.CurrentULSize = 1 + res = checker.isDataTransferExceeded(user, transfer, 512*1024+1, 0) + assert.True(t, res) +} + +func TestGetUsersForQuotaCheck(t *testing.T) { + usersToFetch := make(map[string]bool) + for i := 0; i < 70; i++ { + usersToFetch[fmt.Sprintf("user%v", i)] = i%2 == 0 + } + + users, err := dataprovider.GetUsersForQuotaCheck(usersToFetch) + assert.NoError(t, err) + assert.Len(t, users, 0) + + for i := 0; i < 60; i++ { + folder := vfs.BaseVirtualFolder{ + Name: fmt.Sprintf("f%v", i), + MappedPath: filepath.Join(os.TempDir(), fmt.Sprintf("f%v", i)), + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: fmt.Sprintf("user%v", i), + Password: "pwd", + HomeDir: filepath.Join(os.TempDir(), fmt.Sprintf("user%v", i)), + Status: 1, + QuotaSize: 120, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + VirtualFolders: []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder.Name, + }, + VirtualPath: "/vfolder", + QuotaSize: 100, + }, + }, + } + err = dataprovider.AddFolder(&folder, "", "", "") + assert.NoError(t, err) + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + err = dataprovider.UpdateVirtualFolderQuota(&vfs.BaseVirtualFolder{Name: fmt.Sprintf("f%v", i)}, 1, 50, false) + assert.NoError(t, err) + } + + users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch) + assert.NoError(t, err) + assert.Len(t, users, 60) + + for _, user := range users { + userIdxStr := strings.Replace(user.Username, "user", "", 1) + userIdx, err := strconv.Atoi(userIdxStr) + assert.NoError(t, err) + if userIdx%2 == 0 { + if assert.Len(t, user.VirtualFolders, 1, user.Username) { + assert.Equal(t, int64(100), user.VirtualFolders[0].QuotaSize) + assert.Equal(t, int64(50), user.VirtualFolders[0].UsedQuotaSize) + } + } else { + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + assert.Len(t, user.VirtualFolders, 0, user.Username) + } + } + ul, dl, total := user.GetDataTransferLimits() + assert.Equal(t, int64(0), ul) + assert.Equal(t, int64(0), dl) + assert.Equal(t, int64(0), total) + } + + for i := 0; i < 60; i++ { + err = dataprovider.DeleteUser(fmt.Sprintf("user%v", i), "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteFolder(fmt.Sprintf("f%v", i), "", "", "") + assert.NoError(t, err) + } + + users, err = dataprovider.GetUsersForQuotaCheck(usersToFetch) + assert.NoError(t, err) + assert.Len(t, users, 0) +} + +func TestDBTransferChecker(t *testing.T) { + if !isDbTransferCheckerSupported() { + t.Skip("this test is not supported with the current database provider") + } + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + providerConf.IsShared = 1 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + c := getTransfersChecker(1) + checker, ok := c.(*transfersCheckerDB) + assert.True(t, ok) + assert.True(t, checker.lastCleanup.IsZero()) + transfer1 := dataprovider.ActiveTransfer{ + ID: 1, + Type: TransferDownload, + ConnID: xid.New().String(), + Username: "user1", + FolderName: "folder1", + IP: "127.0.0.1", + } + checker.AddTransfer(transfer1) + transfers, err := dataprovider.GetActiveTransfers(time.Now().Add(24 * time.Hour)) + assert.NoError(t, err) + assert.Len(t, transfers, 0) + transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) + assert.NoError(t, err) + var createdAt, updatedAt int64 + if assert.Len(t, transfers, 1) { + transfer := transfers[0] + assert.Equal(t, transfer1.ID, transfer.ID) + assert.Equal(t, transfer1.Type, transfer.Type) + assert.Equal(t, transfer1.ConnID, transfer.ConnID) + assert.Equal(t, transfer1.Username, transfer.Username) + assert.Equal(t, transfer1.IP, transfer.IP) + assert.Equal(t, transfer1.FolderName, transfer.FolderName) + assert.Greater(t, transfer.CreatedAt, int64(0)) + assert.Greater(t, transfer.UpdatedAt, int64(0)) + assert.Equal(t, int64(0), transfer.CurrentDLSize) + assert.Equal(t, int64(0), transfer.CurrentULSize) + createdAt = transfer.CreatedAt + updatedAt = transfer.UpdatedAt + } + time.Sleep(100 * time.Millisecond) + checker.UpdateTransferCurrentSizes(100, 150, transfer1.ID, transfer1.ConnID) + transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) + assert.NoError(t, err) + if assert.Len(t, transfers, 1) { + transfer := transfers[0] + assert.Equal(t, int64(150), transfer.CurrentDLSize) + assert.Equal(t, int64(100), transfer.CurrentULSize) + assert.Equal(t, createdAt, transfer.CreatedAt) + assert.Greater(t, transfer.UpdatedAt, updatedAt) + } + res := checker.GetOverquotaTransfers() + assert.Len(t, res, 0) + + checker.RemoveTransfer(transfer1.ID, transfer1.ConnID) + transfers, err = dataprovider.GetActiveTransfers(time.Now().Add(-periodicTimeoutCheckInterval * 2)) + assert.NoError(t, err) + assert.Len(t, transfers, 0) + + err = dataprovider.Close() + assert.NoError(t, err) + res = checker.GetOverquotaTransfers() + assert.Len(t, res, 0) + providerConf.IsShared = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func isDbTransferCheckerSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..58863c6d --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,2285 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package config manages the configuration +package config + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + + kmsplugin "github.com/sftpgo/sdk/plugin/kms" + "github.com/spf13/viper" + "github.com/subosito/gotenv" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/command" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/telemetry" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + logSender = "config" + // configName defines the name for config file. + // This name does not include the extension, viper will search for files + // with supported extensions such as "sftpgo.json", "sftpgo.yaml" and so on + configName = "sftpgo" + // ConfigEnvPrefix defines a prefix that environment variables will use + configEnvPrefix = "sftpgo" + envFileMaxSize = 1048576 +) + +var ( + globalConf globalConfig + defaultInstallCodeHint = "Installation code" + defaultSFTPDBinding = sftpd.Binding{ + Address: "", + Port: 2022, + ApplyProxyConfig: true, + } + defaultFTPDBinding = ftpd.Binding{ + Address: "", + Port: 0, + ApplyProxyConfig: true, + TLSMode: 0, + CertificateFile: "", + CertificateKeyFile: "", + MinTLSVersion: 12, + ForcePassiveIP: "", + PassiveIPOverrides: nil, + PassiveHost: "", + ClientAuthType: 0, + TLSCipherSuites: nil, + PassiveConnectionsSecurity: 0, + ActiveConnectionsSecurity: 0, + Debug: false, + } + defaultWebDAVDBinding = webdavd.Binding{ + Address: "", + Port: 0, + EnableHTTPS: false, + CertificateFile: "", + CertificateKeyFile: "", + MinTLSVersion: 12, + ClientAuthType: 0, + TLSCipherSuites: nil, + Protocols: nil, + Prefix: "", + ProxyMode: 0, + ProxyAllowed: nil, + ClientIPProxyHeader: "", + ClientIPHeaderDepth: 0, + DisableWWWAuthHeader: false, + } + defaultHTTPDBinding = httpd.Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + EnabledLoginMethods: 0, + DisabledLoginMethods: 0, + EnableHTTPS: false, + CertificateFile: "", + CertificateKeyFile: "", + MinTLSVersion: 12, + ClientAuthType: 0, + TLSCipherSuites: nil, + Protocols: nil, + ProxyMode: 0, + ProxyAllowed: nil, + ClientIPProxyHeader: "", + ClientIPHeaderDepth: 0, + HideLoginURL: 0, + RenderOpenAPI: true, + BaseURL: "", + Languages: []string{"en"}, + OIDC: httpd.OIDC{ + ClientID: "", + ClientSecret: "", + ClientSecretFile: "", + ConfigURL: "", + RedirectBaseURL: "", + UsernameField: "", + RoleField: "", + ImplicitRoles: false, + Scopes: []string{"openid", "profile", "email"}, + CustomFields: []string{}, + InsecureSkipSignatureCheck: false, + Debug: false, + }, + Security: httpd.SecurityConf{ + Enabled: false, + AllowedHosts: nil, + AllowedHostsAreRegex: false, + HostsProxyHeaders: nil, + HTTPSRedirect: false, + HTTPSHost: "", + HTTPSProxyHeaders: nil, + STSSeconds: 0, + STSIncludeSubdomains: false, + STSPreload: false, + ContentTypeNosniff: false, + ContentSecurityPolicy: "", + PermissionsPolicy: "", + CrossOriginOpenerPolicy: "", + CrossOriginResourcePolicy: "", + CrossOriginEmbedderPolicy: "", + CacheControl: "", + }, + Branding: httpd.Branding{}, + } + defaultRateLimiter = common.RateLimiterConfig{ + Average: 0, + Period: 1000, + Burst: 1, + Type: 2, + Protocols: []string{common.ProtocolSSH, common.ProtocolFTP, common.ProtocolWebDAV, common.ProtocolHTTP}, + GenerateDefenderEvents: false, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + } + defaultTOTP = mfa.TOTPConfig{ + Name: "Default", + Issuer: "SFTPGo", + Algo: mfa.TOTPAlgoSHA1, + } +) + +type globalConfig struct { + Common common.Configuration `json:"common" mapstructure:"common"` + ACME acme.Configuration `json:"acme" mapstructure:"acme"` + SFTPD sftpd.Configuration `json:"sftpd" mapstructure:"sftpd"` + FTPD ftpd.Configuration `json:"ftpd" mapstructure:"ftpd"` + WebDAVD webdavd.Configuration `json:"webdavd" mapstructure:"webdavd"` + ProviderConf dataprovider.Config `json:"data_provider" mapstructure:"data_provider"` + HTTPDConfig httpd.Conf `json:"httpd" mapstructure:"httpd"` + HTTPConfig httpclient.Config `json:"http" mapstructure:"http"` + CommandConfig command.Config `json:"command" mapstructure:"command"` + KMSConfig kms.Configuration `json:"kms" mapstructure:"kms"` + MFAConfig mfa.Config `json:"mfa" mapstructure:"mfa"` + TelemetryConfig telemetry.Conf `json:"telemetry" mapstructure:"telemetry"` + PluginsConfig []plugin.Config `json:"plugins" mapstructure:"plugins"` + SMTPConfig smtp.Config `json:"smtp" mapstructure:"smtp"` +} + +func init() { + Init() +} + +// Init initializes the global configuration. +// It is not supposed to be called outside of this package. +// It is exported to minimize refactoring efforts. Will eventually disappear. +func Init() { + // create a default configuration to use if no config file is provided + globalConf = globalConfig{ + Common: common.Configuration{ + IdleTimeout: 15, + UploadMode: 0, + Actions: common.ProtocolActions{ + ExecuteOn: []string{}, + ExecuteSync: []string{}, + Hook: "", + }, + SetstatMode: 0, + RenameMode: 0, + ResumeMaxSize: 0, + TempPath: "", + ProxyProtocol: 0, + ProxyAllowed: []string{}, + ProxySkipped: []string{}, + PostConnectHook: "", + PostDisconnectHook: "", + MaxTotalConnections: 0, + MaxPerHostConnections: 20, + AllowListStatus: 0, + AllowSelfConnections: 0, + DefenderConfig: common.DefenderConfig{ + Enabled: false, + Driver: common.DefenderDriverMemory, + BanTime: 30, + BanTimeIncrement: 50, + Threshold: 15, + ScoreInvalid: 2, + ScoreValid: 1, + ScoreLimitExceeded: 3, + ScoreNoAuth: 0, + ObservationTime: 30, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + LoginDelay: common.LoginDelay{ + Success: 0, + PasswordFailed: 1000, + }, + }, + RateLimitersConfig: []common.RateLimiterConfig{defaultRateLimiter}, + Umask: "", + ServerVersion: "", + TZ: "", + Metadata: common.MetadataConfig{ + Read: 0, + }, + EventManager: common.EventManagerConfig{ + EnabledCommands: []string{}, + }, + }, + ACME: acme.Configuration{ + Email: "", + KeyType: "4096", + CertsPath: "certs", + CAEndpoint: "https://acme-v02.api.letsencrypt.org/directory", + Domains: []string{}, + RenewDays: 30, + HTTP01Challenge: acme.HTTP01Challenge{ + Port: 80, + WebRoot: "", + ProxyHeader: "", + }, + TLSALPN01Challenge: acme.TLSALPN01Challenge{ + Port: 0, + }, + }, + SFTPD: sftpd.Configuration{ + Bindings: []sftpd.Binding{defaultSFTPDBinding}, + MaxAuthTries: 0, + HostKeys: []string{}, + HostCertificates: []string{}, + HostKeyAlgorithms: []string{}, + KexAlgorithms: []string{}, + Ciphers: []string{}, + MACs: []string{}, + PublicKeyAlgorithms: []string{}, + TrustedUserCAKeys: []string{}, + RevokedUserCertsFile: "", + OPKSSHPath: "", + OPKSSHChecksum: "", + LoginBannerFile: "", + EnabledSSHCommands: []string{}, + KeyboardInteractiveAuthentication: true, + KeyboardInteractiveHook: "", + PasswordAuthentication: true, + }, + FTPD: ftpd.Configuration{ + Bindings: []ftpd.Binding{defaultFTPDBinding}, + BannerFile: "", + ActiveTransfersPortNon20: true, + PassivePortRange: ftpd.PortRange{ + Start: 50000, + End: 50100, + }, + DisableActiveMode: false, + EnableSite: false, + HASHSupport: 0, + CombineSupport: 0, + CertificateFile: "", + CertificateKeyFile: "", + CACertificates: []string{}, + CARevocationLists: []string{}, + }, + WebDAVD: webdavd.Configuration{ + Bindings: []webdavd.Binding{defaultWebDAVDBinding}, + CertificateFile: "", + CertificateKeyFile: "", + CACertificates: []string{}, + CARevocationLists: []string{}, + Cors: webdavd.CorsConfig{ + Enabled: false, + AllowedOrigins: []string{}, + AllowedMethods: []string{}, + AllowedHeaders: []string{}, + ExposedHeaders: []string{}, + AllowCredentials: false, + MaxAge: 0, + OptionsPassthrough: false, + OptionsSuccessStatus: 0, + AllowPrivateNetwork: false, + }, + Cache: webdavd.Cache{ + Users: webdavd.UsersCacheConfig{ + ExpirationTime: 0, + MaxSize: 50, + }, + MimeTypes: webdavd.MimeCacheConfig{ + Enabled: true, + MaxSize: 1000, + CustomMappings: nil, + }, + }, + }, + ProviderConf: dataprovider.Config{ + Driver: "sqlite", + Name: "sftpgo.db", + Host: "", + Port: 0, + Username: "", + Password: "", + ConnectionString: "", + SQLTablesPrefix: "", + SSLMode: 0, + DisableSNI: false, + TargetSessionAttrs: "", + RootCert: "", + ClientCert: "", + ClientKey: "", + TrackQuota: 2, + PoolSize: 0, + UsersBaseDir: "", + Actions: dataprovider.ObjectsActions{ + ExecuteOn: []string{}, + ExecuteFor: []string{}, + Hook: "", + }, + ExternalAuthHook: "", + ExternalAuthScope: 0, + PreLoginHook: "", + PostLoginHook: "", + PostLoginScope: 0, + CheckPasswordHook: "", + CheckPasswordScope: 0, + PasswordHashing: dataprovider.PasswordHashing{ + Argon2Options: dataprovider.Argon2Options{ + Memory: 65536, + Iterations: 1, + Parallelism: 2, + }, + BcryptOptions: dataprovider.BcryptOptions{ + Cost: 10, + }, + Algo: dataprovider.HashingAlgoBcrypt, + }, + PasswordValidation: dataprovider.PasswordValidation{ + Admins: dataprovider.PasswordValidationRules{ + MinEntropy: 0, + }, + Users: dataprovider.PasswordValidationRules{ + MinEntropy: 0, + }, + }, + PasswordCaching: true, + UpdateMode: 0, + DelayedQuotaUpdate: 0, + CreateDefaultAdmin: false, + NamingRules: 1, + IsShared: 0, + Node: dataprovider.NodeConfig{ + Host: "", + Port: 0, + Proto: "http", + }, + BackupsPath: "backups", + }, + HTTPDConfig: httpd.Conf{ + Bindings: []httpd.Binding{defaultHTTPDBinding}, + TemplatesPath: "templates", + StaticFilesPath: "static", + OpenAPIPath: "openapi", + WebRoot: "", + CertificateFile: "", + CertificateKeyFile: "", + CACertificates: nil, + CARevocationLists: nil, + SigningPassphrase: "", + SigningPassphraseFile: "", + TokenValidation: 0, + CookieLifetime: 20, + ShareCookieLifetime: 120, + JWTLifetime: 20, + MaxUploadFileSize: 0, + Cors: httpd.CorsConfig{ + Enabled: false, + AllowedOrigins: []string{}, + AllowedMethods: []string{}, + AllowedHeaders: []string{}, + ExposedHeaders: []string{}, + AllowCredentials: false, + MaxAge: 0, + OptionsPassthrough: false, + OptionsSuccessStatus: 0, + AllowPrivateNetwork: false, + }, + Setup: httpd.SetupConfig{ + InstallationCode: "", + InstallationCodeHint: defaultInstallCodeHint, + }, + HideSupportLink: false, + }, + HTTPConfig: httpclient.Config{ + Timeout: 20, + RetryWaitMin: 2, + RetryWaitMax: 30, + RetryMax: 3, + CACertificates: nil, + Certificates: nil, + SkipTLSVerify: false, + Headers: nil, + }, + CommandConfig: command.Config{ + Timeout: 30, + Env: nil, + Commands: nil, + }, + KMSConfig: kms.Configuration{ + Secrets: kms.Secrets{ + URL: "", + MasterKeyString: "", + MasterKeyPath: "", + }, + }, + MFAConfig: mfa.Config{ + TOTP: []mfa.TOTPConfig{defaultTOTP}, + }, + TelemetryConfig: telemetry.Conf{ + BindPort: 0, + BindAddress: "127.0.0.1", + EnableProfiler: false, + AuthUserFile: "", + CertificateFile: "", + CertificateKeyFile: "", + MinTLSVersion: 12, + TLSCipherSuites: nil, + Protocols: nil, + }, + SMTPConfig: smtp.Config{ + Host: "", + Port: 587, + From: "", + User: "", + Password: "", + AuthType: 0, + Encryption: 0, + Domain: "", + TemplatesPath: "templates", + }, + PluginsConfig: nil, + } + + viper.SetEnvPrefix(configEnvPrefix) + replacer := strings.NewReplacer(".", "__") + viper.SetEnvKeyReplacer(replacer) + viper.SetConfigName(configName) + setViperDefaults() + viper.AutomaticEnv() + viper.AllowEmptyEnv(true) +} + +// GetCommonConfig returns the common protocols configuration +func GetCommonConfig() common.Configuration { + return globalConf.Common +} + +// SetCommonConfig sets the common protocols configuration +func SetCommonConfig(config common.Configuration) { + globalConf.Common = config +} + +// GetSFTPDConfig returns the configuration for the SFTP server +func GetSFTPDConfig() sftpd.Configuration { + return globalConf.SFTPD +} + +// SetSFTPDConfig sets the configuration for the SFTP server +func SetSFTPDConfig(config sftpd.Configuration) { + globalConf.SFTPD = config +} + +// GetFTPDConfig returns the configuration for the FTP server +func GetFTPDConfig() ftpd.Configuration { + return globalConf.FTPD +} + +// SetFTPDConfig sets the configuration for the FTP server +func SetFTPDConfig(config ftpd.Configuration) { + globalConf.FTPD = config +} + +// GetWebDAVDConfig returns the configuration for the WebDAV server +func GetWebDAVDConfig() webdavd.Configuration { + return globalConf.WebDAVD +} + +// SetWebDAVDConfig sets the configuration for the WebDAV server +func SetWebDAVDConfig(config webdavd.Configuration) { + globalConf.WebDAVD = config +} + +// GetHTTPDConfig returns the configuration for the HTTP server +func GetHTTPDConfig() httpd.Conf { + return globalConf.HTTPDConfig +} + +// SetHTTPDConfig sets the configuration for the HTTP server +func SetHTTPDConfig(config httpd.Conf) { + globalConf.HTTPDConfig = config +} + +// GetProviderConf returns the configuration for the data provider +func GetProviderConf() dataprovider.Config { + return globalConf.ProviderConf +} + +// SetProviderConf sets the configuration for the data provider +func SetProviderConf(config dataprovider.Config) { + globalConf.ProviderConf = config +} + +// GetHTTPConfig returns the configuration for HTTP clients +func GetHTTPConfig() httpclient.Config { + return globalConf.HTTPConfig +} + +// GetCommandConfig returns the configuration for external commands +func GetCommandConfig() command.Config { + return globalConf.CommandConfig +} + +// GetKMSConfig returns the KMS configuration +func GetKMSConfig() kms.Configuration { + return globalConf.KMSConfig +} + +// SetKMSConfig sets the kms configuration +func SetKMSConfig(config kms.Configuration) { + globalConf.KMSConfig = config +} + +// GetTelemetryConfig returns the telemetry configuration +func GetTelemetryConfig() telemetry.Conf { + return globalConf.TelemetryConfig +} + +// SetTelemetryConfig sets the telemetry configuration +func SetTelemetryConfig(config telemetry.Conf) { + globalConf.TelemetryConfig = config +} + +// GetPluginsConfig returns the plugins configuration +func GetPluginsConfig() []plugin.Config { + return globalConf.PluginsConfig +} + +// SetPluginsConfig sets the plugin configuration +func SetPluginsConfig(config []plugin.Config) { + globalConf.PluginsConfig = config +} + +// HasKMSPlugin returns true if at least one KMS plugin is configured. +func HasKMSPlugin() bool { + for _, c := range globalConf.PluginsConfig { + if c.Type == kmsplugin.PluginName { + return true + } + } + return false +} + +// GetMFAConfig returns multi-factor authentication config +func GetMFAConfig() mfa.Config { + return globalConf.MFAConfig +} + +// GetSMTPConfig returns the SMTP configuration +func GetSMTPConfig() smtp.Config { + return globalConf.SMTPConfig +} + +// GetACMEConfig returns the ACME configuration +func GetACMEConfig() acme.Configuration { + return globalConf.ACME +} + +// HasServicesToStart returns true if the config defines at least a service to start. +// Supported services are SFTP, FTP and WebDAV +func HasServicesToStart() bool { + if globalConf.SFTPD.ShouldBind() { + return true + } + if globalConf.FTPD.ShouldBind() { + return true + } + if globalConf.WebDAVD.ShouldBind() { + return true + } + if globalConf.HTTPDConfig.ShouldBind() { + return true + } + return false +} + +func getRedactedPassword(value string) string { + if value == "" { + return value + } + return "[redacted]" +} + +func getRedactedGlobalConf() globalConfig { + conf := globalConf + conf.Common.Actions.Hook = util.GetRedactedURL(conf.Common.Actions.Hook) + conf.Common.StartupHook = util.GetRedactedURL(conf.Common.StartupHook) + conf.Common.PostConnectHook = util.GetRedactedURL(conf.Common.PostConnectHook) + conf.Common.PostDisconnectHook = util.GetRedactedURL(conf.Common.PostDisconnectHook) + conf.SFTPD.KeyboardInteractiveHook = util.GetRedactedURL(conf.SFTPD.KeyboardInteractiveHook) + conf.HTTPDConfig.SigningPassphrase = getRedactedPassword(conf.HTTPDConfig.SigningPassphrase) + conf.HTTPDConfig.Setup.InstallationCode = getRedactedPassword(conf.HTTPDConfig.Setup.InstallationCode) + conf.ProviderConf.Password = getRedactedPassword(conf.ProviderConf.Password) + conf.ProviderConf.Actions.Hook = util.GetRedactedURL(conf.ProviderConf.Actions.Hook) + conf.ProviderConf.ExternalAuthHook = util.GetRedactedURL(conf.ProviderConf.ExternalAuthHook) + conf.ProviderConf.PreLoginHook = util.GetRedactedURL(conf.ProviderConf.PreLoginHook) + conf.ProviderConf.PostLoginHook = util.GetRedactedURL(conf.ProviderConf.PostLoginHook) + conf.ProviderConf.CheckPasswordHook = util.GetRedactedURL(conf.ProviderConf.CheckPasswordHook) + conf.SMTPConfig.Password = getRedactedPassword(conf.SMTPConfig.Password) + conf.HTTPDConfig.Bindings = nil + for _, binding := range globalConf.HTTPDConfig.Bindings { + binding.OIDC.ClientID = getRedactedPassword(binding.OIDC.ClientID) + binding.OIDC.ClientSecret = getRedactedPassword(binding.OIDC.ClientSecret) + conf.HTTPDConfig.Bindings = append(conf.HTTPDConfig.Bindings, binding) + } + conf.KMSConfig.Secrets.MasterKeyString = getRedactedPassword(conf.KMSConfig.Secrets.MasterKeyString) + conf.PluginsConfig = nil + for _, plugin := range globalConf.PluginsConfig { + var args []string + for _, arg := range plugin.Args { + args = append(args, getRedactedPassword(arg)) + } + plugin.Args = args + conf.PluginsConfig = append(conf.PluginsConfig, plugin) + } + return conf +} + +func setConfigFile(configDir, configFile string) { + if configFile == "" { + return + } + if !filepath.IsAbs(configFile) && util.IsFileInputValid(configFile) { + configFile = filepath.Join(configDir, configFile) + } + viper.SetConfigFile(configFile) +} + +// readEnvFiles reads files inside the "env.d" directory relative to configDir +// and then export the valid variables into environment variables if they do +// not exist +func readEnvFiles(configDir string) { + envd := filepath.Join(configDir, "env.d") + entries, err := os.ReadDir(envd) + if err != nil { + logger.Info(logSender, "", "unable to read env files from %q: %v", envd, err) + return + } + for _, entry := range entries { + info, err := entry.Info() + if err == nil && info.Mode().IsRegular() { + envFile := filepath.Join(envd, entry.Name()) + if info.Size() > envFileMaxSize { + logger.Info(logSender, "", "env file %q too big: %s, skipping", entry.Name(), util.ByteCountIEC(info.Size())) + continue + } + err = gotenv.Load(envFile) + if err != nil { + logger.Error(logSender, "", "unable to load env vars from file %q, err: %v", envFile, err) + } else { + logger.Info(logSender, "", "set env vars from file %q", envFile) + } + } + } +} + +func checkOverrideDefaultSettings() { + // for slices we need to set the defaults to nil if the key is set in the config file, + // otherwise the values are merged and not replaced as expected + rateLimiters := viper.Get("common.rate_limiters") + if val, ok := rateLimiters.([]any); ok { + if len(val) > 0 { + if rl, ok := val[0].(map[string]any); ok { + if _, ok := rl["protocols"]; ok { + globalConf.Common.RateLimitersConfig[0].Protocols = nil + } + } + } + } + + httpdBindings := viper.Get("httpd.bindings") + if val, ok := httpdBindings.([]any); ok { + if len(val) > 0 { + if binding, ok := val[0].(map[string]any); ok { + if val, ok := binding["oidc"]; ok { + if oidc, ok := val.(map[string]any); ok { + if _, ok := oidc["scopes"]; ok { + globalConf.HTTPDConfig.Bindings[0].OIDC.Scopes = nil + } + } + } + } + } + } + + if slices.Contains(viper.AllKeys(), "mfa.totp") { + globalConf.MFAConfig.TOTP = nil + } +} + +// LoadConfig loads the configuration +// configDir will be added to the configuration search paths. +// The search path contains by default the current directory and on linux it contains +// $HOME/.config/sftpgo and /etc/sftpgo too. +// configFile is an absolute or relative path (to the config dir) to the configuration file. +func LoadConfig(configDir, configFile string) error { + var err error + readEnvFiles(configDir) + viper.AddConfigPath(configDir) + setViperAdditionalConfigPaths() + viper.AddConfigPath(".") + setConfigFile(configDir, configFile) + if err = viper.ReadInConfig(); err != nil { + // if the user specify a configuration file we get os.ErrNotExist. + // viper.ConfigFileNotFoundError is returned if viper is unable + // to find sftpgo.{json,yaml, etc..} in any of the search paths + if errors.As(err, &viper.ConfigFileNotFoundError{}) { + logger.Debug(logSender, "", "no configuration file found") + } else { + logger.Warn(logSender, "", "error loading configuration file: %v", err) + logger.WarnToConsole("error loading configuration file: %v", err) + return err + } + } + checkOverrideDefaultSettings() + err = viper.Unmarshal(&globalConf) + if err != nil { + logger.Warn(logSender, "", "error parsing configuration file: %v", err) + logger.WarnToConsole("error parsing configuration file: %v", err) + return err + } + // viper only supports slice of strings from env vars, so we use our custom method + loadBindingsFromEnv() + loadWebDAVCacheMappingsFromEnv() + resetInvalidConfigs() + logger.Debug(logSender, "", "config file used: '%q', config loaded: %+v", viper.ConfigFileUsed(), getRedactedGlobalConf()) + return nil +} + +func isProxyProtocolValid() bool { + return globalConf.Common.ProxyProtocol >= 0 && globalConf.Common.ProxyProtocol <= 2 +} + +func isExternalAuthScopeValid() bool { + return globalConf.ProviderConf.ExternalAuthScope >= 0 && globalConf.ProviderConf.ExternalAuthScope <= 15 +} + +func resetInvalidConfigs() { + if strings.TrimSpace(globalConf.HTTPDConfig.Setup.InstallationCodeHint) == "" { + globalConf.HTTPDConfig.Setup.InstallationCodeHint = defaultInstallCodeHint + } + if globalConf.ProviderConf.UsersBaseDir != "" && !util.IsFileInputValid(globalConf.ProviderConf.UsersBaseDir) { + warn := fmt.Sprintf("invalid users base dir %q will be ignored", globalConf.ProviderConf.UsersBaseDir) + globalConf.ProviderConf.UsersBaseDir = "" + logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) + logger.WarnToConsole("Non-fatal configuration error: %v", warn) + } + if !isProxyProtocolValid() { + warn := fmt.Sprintf("invalid proxy_protocol 0, 1 and 2 are supported, configured: %v reset proxy_protocol to 0", + globalConf.Common.ProxyProtocol) + globalConf.Common.ProxyProtocol = 0 + logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) + logger.WarnToConsole("Non-fatal configuration error: %v", warn) + } + if !isExternalAuthScopeValid() { + warn := fmt.Sprintf("invalid external_auth_scope: %v reset to 0", globalConf.ProviderConf.ExternalAuthScope) + globalConf.ProviderConf.ExternalAuthScope = 0 + logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) + logger.WarnToConsole("Non-fatal configuration error: %v", warn) + } + if globalConf.Common.DefenderConfig.Enabled && globalConf.Common.DefenderConfig.Driver == common.DefenderDriverProvider { + if !globalConf.ProviderConf.IsDefenderSupported() { + warn := fmt.Sprintf("provider based defender is not supported with data provider %q, "+ + "the memory defender implementation will be used. If you want to use the provider defender "+ + "implementation please switch to a shared/distributed data provider", + globalConf.ProviderConf.Driver) + globalConf.Common.DefenderConfig.Driver = common.DefenderDriverMemory + logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) + logger.WarnToConsole("Non-fatal configuration error: %v", warn) + } + } + if globalConf.Common.RenameMode < 0 || globalConf.Common.RenameMode > 1 { + warn := fmt.Sprintf("invalid rename mode %d, reset to 0", globalConf.Common.RenameMode) + globalConf.Common.RenameMode = 0 + logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) + logger.WarnToConsole("Non-fatal configuration error: %v", warn) + } +} + +func loadBindingsFromEnv() { + for idx := 0; idx < 10; idx++ { + getTOTPFromEnv(idx) + getRateLimitersFromEnv(idx) + getPluginsFromEnv(idx) + getSFTPDBindindFromEnv(idx) + getFTPDBindingFromEnv(idx) + getWebDAVDBindingFromEnv(idx) + getHTTPDBindingFromEnv(idx) + getHTTPClientCertificatesFromEnv(idx) + getHTTPClientHeadersFromEnv(idx) + getCommandConfigsFromEnv(idx) + } +} + +func getTOTPFromEnv(idx int) { + totpConfig := defaultTOTP + if len(globalConf.MFAConfig.TOTP) > idx { + totpConfig = globalConf.MFAConfig.TOTP[idx] + } + + isSet := false + + name, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_MFA__TOTP__%v__NAME", idx)) + if ok { + totpConfig.Name = name + isSet = true + } + + issuer, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_MFA__TOTP__%v__ISSUER", idx)) + if ok { + totpConfig.Issuer = issuer + isSet = true + } + + algo, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_MFA__TOTP__%v__ALGO", idx)) + if ok { + totpConfig.Algo = algo + isSet = true + } + + if isSet { + if len(globalConf.MFAConfig.TOTP) > idx { + globalConf.MFAConfig.TOTP[idx] = totpConfig + } else { + globalConf.MFAConfig.TOTP = append(globalConf.MFAConfig.TOTP, totpConfig) + } + } +} + +func getRateLimitersFromEnv(idx int) { + rtlConfig := defaultRateLimiter + if len(globalConf.Common.RateLimitersConfig) > idx { + rtlConfig = globalConf.Common.RateLimitersConfig[idx] + } + + isSet := false + + average, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__AVERAGE", idx), 64) + if ok { + rtlConfig.Average = average + isSet = true + } + + period, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__PERIOD", idx), 64) + if ok { + rtlConfig.Period = period + isSet = true + } + + burst, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__BURST", idx), 32) + if ok { + rtlConfig.Burst = int(burst) + isSet = true + } + + rtlType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__TYPE", idx), 32) + if ok { + rtlConfig.Type = int(rtlType) + isSet = true + } + + protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__PROTOCOLS", idx)) + if ok { + rtlConfig.Protocols = protocols + isSet = true + } + + generateEvents, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__GENERATE_DEFENDER_EVENTS", idx)) + if ok { + rtlConfig.GenerateDefenderEvents = generateEvents + isSet = true + } + + softLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_SOFT_LIMIT", idx), 32) + if ok { + rtlConfig.EntriesSoftLimit = int(softLimit) + isSet = true + } + + hardLimit, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMON__RATE_LIMITERS__%v__ENTRIES_HARD_LIMIT", idx), 32) + if ok { + rtlConfig.EntriesHardLimit = int(hardLimit) + isSet = true + } + + if isSet { + if len(globalConf.Common.RateLimitersConfig) > idx { + globalConf.Common.RateLimitersConfig[idx] = rtlConfig + } else { + globalConf.Common.RateLimitersConfig = append(globalConf.Common.RateLimitersConfig, rtlConfig) + } + } +} + +func getKMSPluginFromEnv(idx int, pluginConfig *plugin.Config) bool { + isSet := false + + kmsScheme, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__KMS_OPTIONS__SCHEME", idx)) + if ok { + pluginConfig.KMSOptions.Scheme = kmsScheme + isSet = true + } + + kmsEncStatus, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__KMS_OPTIONS__ENCRYPTED_STATUS", idx)) + if ok { + pluginConfig.KMSOptions.EncryptedStatus = kmsEncStatus + isSet = true + } + + return isSet +} + +func getAuthPluginFromEnv(idx int, pluginConfig *plugin.Config) bool { + isSet := false + + authScope, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__AUTH_OPTIONS__SCOPE", idx), 32) + if ok { + pluginConfig.AuthOptions.Scope = int(authScope) + isSet = true + } + + return isSet +} + +func getNotifierPluginFromEnv(idx int, pluginConfig *plugin.Config) bool { + isSet := false + + notifierFsEvents, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__FS_EVENTS", idx)) + if ok { + pluginConfig.NotifierOptions.FsEvents = notifierFsEvents + isSet = true + } + + notifierProviderEvents, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__PROVIDER_EVENTS", idx)) + if ok { + pluginConfig.NotifierOptions.ProviderEvents = notifierProviderEvents + isSet = true + } + + notifierProviderObjects, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__PROVIDER_OBJECTS", idx)) + if ok { + pluginConfig.NotifierOptions.ProviderObjects = notifierProviderObjects + isSet = true + } + + notifierLogEventsString, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__LOG_EVENTS", idx)) + if ok { + var notifierLogEvents []int + for _, e := range notifierLogEventsString { + ev, err := strconv.Atoi(e) + if err == nil { + notifierLogEvents = append(notifierLogEvents, ev) + } + } + if len(notifierLogEvents) > 0 { + pluginConfig.NotifierOptions.LogEvents = notifierLogEvents + isSet = true + } + } + + notifierRetryMaxTime, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_MAX_TIME", idx), 32) + if ok { + pluginConfig.NotifierOptions.RetryMaxTime = int(notifierRetryMaxTime) + isSet = true + } + + notifierRetryQueueMaxSize, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE", idx), 32) + if ok { + pluginConfig.NotifierOptions.RetryQueueMaxSize = int(notifierRetryQueueMaxSize) + isSet = true + } + + return isSet +} + +func getPluginsFromEnv(idx int) { + pluginConfig := plugin.Config{} + if len(globalConf.PluginsConfig) > idx { + pluginConfig = globalConf.PluginsConfig[idx] + } + + isSet := false + + pluginType, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__TYPE", idx)) + if ok { + pluginConfig.Type = pluginType + isSet = true + } + + if getNotifierPluginFromEnv(idx, &pluginConfig) { + isSet = true + } + + if getKMSPluginFromEnv(idx, &pluginConfig) { + isSet = true + } + + if getAuthPluginFromEnv(idx, &pluginConfig) { + isSet = true + } + + cmd, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__CMD", idx)) + if ok { + pluginConfig.Cmd = cmd + isSet = true + } + + cmdArgs, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__ARGS", idx)) + if ok { + pluginConfig.Args = cmdArgs + isSet = true + } + + pluginHash, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__SHA256SUM", idx)) + if ok { + pluginConfig.SHA256Sum = pluginHash + isSet = true + } + + autoMTLS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__AUTO_MTLS", idx)) + if ok { + pluginConfig.AutoMTLS = autoMTLS + isSet = true + } + + envPrefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__ENV_PREFIX", idx)) + if ok { + pluginConfig.EnvPrefix = envPrefix + isSet = true + } + + envVars, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_PLUGINS__%v__ENV_VARS", idx)) + if ok { + pluginConfig.EnvVars = envVars + isSet = true + } + + if isSet { + if len(globalConf.PluginsConfig) > idx { + globalConf.PluginsConfig[idx] = pluginConfig + } else { + globalConf.PluginsConfig = append(globalConf.PluginsConfig, pluginConfig) + } + } +} + +func getSFTPDBindindFromEnv(idx int) { + binding := defaultSFTPDBinding + if len(globalConf.SFTPD.Bindings) > idx { + binding = globalConf.SFTPD.Bindings[idx] + } + + isSet := false + + port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__PORT", idx), 32) + if ok { + binding.Port = int(port) + isSet = true + } + + address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__ADDRESS", idx)) + if ok { + binding.Address = address + isSet = true + } + + applyProxyConfig, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_SFTPD__BINDINGS__%v__APPLY_PROXY_CONFIG", idx)) + if ok { + binding.ApplyProxyConfig = applyProxyConfig + isSet = true + } + + if isSet { + if len(globalConf.SFTPD.Bindings) > idx { + globalConf.SFTPD.Bindings[idx] = binding + } else { + globalConf.SFTPD.Bindings = append(globalConf.SFTPD.Bindings, binding) + } + } +} + +func getFTPDPassiveIPOverridesFromEnv(idx int) []ftpd.PassiveIPOverride { + var overrides []ftpd.PassiveIPOverride + if len(globalConf.FTPD.Bindings) > idx { + overrides = globalConf.FTPD.Bindings[idx].PassiveIPOverrides + } + + for subIdx := 0; subIdx < 10; subIdx++ { + var override ftpd.PassiveIPOverride + var replace bool + if len(globalConf.FTPD.Bindings) > idx && len(globalConf.FTPD.Bindings[idx].PassiveIPOverrides) > subIdx { + override = globalConf.FTPD.Bindings[idx].PassiveIPOverrides[subIdx] + replace = true + } + + ip, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_IP_OVERRIDES__%v__IP", idx, subIdx)) + if ok { + override.IP = ip + } + + networks, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_IP_OVERRIDES__%v__NETWORKS", + idx, subIdx)) + if ok { + override.Networks = networks + } + + if len(override.Networks) > 0 { + if replace { + overrides[subIdx] = override + } else { + overrides = append(overrides, override) + } + } + } + + return overrides +} + +func getDefaultFTPDBinding(idx int) ftpd.Binding { + binding := defaultFTPDBinding + if len(globalConf.FTPD.Bindings) > idx { + binding = globalConf.FTPD.Bindings[idx] + } + return binding +} + +func getFTPDBindingSecurityFromEnv(idx int, binding *ftpd.Binding) bool { + isSet := false + + certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CERTIFICATE_FILE", idx)) + if ok { + binding.CertificateFile = certificateFile + isSet = true + } + + certificateKeyFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CERTIFICATE_KEY_FILE", idx)) + if ok { + binding.CertificateKeyFile = certificateKeyFile + isSet = true + } + + tlsMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_MODE", idx), 32) + if ok { + binding.TLSMode = int(tlsMode) + isSet = true + } + + tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32) + if ok { + binding.MinTLSVersion = int(tlsVer) + isSet = true + } + + tlsCiphers, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__TLS_CIPHER_SUITES", idx)) + if ok { + binding.TLSCipherSuites = tlsCiphers + isSet = true + } + + clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32) + if ok { + binding.ClientAuthType = int(clientAuthType) + isSet = true + } + + pasvSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_CONNECTIONS_SECURITY", idx), 32) + if ok { + binding.PassiveConnectionsSecurity = int(pasvSecurity) + isSet = true + } + + activeSecurity, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__ACTIVE_CONNECTIONS_SECURITY", idx), 32) + if ok { + binding.ActiveConnectionsSecurity = int(activeSecurity) + isSet = true + } + + return isSet +} + +func getFTPDBindingFromEnv(idx int) { + binding := getDefaultFTPDBinding(idx) + isSet := false + + port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PORT", idx), 32) + if ok { + binding.Port = int(port) + isSet = true + } + + address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__ADDRESS", idx)) + if ok { + binding.Address = address + isSet = true + } + + applyProxyConfig, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__APPLY_PROXY_CONFIG", idx)) + if ok { + binding.ApplyProxyConfig = applyProxyConfig + isSet = true + } + + passiveIP, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__FORCE_PASSIVE_IP", idx)) + if ok { + binding.ForcePassiveIP = passiveIP + isSet = true + } + + passiveIPOverrides := getFTPDPassiveIPOverridesFromEnv(idx) + if len(passiveIPOverrides) > 0 { + binding.PassiveIPOverrides = passiveIPOverrides + isSet = true + } + + passiveHost, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__PASSIVE_HOST", idx)) + if ok { + binding.PassiveHost = passiveHost + isSet = true + } + + debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_FTPD__BINDINGS__%v__DEBUG", idx)) + if ok { + binding.Debug = debug + isSet = true + } + + if getFTPDBindingSecurityFromEnv(idx, &binding) { + isSet = true + } + + applyFTPDBindingFromEnv(idx, isSet, binding) +} + +func applyFTPDBindingFromEnv(idx int, isSet bool, binding ftpd.Binding) { + if isSet { + if len(globalConf.FTPD.Bindings) > idx { + globalConf.FTPD.Bindings[idx] = binding + } else { + globalConf.FTPD.Bindings = append(globalConf.FTPD.Bindings, binding) + } + } +} + +func getWebDAVBindingHTTPSConfigsFromEnv(idx int, binding *webdavd.Binding) bool { + isSet := false + + enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ENABLE_HTTPS", idx)) + if ok { + binding.EnableHTTPS = enableHTTPS + isSet = true + } + + certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_FILE", idx)) + if ok { + binding.CertificateFile = certificateFile + isSet = true + } + + certificateKeyFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_KEY_FILE", idx)) + if ok { + binding.CertificateKeyFile = certificateKeyFile + isSet = true + } + + tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32) + if ok { + binding.MinTLSVersion = int(tlsVer) + isSet = true + } + + clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32) + if ok { + binding.ClientAuthType = int(clientAuthType) + isSet = true + } + + tlsCiphers, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__TLS_CIPHER_SUITES", idx)) + if ok { + binding.TLSCipherSuites = tlsCiphers + isSet = true + } + + protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%d__TLS_PROTOCOLS", idx)) + if ok { + binding.Protocols = protocols + isSet = true + } + + return isSet +} + +func getWebDAVDBindingProxyConfigsFromEnv(idx int, binding *webdavd.Binding) bool { + isSet := false + + proxyMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_MODE", idx), 32) + if ok { + binding.ProxyMode = int(proxyMode) + isSet = true + } + + proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PROXY_ALLOWED", idx)) + if ok { + binding.ProxyAllowed = proxyAllowed + isSet = true + } + + clientIPProxyHeader, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_IP_PROXY_HEADER", idx)) + if ok { + binding.ClientIPProxyHeader = clientIPProxyHeader + isSet = true + } + + clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 32) + if ok { + binding.ClientIPHeaderDepth = int(clientIPHeaderDepth) + isSet = true + } + + return isSet +} + +func loadWebDAVCacheMappingsFromEnv() []webdavd.CustomMimeMapping { + for idx := 0; idx < 30; idx++ { + ext, extOK := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__%d__EXT", idx)) + mime, mimeOK := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__%d__MIME", idx)) + if extOK && mimeOK { + if len(globalConf.WebDAVD.Cache.MimeTypes.CustomMappings) > idx { + globalConf.WebDAVD.Cache.MimeTypes.CustomMappings[idx].Ext = ext + globalConf.WebDAVD.Cache.MimeTypes.CustomMappings[idx].Mime = mime + } else { + globalConf.WebDAVD.Cache.MimeTypes.CustomMappings = append(globalConf.WebDAVD.Cache.MimeTypes.CustomMappings, + webdavd.CustomMimeMapping{ + Ext: ext, + Mime: mime, + }) + } + } + } + + return globalConf.WebDAVD.Cache.MimeTypes.CustomMappings +} + +func getWebDAVDBindingFromEnv(idx int) { + binding := defaultWebDAVDBinding + if len(globalConf.WebDAVD.Bindings) > idx { + binding = globalConf.WebDAVD.Bindings[idx] + } + + isSet := false + + port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PORT", idx), 32) + if ok { + binding.Port = int(port) + isSet = true + } + + address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ADDRESS", idx)) + if ok { + binding.Address = address + isSet = true + } + + if getWebDAVBindingHTTPSConfigsFromEnv(idx, &binding) { + isSet = true + } + + prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx)) + if ok { + binding.Prefix = prefix + isSet = true + } + + if getWebDAVDBindingProxyConfigsFromEnv(idx, &binding) { + isSet = true + } + + disableWWWAuth, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__DISABLE_WWW_AUTH_HEADER", idx)) + if ok { + binding.DisableWWWAuthHeader = disableWWWAuth + isSet = true + } + + if isSet { + if len(globalConf.WebDAVD.Bindings) > idx { + globalConf.WebDAVD.Bindings[idx] = binding + } else { + globalConf.WebDAVD.Bindings = append(globalConf.WebDAVD.Bindings, binding) + } + } +} + +func getHTTPDSecurityProxyHeadersFromEnv(idx int) []httpd.HTTPSProxyHeader { + var httpsProxyHeaders []httpd.HTTPSProxyHeader + if len(globalConf.HTTPDConfig.Bindings) > idx { + httpsProxyHeaders = globalConf.HTTPDConfig.Bindings[idx].Security.HTTPSProxyHeaders + } + + for subIdx := 0; subIdx < 10; subIdx++ { + var httpsProxyHeader httpd.HTTPSProxyHeader + var replace bool + if len(globalConf.HTTPDConfig.Bindings) > idx && + len(globalConf.HTTPDConfig.Bindings[idx].Security.HTTPSProxyHeaders) > subIdx { + httpsProxyHeader = httpsProxyHeaders[subIdx] + replace = true + } + proxyKey, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_PROXY_HEADERS__%v__KEY", + idx, subIdx)) + if ok { + httpsProxyHeader.Key = proxyKey + } + proxyVal, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_PROXY_HEADERS__%v__VALUE", + idx, subIdx)) + if ok { + httpsProxyHeader.Value = proxyVal + } + if httpsProxyHeader.Key != "" && httpsProxyHeader.Value != "" { + if replace { + httpsProxyHeaders[subIdx] = httpsProxyHeader + } else { + httpsProxyHeaders = append(httpsProxyHeaders, httpsProxyHeader) + } + } + } + return httpsProxyHeaders +} + +func getHTTPDSecurityConfFromEnv(idx int) (httpd.SecurityConf, bool) { //nolint:gocyclo + result := defaultHTTPDBinding.Security + if len(globalConf.HTTPDConfig.Bindings) > idx { + result = globalConf.HTTPDConfig.Bindings[idx].Security + } + isSet := false + + enabled, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__ENABLED", idx)) + if ok { + result.Enabled = enabled + isSet = true + } + + allowedHosts, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__ALLOWED_HOSTS", idx)) + if ok { + result.AllowedHosts = allowedHosts + isSet = true + } + + allowedHostsAreRegex, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__ALLOWED_HOSTS_ARE_REGEX", idx)) + if ok { + result.AllowedHostsAreRegex = allowedHostsAreRegex + isSet = true + } + + hostsProxyHeaders, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HOSTS_PROXY_HEADERS", idx)) + if ok { + result.HostsProxyHeaders = hostsProxyHeaders + isSet = true + } + + httpsRedirect, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_REDIRECT", idx)) + if ok { + result.HTTPSRedirect = httpsRedirect + isSet = true + } + + httpsHost, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__HTTPS_HOST", idx)) + if ok { + result.HTTPSHost = httpsHost + isSet = true + } + + httpsProxyHeaders := getHTTPDSecurityProxyHeadersFromEnv(idx) + if len(httpsProxyHeaders) > 0 { + result.HTTPSProxyHeaders = httpsProxyHeaders + isSet = true + } + + stsSeconds, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__STS_SECONDS", idx), 64) + if ok { + result.STSSeconds = stsSeconds + isSet = true + } + + stsIncludeSubDomains, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__STS_INCLUDE_SUBDOMAINS", idx)) + if ok { + result.STSIncludeSubdomains = stsIncludeSubDomains + isSet = true + } + + stsPreload, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__STS_PRELOAD", idx)) + if ok { + result.STSPreload = stsPreload + isSet = true + } + + contentTypeNosniff, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CONTENT_TYPE_NOSNIFF", idx)) + if ok { + result.ContentTypeNosniff = contentTypeNosniff + isSet = true + } + + contentSecurityPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CONTENT_SECURITY_POLICY", idx)) + if ok { + result.ContentSecurityPolicy = contentSecurityPolicy + isSet = true + } + + permissionsPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__PERMISSIONS_POLICY", idx)) + if ok { + result.PermissionsPolicy = permissionsPolicy + isSet = true + } + + crossOriginOpenerPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CROSS_ORIGIN_OPENER_POLICY", idx)) + if ok { + result.CrossOriginOpenerPolicy = crossOriginOpenerPolicy + isSet = true + } + + crossOriginResourcePolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CROSS_ORIGIN_RESOURCE_POLICY", idx)) + if ok { + result.CrossOriginResourcePolicy = crossOriginResourcePolicy + isSet = true + } + + crossOriginEmbedderPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CROSS_ORIGIN_EMBEDDER_POLICY", idx)) + if ok { + result.CrossOriginEmbedderPolicy = crossOriginEmbedderPolicy + isSet = true + } + + referredPolicy, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__REFERRER_POLICY", idx)) + if ok { + result.ReferrerPolicy = referredPolicy + isSet = true + } + + cacheControl, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__SECURITY__CACHE_CONTROL", idx)) + if ok { + result.CacheControl = cacheControl + isSet = true + } + + return result, isSet +} + +func getHTTPDOIDCFromEnv(idx int) (httpd.OIDC, bool) { + result := defaultHTTPDBinding.OIDC + if len(globalConf.HTTPDConfig.Bindings) > idx { + result = globalConf.HTTPDConfig.Bindings[idx].OIDC + } + isSet := false + + clientID, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CLIENT_ID", idx)) + if ok { + result.ClientID = clientID + isSet = true + } + + clientSecret, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CLIENT_SECRET", idx)) + if ok { + result.ClientSecret = clientSecret + isSet = true + } + + clientSecretFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CLIENT_SECRET_FILE", idx)) + if ok { + result.ClientSecretFile = clientSecretFile + isSet = true + } + + configURL, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CONFIG_URL", idx)) + if ok { + result.ConfigURL = configURL + isSet = true + } + + redirectBaseURL, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__REDIRECT_BASE_URL", idx)) + if ok { + result.RedirectBaseURL = redirectBaseURL + isSet = true + } + + usernameField, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__USERNAME_FIELD", idx)) + if ok { + result.UsernameField = usernameField + isSet = true + } + + scopes, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__SCOPES", idx)) + if ok { + result.Scopes = scopes + isSet = true + } + + roleField, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__ROLE_FIELD", idx)) + if ok { + result.RoleField = roleField + isSet = true + } + + implicitRoles, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__IMPLICIT_ROLES", idx)) + if ok { + result.ImplicitRoles = implicitRoles + isSet = true + } + + customFields, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__CUSTOM_FIELDS", idx)) + if ok { + result.CustomFields = customFields + isSet = true + } + + skipSignatureCheck, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", idx)) + if ok { + result.InsecureSkipSignatureCheck = skipSignatureCheck + isSet = true + } + + debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__DEBUG", idx)) + if ok { + result.Debug = debug + isSet = true + } + + return result, isSet +} + +func getHTTPDUIBrandingFromEnv(prefix string, branding httpd.UIBranding) (httpd.UIBranding, bool) { + isSet := false + + name, ok := os.LookupEnv(fmt.Sprintf("%s__NAME", prefix)) + if ok { + branding.Name = name + isSet = true + } + + shortName, ok := os.LookupEnv(fmt.Sprintf("%s__SHORT_NAME", prefix)) + if ok { + branding.ShortName = shortName + isSet = true + } + + faviconPath, ok := os.LookupEnv(fmt.Sprintf("%s__FAVICON_PATH", prefix)) + if ok { + branding.FaviconPath = faviconPath + isSet = true + } + + logoPath, ok := os.LookupEnv(fmt.Sprintf("%s__LOGO_PATH", prefix)) + if ok { + branding.LogoPath = logoPath + isSet = true + } + + disclaimerName, ok := os.LookupEnv(fmt.Sprintf("%s__DISCLAIMER_NAME", prefix)) + if ok { + branding.DisclaimerName = disclaimerName + isSet = true + } + + disclaimerPath, ok := os.LookupEnv(fmt.Sprintf("%s__DISCLAIMER_PATH", prefix)) + if ok { + branding.DisclaimerPath = disclaimerPath + isSet = true + } + + defaultCSSPath, ok := lookupStringListFromEnv(fmt.Sprintf("%s__DEFAULT_CSS", prefix)) + if ok { + branding.DefaultCSS = defaultCSSPath + isSet = true + } + + extraCSS, ok := lookupStringListFromEnv(fmt.Sprintf("%s__EXTRA_CSS", prefix)) + if ok { + branding.ExtraCSS = extraCSS + isSet = true + } + + return branding, isSet +} + +func getHTTPDBrandingFromEnv(idx int) (httpd.Branding, bool) { + result := defaultHTTPDBinding.Branding + if len(globalConf.HTTPDConfig.Bindings) > idx { + result = globalConf.HTTPDConfig.Bindings[idx].Branding + } + isSet := false + + webAdmin, ok := getHTTPDUIBrandingFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__BRANDING__WEB_ADMIN", idx), + result.WebAdmin) + if ok { + result.WebAdmin = webAdmin + isSet = true + } + + webClient, ok := getHTTPDUIBrandingFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__BRANDING__WEB_CLIENT", idx), + result.WebClient) + if ok { + result.WebClient = webClient + isSet = true + } + + return result, isSet +} + +func getDefaultHTTPBinding(idx int) httpd.Binding { + binding := defaultHTTPDBinding + if len(globalConf.HTTPDConfig.Bindings) > idx { + binding = globalConf.HTTPDConfig.Bindings[idx] + } + return binding +} + +func getHTTPDNestedObjectsFromEnv(idx int, binding *httpd.Binding) bool { + isSet := false + + oidc, ok := getHTTPDOIDCFromEnv(idx) + if ok { + binding.OIDC = oidc + isSet = true + } + + securityConf, ok := getHTTPDSecurityConfFromEnv(idx) + if ok { + binding.Security = securityConf + isSet = true + } + + brandingConf, ok := getHTTPDBrandingFromEnv(idx) + if ok { + binding.Branding = brandingConf + isSet = true + } + + return isSet +} + +func getHTTPDBindingProxyConfigsFromEnv(idx int, binding *httpd.Binding) bool { + isSet := false + + proxyMode, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_MODE", idx), 32) + if ok { + binding.ProxyMode = int(proxyMode) + isSet = true + } + + proxyAllowed, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PROXY_ALLOWED", idx)) + if ok { + binding.ProxyAllowed = proxyAllowed + isSet = true + } + + clientIPProxyHeader, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_IP_PROXY_HEADER", idx)) + if ok { + binding.ClientIPProxyHeader = clientIPProxyHeader + isSet = true + } + + clientIPHeaderDepth, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_IP_HEADER_DEPTH", idx), 32) + if ok { + binding.ClientIPHeaderDepth = int(clientIPHeaderDepth) + isSet = true + } + + return isSet +} + +func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo + binding := getDefaultHTTPBinding(idx) + isSet := false + + port, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__PORT", idx), 32) + if ok { + binding.Port = int(port) + isSet = true + } + + address, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ADDRESS", idx)) + if ok { + binding.Address = address + isSet = true + } + + certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CERTIFICATE_FILE", idx)) + if ok { + binding.CertificateFile = certificateFile + isSet = true + } + + certificateKeyFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CERTIFICATE_KEY_FILE", idx)) + if ok { + binding.CertificateKeyFile = certificateKeyFile + isSet = true + } + + enableWebAdmin, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_WEB_ADMIN", idx)) + if ok { + binding.EnableWebAdmin = enableWebAdmin + isSet = true + } + + enableWebClient, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_WEB_CLIENT", idx)) + if ok { + binding.EnableWebClient = enableWebClient + isSet = true + } + + enableRESTAPI, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_REST_API", idx)) + if ok { + binding.EnableRESTAPI = enableRESTAPI + isSet = true + } + + enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx), 32) + if ok { + binding.EnabledLoginMethods = int(enabledLoginMethods) + isSet = true + } + + disabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__DISABLED_LOGIN_METHODS", idx), 32) + if ok { + binding.DisabledLoginMethods = int(disabledLoginMethods) + isSet = true + } + + renderOpenAPI, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__RENDER_OPENAPI", idx)) + if ok { + binding.RenderOpenAPI = renderOpenAPI + isSet = true + } + + baseURL, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__BASE_URL", idx)) + if ok { + binding.BaseURL = baseURL + isSet = true + } + + languages, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__LANGUAGES", idx)) + if ok { + binding.Languages = languages + isSet = true + } + + enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_HTTPS", idx)) + if ok { + binding.EnableHTTPS = enableHTTPS + isSet = true + } + + tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__MIN_TLS_VERSION", idx), 32) + if ok { + binding.MinTLSVersion = int(tlsVer) + isSet = true + } + + clientAuthType, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__CLIENT_AUTH_TYPE", idx), 32) + if ok { + binding.ClientAuthType = int(clientAuthType) + isSet = true + } + + tlsCiphers, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__TLS_CIPHER_SUITES", idx)) + if ok { + binding.TLSCipherSuites = tlsCiphers + isSet = true + } + + protocols, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%d__TLS_PROTOCOLS", idx)) + if ok { + binding.Protocols = protocols + isSet = true + } + + if getHTTPDBindingProxyConfigsFromEnv(idx, &binding) { + isSet = true + } + + hideLoginURL, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__HIDE_LOGIN_URL", idx), 32) + if ok { + binding.HideLoginURL = int(hideLoginURL) + isSet = true + } + + if getHTTPDNestedObjectsFromEnv(idx, &binding) { + isSet = true + } + + setHTTPDBinding(isSet, binding, idx) +} + +func setHTTPDBinding(isSet bool, binding httpd.Binding, idx int) { + if isSet { + if len(globalConf.HTTPDConfig.Bindings) > idx { + globalConf.HTTPDConfig.Bindings[idx] = binding + } else { + globalConf.HTTPDConfig.Bindings = append(globalConf.HTTPDConfig.Bindings, binding) + } + } +} + +func getHTTPClientCertificatesFromEnv(idx int) { + tlsCert := httpclient.TLSKeyPair{} + if len(globalConf.HTTPConfig.Certificates) > idx { + tlsCert = globalConf.HTTPConfig.Certificates[idx] + } + + cert, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__CERTIFICATES__%v__CERT", idx)) + if ok { + tlsCert.Cert = cert + } + + key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__CERTIFICATES__%v__KEY", idx)) + if ok { + tlsCert.Key = key + } + + if tlsCert.Cert != "" && tlsCert.Key != "" { + if len(globalConf.HTTPConfig.Certificates) > idx { + globalConf.HTTPConfig.Certificates[idx] = tlsCert + } else { + globalConf.HTTPConfig.Certificates = append(globalConf.HTTPConfig.Certificates, tlsCert) + } + } +} + +func getHTTPClientHeadersFromEnv(idx int) { + header := httpclient.Header{} + if len(globalConf.HTTPConfig.Headers) > idx { + header = globalConf.HTTPConfig.Headers[idx] + } + + key, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__KEY", idx)) + if ok { + header.Key = key + } + + value, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__VALUE", idx)) + if ok { + header.Value = value + } + + url, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_HTTP__HEADERS__%v__URL", idx)) + if ok { + header.URL = url + } + + if header.Key != "" && header.Value != "" { + if len(globalConf.HTTPConfig.Headers) > idx { + globalConf.HTTPConfig.Headers[idx] = header + } else { + globalConf.HTTPConfig.Headers = append(globalConf.HTTPConfig.Headers, header) + } + } +} + +func getCommandConfigsFromEnv(idx int) { + cfg := command.Command{} + if len(globalConf.CommandConfig.Commands) > idx { + cfg = globalConf.CommandConfig.Commands[idx] + } + + path, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__PATH", idx)) + if ok { + cfg.Path = path + } + + timeout, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__TIMEOUT", idx), 32) + if ok { + cfg.Timeout = int(timeout) + } + + env, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__ENV", idx)) + if ok { + cfg.Env = env + } + + args, ok := lookupStringListFromEnv(fmt.Sprintf("SFTPGO_COMMAND__COMMANDS__%v__ARGS", idx)) + if ok { + cfg.Args = args + } + + if cfg.Path != "" { + if len(globalConf.CommandConfig.Commands) > idx { + globalConf.CommandConfig.Commands[idx] = cfg + } else { + globalConf.CommandConfig.Commands = append(globalConf.CommandConfig.Commands, cfg) + } + } +} + +func setViperDefaults() { + viper.SetDefault("common.idle_timeout", globalConf.Common.IdleTimeout) + viper.SetDefault("common.upload_mode", globalConf.Common.UploadMode) + viper.SetDefault("common.actions.execute_on", globalConf.Common.Actions.ExecuteOn) + viper.SetDefault("common.actions.execute_sync", globalConf.Common.Actions.ExecuteSync) + viper.SetDefault("common.actions.hook", globalConf.Common.Actions.Hook) + viper.SetDefault("common.setstat_mode", globalConf.Common.SetstatMode) + viper.SetDefault("common.rename_mode", globalConf.Common.RenameMode) + viper.SetDefault("common.resume_max_size", globalConf.Common.ResumeMaxSize) + viper.SetDefault("common.temp_path", globalConf.Common.TempPath) + viper.SetDefault("common.proxy_protocol", globalConf.Common.ProxyProtocol) + viper.SetDefault("common.proxy_allowed", globalConf.Common.ProxyAllowed) + viper.SetDefault("common.proxy_skipped", globalConf.Common.ProxySkipped) + viper.SetDefault("common.post_connect_hook", globalConf.Common.PostConnectHook) + viper.SetDefault("common.post_disconnect_hook", globalConf.Common.PostDisconnectHook) + viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections) + viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections) + viper.SetDefault("common.allowlist_status", globalConf.Common.AllowListStatus) + viper.SetDefault("common.allow_self_connections", globalConf.Common.AllowSelfConnections) + viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled) + viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver) + viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime) + viper.SetDefault("common.defender.ban_time_increment", globalConf.Common.DefenderConfig.BanTimeIncrement) + viper.SetDefault("common.defender.threshold", globalConf.Common.DefenderConfig.Threshold) + viper.SetDefault("common.defender.score_invalid", globalConf.Common.DefenderConfig.ScoreInvalid) + viper.SetDefault("common.defender.score_valid", globalConf.Common.DefenderConfig.ScoreValid) + viper.SetDefault("common.defender.score_limit_exceeded", globalConf.Common.DefenderConfig.ScoreLimitExceeded) + viper.SetDefault("common.defender.score_no_auth", globalConf.Common.DefenderConfig.ScoreNoAuth) + viper.SetDefault("common.defender.observation_time", globalConf.Common.DefenderConfig.ObservationTime) + viper.SetDefault("common.defender.entries_soft_limit", globalConf.Common.DefenderConfig.EntriesSoftLimit) + viper.SetDefault("common.defender.entries_hard_limit", globalConf.Common.DefenderConfig.EntriesHardLimit) + viper.SetDefault("common.defender.login_delay.success", globalConf.Common.DefenderConfig.LoginDelay.Success) + viper.SetDefault("common.defender.login_delay.password_failed", globalConf.Common.DefenderConfig.LoginDelay.PasswordFailed) + viper.SetDefault("common.umask", globalConf.Common.Umask) + viper.SetDefault("common.server_version", globalConf.Common.ServerVersion) + viper.SetDefault("common.tz", globalConf.Common.TZ) + viper.SetDefault("common.metadata.read", globalConf.Common.Metadata.Read) + viper.SetDefault("common.event_manager.enabled_commands", globalConf.Common.EventManager.EnabledCommands) + viper.SetDefault("acme.email", globalConf.ACME.Email) + viper.SetDefault("acme.key_type", globalConf.ACME.KeyType) + viper.SetDefault("acme.certs_path", globalConf.ACME.CertsPath) + viper.SetDefault("acme.ca_endpoint", globalConf.ACME.CAEndpoint) + viper.SetDefault("acme.domains", globalConf.ACME.Domains) + viper.SetDefault("acme.renew_days", globalConf.ACME.RenewDays) + viper.SetDefault("acme.http01_challenge.port", globalConf.ACME.HTTP01Challenge.Port) + viper.SetDefault("acme.http01_challenge.webroot", globalConf.ACME.HTTP01Challenge.WebRoot) + viper.SetDefault("acme.http01_challenge.proxy_header", globalConf.ACME.HTTP01Challenge.ProxyHeader) + viper.SetDefault("acme.tls_alpn01_challenge.port", globalConf.ACME.TLSALPN01Challenge.Port) + viper.SetDefault("sftpd.max_auth_tries", globalConf.SFTPD.MaxAuthTries) + viper.SetDefault("sftpd.host_keys", globalConf.SFTPD.HostKeys) + viper.SetDefault("sftpd.host_certificates", globalConf.SFTPD.HostCertificates) + viper.SetDefault("sftpd.host_key_algorithms", globalConf.SFTPD.HostKeyAlgorithms) + viper.SetDefault("sftpd.kex_algorithms", globalConf.SFTPD.KexAlgorithms) + viper.SetDefault("sftpd.ciphers", globalConf.SFTPD.Ciphers) + viper.SetDefault("sftpd.macs", globalConf.SFTPD.MACs) + viper.SetDefault("sftpd.public_key_algorithms", globalConf.SFTPD.PublicKeyAlgorithms) + viper.SetDefault("sftpd.trusted_user_ca_keys", globalConf.SFTPD.TrustedUserCAKeys) + viper.SetDefault("sftpd.revoked_user_certs_file", globalConf.SFTPD.RevokedUserCertsFile) + viper.SetDefault("sftpd.opkssh_path", globalConf.SFTPD.OPKSSHPath) + viper.SetDefault("sftpd.opkssh_checksum", globalConf.SFTPD.OPKSSHChecksum) + viper.SetDefault("sftpd.login_banner_file", globalConf.SFTPD.LoginBannerFile) + viper.SetDefault("sftpd.enabled_ssh_commands", sftpd.GetDefaultSSHCommands()) + viper.SetDefault("sftpd.keyboard_interactive_authentication", globalConf.SFTPD.KeyboardInteractiveAuthentication) + viper.SetDefault("sftpd.keyboard_interactive_auth_hook", globalConf.SFTPD.KeyboardInteractiveHook) + viper.SetDefault("sftpd.password_authentication", globalConf.SFTPD.PasswordAuthentication) + viper.SetDefault("ftpd.banner_file", globalConf.FTPD.BannerFile) + viper.SetDefault("ftpd.active_transfers_port_non_20", globalConf.FTPD.ActiveTransfersPortNon20) + viper.SetDefault("ftpd.passive_port_range.start", globalConf.FTPD.PassivePortRange.Start) + viper.SetDefault("ftpd.passive_port_range.end", globalConf.FTPD.PassivePortRange.End) + viper.SetDefault("ftpd.disable_active_mode", globalConf.FTPD.DisableActiveMode) + viper.SetDefault("ftpd.enable_site", globalConf.FTPD.EnableSite) + viper.SetDefault("ftpd.hash_support", globalConf.FTPD.HASHSupport) + viper.SetDefault("ftpd.combine_support", globalConf.FTPD.CombineSupport) + viper.SetDefault("ftpd.certificate_file", globalConf.FTPD.CertificateFile) + viper.SetDefault("ftpd.certificate_key_file", globalConf.FTPD.CertificateKeyFile) + viper.SetDefault("ftpd.ca_certificates", globalConf.FTPD.CACertificates) + viper.SetDefault("ftpd.ca_revocation_lists", globalConf.FTPD.CARevocationLists) + viper.SetDefault("webdavd.certificate_file", globalConf.WebDAVD.CertificateFile) + viper.SetDefault("webdavd.certificate_key_file", globalConf.WebDAVD.CertificateKeyFile) + viper.SetDefault("webdavd.ca_certificates", globalConf.WebDAVD.CACertificates) + viper.SetDefault("webdavd.ca_revocation_lists", globalConf.WebDAVD.CARevocationLists) + viper.SetDefault("webdavd.cors.enabled", globalConf.WebDAVD.Cors.Enabled) + viper.SetDefault("webdavd.cors.allowed_origins", globalConf.WebDAVD.Cors.AllowedOrigins) + viper.SetDefault("webdavd.cors.allowed_methods", globalConf.WebDAVD.Cors.AllowedMethods) + viper.SetDefault("webdavd.cors.allowed_headers", globalConf.WebDAVD.Cors.AllowedHeaders) + viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders) + viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials) + viper.SetDefault("webdavd.cors.options_passthrough", globalConf.WebDAVD.Cors.OptionsPassthrough) + viper.SetDefault("webdavd.cors.options_success_status", globalConf.WebDAVD.Cors.OptionsSuccessStatus) + viper.SetDefault("webdavd.cors.allow_private_network", globalConf.WebDAVD.Cors.AllowPrivateNetwork) + viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge) + viper.SetDefault("webdavd.cache.users.expiration_time", globalConf.WebDAVD.Cache.Users.ExpirationTime) + viper.SetDefault("webdavd.cache.users.max_size", globalConf.WebDAVD.Cache.Users.MaxSize) + viper.SetDefault("webdavd.cache.mime_types.enabled", globalConf.WebDAVD.Cache.MimeTypes.Enabled) + viper.SetDefault("webdavd.cache.mime_types.max_size", globalConf.WebDAVD.Cache.MimeTypes.MaxSize) + viper.SetDefault("webdavd.cache.mime_types.custom_mappings", globalConf.WebDAVD.Cache.MimeTypes.CustomMappings) + viper.SetDefault("data_provider.driver", globalConf.ProviderConf.Driver) + viper.SetDefault("data_provider.name", globalConf.ProviderConf.Name) + viper.SetDefault("data_provider.host", globalConf.ProviderConf.Host) + viper.SetDefault("data_provider.port", globalConf.ProviderConf.Port) + viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username) + viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password) + viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode) + viper.SetDefault("data_provider.disable_sni", globalConf.ProviderConf.DisableSNI) + viper.SetDefault("data_provider.target_session_attrs", globalConf.ProviderConf.TargetSessionAttrs) + viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert) + viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert) + viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey) + viper.SetDefault("data_provider.connection_string", globalConf.ProviderConf.ConnectionString) + viper.SetDefault("data_provider.sql_tables_prefix", globalConf.ProviderConf.SQLTablesPrefix) + viper.SetDefault("data_provider.track_quota", globalConf.ProviderConf.TrackQuota) + viper.SetDefault("data_provider.pool_size", globalConf.ProviderConf.PoolSize) + viper.SetDefault("data_provider.users_base_dir", globalConf.ProviderConf.UsersBaseDir) + viper.SetDefault("data_provider.actions.execute_on", globalConf.ProviderConf.Actions.ExecuteOn) + viper.SetDefault("data_provider.actions.execute_for", globalConf.ProviderConf.Actions.ExecuteFor) + viper.SetDefault("data_provider.actions.hook", globalConf.ProviderConf.Actions.Hook) + viper.SetDefault("data_provider.external_auth_hook", globalConf.ProviderConf.ExternalAuthHook) + viper.SetDefault("data_provider.external_auth_scope", globalConf.ProviderConf.ExternalAuthScope) + viper.SetDefault("data_provider.pre_login_hook", globalConf.ProviderConf.PreLoginHook) + viper.SetDefault("data_provider.post_login_hook", globalConf.ProviderConf.PostLoginHook) + viper.SetDefault("data_provider.post_login_scope", globalConf.ProviderConf.PostLoginScope) + viper.SetDefault("data_provider.check_password_hook", globalConf.ProviderConf.CheckPasswordHook) + viper.SetDefault("data_provider.check_password_scope", globalConf.ProviderConf.CheckPasswordScope) + viper.SetDefault("data_provider.password_hashing.bcrypt_options.cost", globalConf.ProviderConf.PasswordHashing.BcryptOptions.Cost) + viper.SetDefault("data_provider.password_hashing.argon2_options.memory", globalConf.ProviderConf.PasswordHashing.Argon2Options.Memory) + viper.SetDefault("data_provider.password_hashing.argon2_options.iterations", globalConf.ProviderConf.PasswordHashing.Argon2Options.Iterations) + viper.SetDefault("data_provider.password_hashing.argon2_options.parallelism", globalConf.ProviderConf.PasswordHashing.Argon2Options.Parallelism) + viper.SetDefault("data_provider.password_hashing.algo", globalConf.ProviderConf.PasswordHashing.Algo) + viper.SetDefault("data_provider.password_validation.admins.min_entropy", globalConf.ProviderConf.PasswordValidation.Admins.MinEntropy) + viper.SetDefault("data_provider.password_validation.users.min_entropy", globalConf.ProviderConf.PasswordValidation.Users.MinEntropy) + viper.SetDefault("data_provider.password_caching", globalConf.ProviderConf.PasswordCaching) + viper.SetDefault("data_provider.update_mode", globalConf.ProviderConf.UpdateMode) + viper.SetDefault("data_provider.delayed_quota_update", globalConf.ProviderConf.DelayedQuotaUpdate) + viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin) + viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules) + viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared) + viper.SetDefault("data_provider.node.host", globalConf.ProviderConf.Node.Host) + viper.SetDefault("data_provider.node.port", globalConf.ProviderConf.Node.Port) + viper.SetDefault("data_provider.node.proto", globalConf.ProviderConf.Node.Proto) + viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath) + viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath) + viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath) + viper.SetDefault("httpd.openapi_path", globalConf.HTTPDConfig.OpenAPIPath) + viper.SetDefault("httpd.web_root", globalConf.HTTPDConfig.WebRoot) + viper.SetDefault("httpd.certificate_file", globalConf.HTTPDConfig.CertificateFile) + viper.SetDefault("httpd.certificate_key_file", globalConf.HTTPDConfig.CertificateKeyFile) + viper.SetDefault("httpd.ca_certificates", globalConf.HTTPDConfig.CACertificates) + viper.SetDefault("httpd.ca_revocation_lists", globalConf.HTTPDConfig.CARevocationLists) + viper.SetDefault("httpd.signing_passphrase", globalConf.HTTPDConfig.SigningPassphrase) + viper.SetDefault("httpd.signing_passphrase_file", globalConf.HTTPDConfig.SigningPassphraseFile) + viper.SetDefault("httpd.token_validation", globalConf.HTTPDConfig.TokenValidation) + viper.SetDefault("httpd.cookie_lifetime", globalConf.HTTPDConfig.CookieLifetime) + viper.SetDefault("httpd.share_cookie_lifetime", globalConf.HTTPDConfig.ShareCookieLifetime) + viper.SetDefault("httpd.jwt_lifetime", globalConf.HTTPDConfig.JWTLifetime) + viper.SetDefault("httpd.max_upload_file_size", globalConf.HTTPDConfig.MaxUploadFileSize) + viper.SetDefault("httpd.cors.enabled", globalConf.HTTPDConfig.Cors.Enabled) + viper.SetDefault("httpd.cors.allowed_origins", globalConf.HTTPDConfig.Cors.AllowedOrigins) + viper.SetDefault("httpd.cors.allowed_methods", globalConf.HTTPDConfig.Cors.AllowedMethods) + viper.SetDefault("httpd.cors.allowed_headers", globalConf.HTTPDConfig.Cors.AllowedHeaders) + viper.SetDefault("httpd.cors.exposed_headers", globalConf.HTTPDConfig.Cors.ExposedHeaders) + viper.SetDefault("httpd.cors.allow_credentials", globalConf.HTTPDConfig.Cors.AllowCredentials) + viper.SetDefault("httpd.cors.max_age", globalConf.HTTPDConfig.Cors.MaxAge) + viper.SetDefault("httpd.cors.options_passthrough", globalConf.HTTPDConfig.Cors.OptionsPassthrough) + viper.SetDefault("httpd.cors.options_success_status", globalConf.HTTPDConfig.Cors.OptionsSuccessStatus) + viper.SetDefault("httpd.cors.allow_private_network", globalConf.HTTPDConfig.Cors.AllowPrivateNetwork) + viper.SetDefault("httpd.setup.installation_code", globalConf.HTTPDConfig.Setup.InstallationCode) + viper.SetDefault("httpd.setup.installation_code_hint", globalConf.HTTPDConfig.Setup.InstallationCodeHint) + viper.SetDefault("httpd.hide_support_link", globalConf.HTTPDConfig.HideSupportLink) + viper.SetDefault("http.timeout", globalConf.HTTPConfig.Timeout) + viper.SetDefault("http.retry_wait_min", globalConf.HTTPConfig.RetryWaitMin) + viper.SetDefault("http.retry_wait_max", globalConf.HTTPConfig.RetryWaitMax) + viper.SetDefault("http.retry_max", globalConf.HTTPConfig.RetryMax) + viper.SetDefault("http.ca_certificates", globalConf.HTTPConfig.CACertificates) + viper.SetDefault("http.skip_tls_verify", globalConf.HTTPConfig.SkipTLSVerify) + viper.SetDefault("command.timeout", globalConf.CommandConfig.Timeout) + viper.SetDefault("command.env", globalConf.CommandConfig.Env) + viper.SetDefault("kms.secrets.url", globalConf.KMSConfig.Secrets.URL) + viper.SetDefault("kms.secrets.master_key", globalConf.KMSConfig.Secrets.MasterKeyString) + viper.SetDefault("kms.secrets.master_key_path", globalConf.KMSConfig.Secrets.MasterKeyPath) + viper.SetDefault("telemetry.bind_port", globalConf.TelemetryConfig.BindPort) + viper.SetDefault("telemetry.bind_address", globalConf.TelemetryConfig.BindAddress) + viper.SetDefault("telemetry.enable_profiler", globalConf.TelemetryConfig.EnableProfiler) + viper.SetDefault("telemetry.auth_user_file", globalConf.TelemetryConfig.AuthUserFile) + viper.SetDefault("telemetry.certificate_file", globalConf.TelemetryConfig.CertificateFile) + viper.SetDefault("telemetry.certificate_key_file", globalConf.TelemetryConfig.CertificateKeyFile) + viper.SetDefault("telemetry.min_tls_version", globalConf.TelemetryConfig.MinTLSVersion) + viper.SetDefault("telemetry.tls_cipher_suites", globalConf.TelemetryConfig.TLSCipherSuites) + viper.SetDefault("telemetry.tls_protocols", globalConf.TelemetryConfig.Protocols) + viper.SetDefault("smtp.host", globalConf.SMTPConfig.Host) + viper.SetDefault("smtp.port", globalConf.SMTPConfig.Port) + viper.SetDefault("smtp.from", globalConf.SMTPConfig.From) + viper.SetDefault("smtp.user", globalConf.SMTPConfig.User) + viper.SetDefault("smtp.password", globalConf.SMTPConfig.Password) + viper.SetDefault("smtp.auth_type", globalConf.SMTPConfig.AuthType) + viper.SetDefault("smtp.encryption", globalConf.SMTPConfig.Encryption) + viper.SetDefault("smtp.domain", globalConf.SMTPConfig.Domain) + viper.SetDefault("smtp.templates_path", globalConf.SMTPConfig.TemplatesPath) +} + +func lookupBoolFromEnv(envName string) (bool, bool) { + value, ok := os.LookupEnv(envName) + if ok { + converted, err := strconv.ParseBool(strings.TrimSpace(value)) + if err == nil { + return converted, ok + } + } + + return false, false +} + +func lookupIntFromEnv(envName string, bitSize int) (int64, bool) { + value, ok := os.LookupEnv(envName) + if ok { + converted, err := strconv.ParseInt(strings.TrimSpace(value), 10, bitSize) + if err == nil { + return converted, ok + } + } + + return 0, false +} + +func lookupStringListFromEnv(envName string) ([]string, bool) { + value, ok := os.LookupEnv(envName) + if ok { + var result []string + for v := range strings.SplitSeq(value, ",") { + val := strings.TrimSpace(v) + if val != "" { + result = append(result, val) + } + } + return result, true + } + return nil, false +} diff --git a/internal/config/config_darwin.go b/internal/config/config_darwin.go new file mode 100644 index 00000000..808c1cd3 --- /dev/null +++ b/internal/config/config_darwin.go @@ -0,0 +1,24 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build darwin + +package config + +import "github.com/spf13/viper" + +// macOS specific config search path +func setViperAdditionalConfigPaths() { + viper.AddConfigPath("/usr/local/etc/sftpgo") +} diff --git a/internal/config/config_fallback.go b/internal/config/config_fallback.go new file mode 100644 index 00000000..f841410a --- /dev/null +++ b/internal/config/config_fallback.go @@ -0,0 +1,19 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !linux && !darwin + +package config + +func setViperAdditionalConfigPaths() {} diff --git a/internal/config/config_linux.go b/internal/config/config_linux.go new file mode 100644 index 00000000..21e6eed5 --- /dev/null +++ b/internal/config/config_linux.go @@ -0,0 +1,26 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build linux + +package config + +import "github.com/spf13/viper" + +// linux specific config search path +func setViperAdditionalConfigPaths() { + viper.AddConfigPath("$HOME/.config/sftpgo") + viper.AddConfigPath("/etc/sftpgo") + viper.AddConfigPath("/usr/local/etc/sftpgo") +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 00000000..668d6769 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,1641 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package config_test + +import ( + "crypto/rand" + "encoding/json" + "os" + "path/filepath" + "slices" + "testing" + + "github.com/sftpgo/sdk/kms" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/command" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + tempConfigName = "temp" +) + +var ( + configDir = filepath.Join(".", "..", "..") +) + +func reset() { + viper.Reset() + config.Init() +} + +func TestLoadConfigTest(t *testing.T) { + reset() + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig()) + assert.NotEqual(t, dataprovider.Config{}, config.GetProviderConf()) + assert.NotEqual(t, sftpd.Configuration{}, config.GetSFTPDConfig()) + assert.NotEqual(t, httpclient.Config{}, config.GetHTTPConfig()) + assert.NotEqual(t, smtp.Config{}, config.GetSMTPConfig()) + confName := tempConfigName + ".json" //nolint:goconst + configFilePath := filepath.Join(configDir, confName) + err = config.LoadConfig(configDir, confName) + assert.Error(t, err) + err = os.WriteFile(configFilePath, []byte("{invalid json}"), os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.Error(t, err) + err = os.WriteFile(configFilePath, []byte(`{"sftpd": {"max_auth_tries": "a"}}`), os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.Error(t, err) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestLoadConfigFileNotFound(t *testing.T) { + reset() + + viper.SetConfigName("configfile") + err := config.LoadConfig(os.TempDir(), "") + require.NoError(t, err) + mfaConf := config.GetMFAConfig() + require.Len(t, mfaConf.TOTP, 1) + require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1) + require.Len(t, config.GetCommonConfig().RateLimitersConfig[0].Protocols, 4) + require.Len(t, config.GetHTTPDConfig().Bindings, 1) + require.Len(t, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes, 3) +} + +func TestReadEnvFiles(t *testing.T) { + reset() + + envd := filepath.Join(configDir, "env.d") + err := os.Mkdir(envd, os.ModePerm) + assert.NoError(t, err) + + content := make([]byte, 1048576+1) + _, err = rand.Read(content) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(envd, "env1"), []byte("SFTPGO_SFTPD__MAX_AUTH_TRIES = 10"), 0666) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(envd, "env2"), []byte(`{"invalid env": "value"}`), 0666) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(envd, "env3"), content, 0666) + assert.NoError(t, err) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + assert.Equal(t, 10, config.GetSFTPDConfig().MaxAuthTries) + + _, ok := os.LookupEnv("SFTPGO_SFTPD__MAX_AUTH_TRIES") + assert.True(t, ok) + err = os.Unsetenv("SFTPGO_SFTPD__MAX_AUTH_TRIES") + assert.NoError(t, err) + os.RemoveAll(envd) +} + +func TestEnabledSSHCommands(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + + reset() + + sftpdConf := config.GetSFTPDConfig() + sftpdConf.EnabledSSHCommands = []string{"scp"} + c := make(map[string]sftpd.Configuration) + c["sftpd"] = sftpdConf + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + sftpdConf = config.GetSFTPDConfig() + if assert.Len(t, sftpdConf.EnabledSSHCommands, 1) { + assert.Equal(t, "scp", sftpdConf.EnabledSSHCommands[0]) + } + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestInvalidExternalAuthScope(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.ExternalAuthScope = 100 + c := make(map[string]dataprovider.Config) + c["data_provider"] = providerConf + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + assert.Equal(t, 0, config.GetProviderConf().ExternalAuthScope) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestInvalidProxyProtocol(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + commonConf := config.GetCommonConfig() + commonConf.ProxyProtocol = 10 + c := make(map[string]common.Configuration) + c["common"] = commonConf + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + assert.Equal(t, 0, config.GetCommonConfig().ProxyProtocol) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestInvalidUsersBaseDir(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.UsersBaseDir = "." + c := make(map[string]dataprovider.Config) + c["data_provider"] = providerConf + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + assert.Empty(t, config.GetProviderConf().UsersBaseDir) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestInvalidInstallationHint(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + httpdConfig := config.GetHTTPDConfig() + httpdConfig.Setup = httpd.SetupConfig{ + InstallationCode: "abc", + InstallationCodeHint: " ", + } + c := make(map[string]httpd.Conf) + c["httpd"] = httpdConfig + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + httpdConfig = config.GetHTTPDConfig() + assert.Equal(t, "abc", httpdConfig.Setup.InstallationCode) + assert.Equal(t, "Installation code", httpdConfig.Setup.InstallationCodeHint) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestInvalidRenameMode(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + commonConfig := config.GetCommonConfig() + commonConfig.RenameMode = 10 + c := make(map[string]any) + c["common"] = commonConfig + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + assert.Equal(t, 0, config.GetCommonConfig().RenameMode) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestDefenderProviderDriver(t *testing.T) { + if config.GetProviderConf().Driver != dataprovider.SQLiteDataProviderName { + t.Skip("this test is not supported with the current database provider") + } + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + providerConf := config.GetProviderConf() + providerConf.Driver = dataprovider.BoltDataProviderName + commonConfig := config.GetCommonConfig() + commonConfig.DefenderConfig.Enabled = true + commonConfig.DefenderConfig.Driver = common.DefenderDriverProvider + c := make(map[string]any) + c["common"] = commonConfig + c["data_provider"] = providerConf + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + assert.Equal(t, dataprovider.BoltDataProviderName, config.GetProviderConf().Driver) + assert.Equal(t, common.DefenderDriverMemory, config.GetCommonConfig().DefenderConfig.Driver) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestSetGetConfig(t *testing.T) { + reset() + + sftpdConf := config.GetSFTPDConfig() + sftpdConf.MaxAuthTries = 10 + config.SetSFTPDConfig(sftpdConf) + assert.Equal(t, sftpdConf.MaxAuthTries, config.GetSFTPDConfig().MaxAuthTries) + dataProviderConf := config.GetProviderConf() + dataProviderConf.Host = "test host" + config.SetProviderConf(dataProviderConf) + assert.Equal(t, dataProviderConf.Host, config.GetProviderConf().Host) + httpdConf := config.GetHTTPDConfig() + httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{Address: "0.0.0.0"}) + config.SetHTTPDConfig(httpdConf) + assert.Equal(t, httpdConf.Bindings[0].Address, config.GetHTTPDConfig().Bindings[0].Address) + commonConf := config.GetCommonConfig() + commonConf.IdleTimeout = 10 + config.SetCommonConfig(commonConf) + assert.Equal(t, commonConf.IdleTimeout, config.GetCommonConfig().IdleTimeout) + ftpdConf := config.GetFTPDConfig() + ftpdConf.CertificateFile = "cert" + ftpdConf.CertificateKeyFile = "key" + config.SetFTPDConfig(ftpdConf) + assert.Equal(t, ftpdConf.CertificateFile, config.GetFTPDConfig().CertificateFile) + assert.Equal(t, ftpdConf.CertificateKeyFile, config.GetFTPDConfig().CertificateKeyFile) + webDavConf := config.GetWebDAVDConfig() + webDavConf.CertificateFile = "dav_cert" + webDavConf.CertificateKeyFile = "dav_key" + config.SetWebDAVDConfig(webDavConf) + assert.Equal(t, webDavConf.CertificateFile, config.GetWebDAVDConfig().CertificateFile) + assert.Equal(t, webDavConf.CertificateKeyFile, config.GetWebDAVDConfig().CertificateKeyFile) + kmsConf := config.GetKMSConfig() + kmsConf.Secrets.MasterKeyPath = "apath" + kmsConf.Secrets.URL = "aurl" + config.SetKMSConfig(kmsConf) + assert.Equal(t, kmsConf.Secrets.MasterKeyPath, config.GetKMSConfig().Secrets.MasterKeyPath) + assert.Equal(t, kmsConf.Secrets.URL, config.GetKMSConfig().Secrets.URL) + telemetryConf := config.GetTelemetryConfig() + telemetryConf.BindPort = 10001 + telemetryConf.BindAddress = "0.0.0.0" + config.SetTelemetryConfig(telemetryConf) + assert.Equal(t, telemetryConf.BindPort, config.GetTelemetryConfig().BindPort) + assert.Equal(t, telemetryConf.BindAddress, config.GetTelemetryConfig().BindAddress) + pluginConf := []plugin.Config{ + { + Type: "eventsearcher", + }, + } + config.SetPluginsConfig(pluginConf) + if assert.Len(t, config.GetPluginsConfig(), 1) { + assert.Equal(t, pluginConf[0].Type, config.GetPluginsConfig()[0].Type) + } + assert.False(t, config.HasKMSPlugin()) + pluginConf = []plugin.Config{ + { + Type: "notifier", + }, + { + Type: "kms", + }, + } + config.SetPluginsConfig(pluginConf) + assert.Len(t, config.GetPluginsConfig(), 2) + assert.True(t, config.HasKMSPlugin()) +} + +func TestServiceToStart(t *testing.T) { + reset() + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + assert.True(t, config.HasServicesToStart()) + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings[0].Port = 0 + config.SetSFTPDConfig(sftpdConf) + // httpd service is enabled + assert.True(t, config.HasServicesToStart()) + httpdConf := config.GetHTTPDConfig() + httpdConf.Bindings[0].Port = 0 + assert.False(t, config.HasServicesToStart()) + ftpdConf := config.GetFTPDConfig() + ftpdConf.Bindings[0].Port = 2121 + config.SetFTPDConfig(ftpdConf) + assert.True(t, config.HasServicesToStart()) + ftpdConf.Bindings[0].Port = 0 + config.SetFTPDConfig(ftpdConf) + webdavdConf := config.GetWebDAVDConfig() + webdavdConf.Bindings[0].Port = 9000 + config.SetWebDAVDConfig(webdavdConf) + assert.True(t, config.HasServicesToStart()) + webdavdConf.Bindings[0].Port = 0 + config.SetWebDAVDConfig(webdavdConf) + assert.False(t, config.HasServicesToStart()) + sftpdConf.Bindings[0].Port = 2022 + config.SetSFTPDConfig(sftpdConf) + assert.True(t, config.HasServicesToStart()) +} + +func TestSSHCommandsFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS", "cd,scp") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + + sftpdConf := config.GetSFTPDConfig() + if assert.Len(t, sftpdConf.EnabledSSHCommands, 2) { + assert.Equal(t, "cd", sftpdConf.EnabledSSHCommands[0]) + assert.Equal(t, "scp", sftpdConf.EnabledSSHCommands[1]) + } +} + +func TestSMTPFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_SMTP__HOST", "smtp.example.com") + os.Setenv("SFTPGO_SMTP__PORT", "587") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_SMTP__HOST") + os.Unsetenv("SFTPGO_SMTP__PORT") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + smtpConfig := config.GetSMTPConfig() + assert.Equal(t, "smtp.example.com", smtpConfig.Host) + assert.Equal(t, 587, smtpConfig.Port) +} + +func TestMFAFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_MFA__TOTP__0__NAME", "main") + os.Setenv("SFTPGO_MFA__TOTP__1__NAME", "additional_name") + os.Setenv("SFTPGO_MFA__TOTP__1__ISSUER", "additional_issuer") + os.Setenv("SFTPGO_MFA__TOTP__1__ALGO", "sha256") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_MFA__TOTP__0__NAME") + os.Unsetenv("SFTPGO_MFA__TOTP__1__NAME") + os.Unsetenv("SFTPGO_MFA__TOTP__1__ISSUER") + os.Unsetenv("SFTPGO_MFA__TOTP__1__ALGO") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + mfaConf := config.GetMFAConfig() + require.Len(t, mfaConf.TOTP, 2) + require.Equal(t, "main", mfaConf.TOTP[0].Name) + require.Equal(t, "SFTPGo", mfaConf.TOTP[0].Issuer) + require.Equal(t, "sha1", mfaConf.TOTP[0].Algo) + require.Equal(t, "additional_name", mfaConf.TOTP[1].Name) + require.Equal(t, "additional_issuer", mfaConf.TOTP[1].Issuer) + require.Equal(t, "sha256", mfaConf.TOTP[1].Algo) +} + +func TestDisabledMFAConfig(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + mfaConf := config.GetMFAConfig() + assert.Len(t, mfaConf.TOTP, 1) + + reset() + + c := make(map[string]mfa.Config) + c["mfa"] = mfa.Config{} + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + mfaConf = config.GetMFAConfig() + assert.Len(t, mfaConf.TOTP, 0) + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestOverrideSliceValues(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + c := make(map[string]any) + c["common"] = common.Configuration{ + RateLimitersConfig: []common.RateLimiterConfig{ + { + Type: 1, + Protocols: []string{"HTTP"}, + }, + }, + } + jsonConf, err := json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1) + require.Equal(t, []string{"HTTP"}, config.GetCommonConfig().RateLimitersConfig[0].Protocols) + + reset() + + // empty ratelimiters, default value should be used + c["common"] = common.Configuration{} + jsonConf, err = json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1) + rl := config.GetCommonConfig().RateLimitersConfig[0] + require.Equal(t, []string{"SSH", "FTP", "DAV", "HTTP"}, rl.Protocols) + require.Equal(t, int64(1000), rl.Period) + + reset() + + c = make(map[string]any) + c["httpd"] = httpd.Conf{ + Bindings: []httpd.Binding{ + { + OIDC: httpd.OIDC{ + Scopes: []string{"scope1"}, + }, + }, + }, + } + jsonConf, err = json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + require.Len(t, config.GetHTTPDConfig().Bindings, 1) + require.Equal(t, []string{"scope1"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes) + + reset() + + c = make(map[string]any) + c["httpd"] = httpd.Conf{ + Bindings: nil, + } + jsonConf, err = json.Marshal(c) + assert.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + assert.NoError(t, err) + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + require.Len(t, config.GetHTTPDConfig().Bindings, 1) + require.Equal(t, []string{"openid", "profile", "email"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes) +} + +func TestFTPDOverridesFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP", "192.168.1.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__NETWORKS", "192.168.1.0/24, 192.168.3.0/25") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__IP", "192.168.2.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__NETWORKS", "192.168.2.0/24") + cleanup := func() { + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__NETWORKS") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__IP") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__NETWORKS") + } + t.Cleanup(cleanup) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + ftpdConf := config.GetFTPDConfig() + require.Len(t, ftpdConf.Bindings, 1) + require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides, 2) + require.Equal(t, "192.168.1.1", ftpdConf.Bindings[0].PassiveIPOverrides[0].IP) + require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[0].Networks, 2) + require.Equal(t, "192.168.2.1", ftpdConf.Bindings[0].PassiveIPOverrides[1].IP) + require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[1].Networks, 1) + + cleanup() + cfg := make(map[string]any) + cfg["ftpd"] = ftpdConf + configAsJSON, err := json.Marshal(cfg) + require.NoError(t, err) + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err = os.WriteFile(configFilePath, configAsJSON, os.ModePerm) + assert.NoError(t, err) + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP", "192.168.1.2") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__1__NETWORKS", "192.168.2.0/24,192.168.4.0/25") + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + ftpdConf = config.GetFTPDConfig() + require.Len(t, ftpdConf.Bindings, 1) + require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides, 2) + require.Equal(t, "192.168.1.2", ftpdConf.Bindings[0].PassiveIPOverrides[0].IP) + require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[0].Networks, 2) + require.Equal(t, "192.168.2.1", ftpdConf.Bindings[0].PassiveIPOverrides[1].IP) + require.Len(t, ftpdConf.Bindings[0].PassiveIPOverrides[1].Networks, 2) + + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestHTTPDSubObjectsFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__KEY", "X-Forwarded-Proto") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__VALUE", "https") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_ID", "client_id") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET", "client_secret") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET_FILE", "client_secret_file") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CONFIG_URL", "config_url") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__REDIRECT_BASE_URL", "redirect_base_url") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__USERNAME_FIELD", "email") + cleanup := func() { + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__KEY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__VALUE") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_ID") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET_FILE") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CONFIG_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__REDIRECT_BASE_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__USERNAME_FIELD") + } + t.Cleanup(cleanup) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + httpdConf := config.GetHTTPDConfig() + require.Len(t, httpdConf.Bindings, 1) + require.Len(t, httpdConf.Bindings[0].Security.HTTPSProxyHeaders, 1) + require.Equal(t, "client_id", httpdConf.Bindings[0].OIDC.ClientID) + require.Equal(t, "client_secret", httpdConf.Bindings[0].OIDC.ClientSecret) + require.Equal(t, "client_secret_file", httpdConf.Bindings[0].OIDC.ClientSecretFile) + require.Equal(t, "config_url", httpdConf.Bindings[0].OIDC.ConfigURL) + require.Equal(t, "redirect_base_url", httpdConf.Bindings[0].OIDC.RedirectBaseURL) + require.Equal(t, "email", httpdConf.Bindings[0].OIDC.UsernameField) + + cleanup() + cfg := make(map[string]any) + cfg["httpd"] = httpdConf + configAsJSON, err := json.Marshal(cfg) + require.NoError(t, err) + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err = os.WriteFile(configFilePath, configAsJSON, os.ModePerm) + assert.NoError(t, err) + + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__SECURITY__HTTPS_PROXY_HEADERS__0__VALUE", "http") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__OIDC__CLIENT_SECRET", "new_client_secret") + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + httpdConf = config.GetHTTPDConfig() + require.Len(t, httpdConf.Bindings, 1) + require.Len(t, httpdConf.Bindings[0].Security.HTTPSProxyHeaders, 1) + require.Equal(t, "http", httpdConf.Bindings[0].Security.HTTPSProxyHeaders[0].Value) + require.Equal(t, "client_id", httpdConf.Bindings[0].OIDC.ClientID) + require.Equal(t, "new_client_secret", httpdConf.Bindings[0].OIDC.ClientSecret) + require.Equal(t, "config_url", httpdConf.Bindings[0].OIDC.ConfigURL) + require.Equal(t, "redirect_base_url", httpdConf.Bindings[0].OIDC.RedirectBaseURL) + require.Equal(t, "email", httpdConf.Bindings[0].OIDC.UsernameField) + + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestPluginsFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_PLUGINS__0__TYPE", "notifier") + os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__FS_EVENTS", "upload,download") + os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_EVENTS", "add,update") + os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_OBJECTS", "user,admin") + os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__LOG_EVENTS", "a,1,2") + os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_MAX_TIME", "2") + os.Setenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE", "1000") + os.Setenv("SFTPGO_PLUGINS__0__CMD", "plugin_start_cmd") + os.Setenv("SFTPGO_PLUGINS__0__ARGS", "arg1,arg2") + os.Setenv("SFTPGO_PLUGINS__0__SHA256SUM", "0a71ded61fccd59c4f3695b51c1b3d180da8d2d77ea09ccee20dac242675c193") + os.Setenv("SFTPGO_PLUGINS__0__AUTO_MTLS", "1") + os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__SCHEME", kms.SchemeAWS) + os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__ENCRYPTED_STATUS", kms.SecretStatusAWS) + os.Setenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE", "14") + os.Setenv("SFTPGO_PLUGINS__0__ENV_PREFIX", "prefix_") + os.Setenv("SFTPGO_PLUGINS__0__ENV_VARS", "a, b") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_PLUGINS__0__TYPE") + os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__FS_EVENTS") + os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_EVENTS") + os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__PROVIDER_OBJECTS") + os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__LOG_EVENTS") + os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_MAX_TIME") + os.Unsetenv("SFTPGO_PLUGINS__0__NOTIFIER_OPTIONS__RETRY_QUEUE_MAX_SIZE") + os.Unsetenv("SFTPGO_PLUGINS__0__CMD") + os.Unsetenv("SFTPGO_PLUGINS__0__ARGS") + os.Unsetenv("SFTPGO_PLUGINS__0__SHA256SUM") + os.Unsetenv("SFTPGO_PLUGINS__0__AUTO_MTLS") + os.Unsetenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__SCHEME") + os.Unsetenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__ENCRYPTED_STATUS") + os.Unsetenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE") + os.Unsetenv("SFTPGO_PLUGINS__0__ENV_PREFIX") + os.Unsetenv("SFTPGO_PLUGINS__0__ENV_VARS") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + pluginsConf := config.GetPluginsConfig() + require.Len(t, pluginsConf, 1) + pluginConf := pluginsConf[0] + require.Equal(t, "notifier", pluginConf.Type) + require.Len(t, pluginConf.NotifierOptions.FsEvents, 2) + require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload")) + require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download")) + require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2) + require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0]) + require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1]) + require.Len(t, pluginConf.NotifierOptions.ProviderObjects, 2) + require.Equal(t, "user", pluginConf.NotifierOptions.ProviderObjects[0]) + require.Equal(t, "admin", pluginConf.NotifierOptions.ProviderObjects[1]) + require.Len(t, pluginConf.NotifierOptions.LogEvents, 2) + require.Equal(t, 1, pluginConf.NotifierOptions.LogEvents[0]) + require.Equal(t, 2, pluginConf.NotifierOptions.LogEvents[1]) + require.Equal(t, 2, pluginConf.NotifierOptions.RetryMaxTime) + require.Equal(t, 1000, pluginConf.NotifierOptions.RetryQueueMaxSize) + require.Equal(t, "plugin_start_cmd", pluginConf.Cmd) + require.Len(t, pluginConf.Args, 2) + require.Equal(t, "arg1", pluginConf.Args[0]) + require.Equal(t, "arg2", pluginConf.Args[1]) + require.Equal(t, "0a71ded61fccd59c4f3695b51c1b3d180da8d2d77ea09ccee20dac242675c193", pluginConf.SHA256Sum) + require.True(t, pluginConf.AutoMTLS) + require.Equal(t, kms.SchemeAWS, pluginConf.KMSOptions.Scheme) + require.Equal(t, kms.SecretStatusAWS, pluginConf.KMSOptions.EncryptedStatus) + require.Equal(t, 14, pluginConf.AuthOptions.Scope) + require.Equal(t, "prefix_", pluginConf.EnvPrefix) + require.Len(t, pluginConf.EnvVars, 2) + assert.Equal(t, "a", pluginConf.EnvVars[0]) + assert.Equal(t, "b", pluginConf.EnvVars[1]) + + cfg := make(map[string]any) + cfg["plugins"] = pluginConf + configAsJSON, err := json.Marshal(cfg) + require.NoError(t, err) + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err = os.WriteFile(configFilePath, configAsJSON, os.ModePerm) + assert.NoError(t, err) + + os.Setenv("SFTPGO_PLUGINS__0__CMD", "plugin_start_cmd1") + os.Setenv("SFTPGO_PLUGINS__0__ARGS", "") + os.Setenv("SFTPGO_PLUGINS__0__AUTO_MTLS", "0") + os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__SCHEME", kms.SchemeVaultTransit) + os.Setenv("SFTPGO_PLUGINS__0__KMS_OPTIONS__ENCRYPTED_STATUS", kms.SecretStatusVaultTransit) + os.Setenv("SFTPGO_PLUGINS__0__ENV_PREFIX", "") + os.Setenv("SFTPGO_PLUGINS__0__ENV_VARS", "") + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + pluginsConf = config.GetPluginsConfig() + require.Len(t, pluginsConf, 1) + pluginConf = pluginsConf[0] + require.Equal(t, "notifier", pluginConf.Type) + require.Len(t, pluginConf.NotifierOptions.FsEvents, 2) + require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "upload")) + require.True(t, slices.Contains(pluginConf.NotifierOptions.FsEvents, "download")) + require.Len(t, pluginConf.NotifierOptions.ProviderEvents, 2) + require.Equal(t, "add", pluginConf.NotifierOptions.ProviderEvents[0]) + require.Equal(t, "update", pluginConf.NotifierOptions.ProviderEvents[1]) + require.Len(t, pluginConf.NotifierOptions.ProviderObjects, 2) + require.Equal(t, "user", pluginConf.NotifierOptions.ProviderObjects[0]) + require.Equal(t, "admin", pluginConf.NotifierOptions.ProviderObjects[1]) + require.Equal(t, 2, pluginConf.NotifierOptions.RetryMaxTime) + require.Equal(t, 1000, pluginConf.NotifierOptions.RetryQueueMaxSize) + require.Equal(t, "plugin_start_cmd1", pluginConf.Cmd) + require.Len(t, pluginConf.Args, 0) + require.Equal(t, "0a71ded61fccd59c4f3695b51c1b3d180da8d2d77ea09ccee20dac242675c193", pluginConf.SHA256Sum) + require.False(t, pluginConf.AutoMTLS) + require.Equal(t, kms.SchemeVaultTransit, pluginConf.KMSOptions.Scheme) + require.Equal(t, kms.SecretStatusVaultTransit, pluginConf.KMSOptions.EncryptedStatus) + require.Equal(t, 14, pluginConf.AuthOptions.Scope) + assert.Empty(t, pluginConf.EnvPrefix) + assert.Len(t, pluginConf.EnvVars, 0) + + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestRateLimitersFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__AVERAGE", "100") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__PERIOD", "2000") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__BURST", "10") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__TYPE", "2") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__PROTOCOLS", "SSH, FTP") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__GENERATE_DEFENDER_EVENTS", "1") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_SOFT_LIMIT", "50") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_HARD_LIMIT", "100") + os.Setenv("SFTPGO_COMMON__RATE_LIMITERS__8__AVERAGE", "50") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__AVERAGE") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__PERIOD") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__BURST") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__TYPE") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__PROTOCOLS") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__GENERATE_DEFENDER_EVENTS") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_SOFT_LIMIT") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__0__ENTRIES_HARD_LIMIT") + os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__8__AVERAGE") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + limiters := config.GetCommonConfig().RateLimitersConfig + require.Len(t, limiters, 2) + require.Equal(t, int64(100), limiters[0].Average) + require.Equal(t, int64(2000), limiters[0].Period) + require.Equal(t, 10, limiters[0].Burst) + require.Equal(t, 2, limiters[0].Type) + protocols := limiters[0].Protocols + require.Len(t, protocols, 2) + require.True(t, slices.Contains(protocols, common.ProtocolFTP)) + require.True(t, slices.Contains(protocols, common.ProtocolSSH)) + require.True(t, limiters[0].GenerateDefenderEvents) + require.Equal(t, 50, limiters[0].EntriesSoftLimit) + require.Equal(t, 100, limiters[0].EntriesHardLimit) + require.Equal(t, int64(50), limiters[1].Average) + // we check the default values here + require.Equal(t, int64(1000), limiters[1].Period) + require.Equal(t, 1, limiters[1].Burst) + require.Equal(t, 2, limiters[1].Type) + protocols = limiters[1].Protocols + require.Len(t, protocols, 4) + require.True(t, slices.Contains(protocols, common.ProtocolFTP)) + require.True(t, slices.Contains(protocols, common.ProtocolSSH)) + require.True(t, slices.Contains(protocols, common.ProtocolWebDAV)) + require.True(t, slices.Contains(protocols, common.ProtocolHTTP)) + require.False(t, limiters[1].GenerateDefenderEvents) + require.Equal(t, 100, limiters[1].EntriesSoftLimit) + require.Equal(t, 150, limiters[1].EntriesHardLimit) +} + +func TestSFTPDBindingsFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS", "127.0.0.1") + os.Setenv("SFTPGO_SFTPD__BINDINGS__0__PORT", "2200") + os.Setenv("SFTPGO_SFTPD__BINDINGS__0__APPLY_PROXY_CONFIG", "false") + os.Setenv("SFTPGO_SFTPD__BINDINGS__3__ADDRESS", "127.0.1.1") + os.Setenv("SFTPGO_SFTPD__BINDINGS__3__PORT", "2203") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS") + os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__PORT") + os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__APPLY_PROXY_CONFIG") + os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__ADDRESS") + os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__PORT") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + bindings := config.GetSFTPDConfig().Bindings + require.Len(t, bindings, 2) + require.Equal(t, 2200, bindings[0].Port) + require.Equal(t, "127.0.0.1", bindings[0].Address) + require.False(t, bindings[0].ApplyProxyConfig) + require.Equal(t, 2203, bindings[1].Port) + require.Equal(t, "127.0.1.1", bindings[1].Address) + require.True(t, bindings[1].ApplyProxyConfig) // default value +} + +func TestCommandsFromEnv(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + commandConfig := config.GetCommandConfig() + commandConfig.Commands = append(commandConfig.Commands, command.Command{ + Path: "cmd", + Timeout: 10, + Env: []string{"a=a"}, + }) + c := make(map[string]command.Config) + c["command"] = commandConfig + jsonConf, err := json.Marshal(c) + require.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + require.NoError(t, err) + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + commandConfig = config.GetCommandConfig() + require.Equal(t, 30, commandConfig.Timeout) + require.Len(t, commandConfig.Env, 0) + require.Len(t, commandConfig.Commands, 1) + require.Equal(t, "cmd", commandConfig.Commands[0].Path) + require.Equal(t, 10, commandConfig.Commands[0].Timeout) + require.Equal(t, []string{"a=a"}, commandConfig.Commands[0].Env) + + os.Setenv("SFTPGO_COMMAND__TIMEOUT", "25") + os.Setenv("SFTPGO_COMMAND__ENV", "a=b,c=d") + os.Setenv("SFTPGO_COMMAND__COMMANDS__0__PATH", "cmd1") + os.Setenv("SFTPGO_COMMAND__COMMANDS__0__TIMEOUT", "11") + os.Setenv("SFTPGO_COMMAND__COMMANDS__1__PATH", "cmd2") + os.Setenv("SFTPGO_COMMAND__COMMANDS__1__TIMEOUT", "20") + os.Setenv("SFTPGO_COMMAND__COMMANDS__1__ENV", "e=f") + os.Setenv("SFTPGO_COMMAND__COMMANDS__1__ARGS", "arg1, arg2") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_COMMAND__TIMEOUT") + os.Unsetenv("SFTPGO_COMMAND__ENV") + os.Unsetenv("SFTPGO_COMMAND__COMMANDS__0__PATH") + os.Unsetenv("SFTPGO_COMMAND__COMMANDS__0__TIMEOUT") + os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__PATH") + os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__TIMEOUT") + os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__ENV") + os.Unsetenv("SFTPGO_COMMAND__COMMANDS__1__ARGS") + }) + + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + commandConfig = config.GetCommandConfig() + require.Equal(t, 25, commandConfig.Timeout) + require.Equal(t, []string{"a=b", "c=d"}, commandConfig.Env) + require.Len(t, commandConfig.Commands, 2) + require.Equal(t, "cmd1", commandConfig.Commands[0].Path) + require.Equal(t, 11, commandConfig.Commands[0].Timeout) + require.Equal(t, []string{"a=a"}, commandConfig.Commands[0].Env) + require.Equal(t, "cmd2", commandConfig.Commands[1].Path) + require.Equal(t, 20, commandConfig.Commands[1].Timeout) + require.Equal(t, []string{"e=f"}, commandConfig.Commands[1].Env) + require.Equal(t, []string{"arg1", "arg2"}, commandConfig.Commands[1].Args) + + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestFTPDBindingsFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_FTPD__BINDINGS__0__ADDRESS", "127.0.0.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PORT", "2200") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__APPLY_PROXY_CONFIG", "f") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__TLS_MODE", "2") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__FORCE_PASSIVE_IP", "127.0.1.2") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP", "172.16.1.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_HOST", "127.0.1.3") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__TLS_CIPHER_SUITES", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") + os.Setenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_CONNECTIONS_SECURITY", "1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__ADDRESS", "127.0.1.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__PORT", "2203") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__TLS_MODE", "1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__MIN_TLS_VERSION", "13") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__FORCE_PASSIVE_IP", "127.0.1.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__IP", "192.168.1.1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__NETWORKS", "192.168.1.0/24, 192.168.3.0/25") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__CLIENT_AUTH_TYPE", "2") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__DEBUG", "1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__ACTIVE_CONNECTIONS_SECURITY", "1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__IGNORE_ASCII_TRANSFER_TYPE", "1") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_FILE", "cert.crt") + os.Setenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE", "cert.key") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__ADDRESS") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PORT") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__APPLY_PROXY_CONFIG") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__TLS_MODE") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__FORCE_PASSIVE_IP") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_IP_OVERRIDES__0__IP") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__PASSIVE_HOST") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__0__ACTIVE_CONNECTIONS_SECURITY") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__ADDRESS") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PORT") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__TLS_MODE") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__MIN_TLS_VERSION") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__FORCE_PASSIVE_IP") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__IP") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__PASSIVE_IP_OVERRIDES__3__NETWORKS") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CLIENT_AUTH_TYPE") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__DEBUG") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__ACTIVE_CONNECTIONS_SECURITY") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__IGNORE_ASCII_TRANSFER_TYPE") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_FILE") + os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + bindings := config.GetFTPDConfig().Bindings + require.Len(t, bindings, 2) + require.Equal(t, 2200, bindings[0].Port) + require.Equal(t, "127.0.0.1", bindings[0].Address) + require.False(t, bindings[0].ApplyProxyConfig) + require.Equal(t, 2, bindings[0].TLSMode) + require.Equal(t, 12, bindings[0].MinTLSVersion) + require.Equal(t, "127.0.1.2", bindings[0].ForcePassiveIP) + require.Len(t, bindings[0].PassiveIPOverrides, 0) + require.Equal(t, "127.0.1.3", bindings[0].PassiveHost) + require.Equal(t, 0, bindings[0].ClientAuthType) + require.Len(t, bindings[0].TLSCipherSuites, 2) + require.Equal(t, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", bindings[0].TLSCipherSuites[0]) + require.Equal(t, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", bindings[0].TLSCipherSuites[1]) + require.False(t, bindings[0].Debug) + require.Equal(t, 1, bindings[0].PassiveConnectionsSecurity) + require.Equal(t, 0, bindings[0].ActiveConnectionsSecurity) + require.Equal(t, 2203, bindings[1].Port) + require.Equal(t, "127.0.1.1", bindings[1].Address) + require.True(t, bindings[1].ApplyProxyConfig) // default value + require.Equal(t, 1, bindings[1].TLSMode) + require.Equal(t, 13, bindings[1].MinTLSVersion) + require.Equal(t, "127.0.1.1", bindings[1].ForcePassiveIP) + require.Empty(t, bindings[1].PassiveHost) + require.Len(t, bindings[1].PassiveIPOverrides, 1) + require.Equal(t, "192.168.1.1", bindings[1].PassiveIPOverrides[0].IP) + require.Len(t, bindings[1].PassiveIPOverrides[0].Networks, 2) + require.Equal(t, "192.168.1.0/24", bindings[1].PassiveIPOverrides[0].Networks[0]) + require.Equal(t, "192.168.3.0/25", bindings[1].PassiveIPOverrides[0].Networks[1]) + require.Equal(t, 2, bindings[1].ClientAuthType) + require.Nil(t, bindings[1].TLSCipherSuites) + require.Equal(t, 0, bindings[1].PassiveConnectionsSecurity) + require.Equal(t, 1, bindings[1].ActiveConnectionsSecurity) + require.True(t, bindings[1].Debug) + require.Equal(t, "cert.crt", bindings[1].CertificateFile) + require.Equal(t, "cert.key", bindings[1].CertificateKeyFile) +} + +func TestWebDAVMimeCache(t *testing.T) { + reset() + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + webdavdConf := config.GetWebDAVDConfig() + webdavdConf.Cache.MimeTypes.CustomMappings = []webdavd.CustomMimeMapping{ + { + Ext: ".custom", + Mime: "application/custom", + }, + } + cfg := map[string]any{ + "webdavd": webdavdConf, + } + data, err := json.Marshal(cfg) + assert.NoError(t, err) + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err = os.WriteFile(configFilePath, data, 0666) + assert.NoError(t, err) + + reset() + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + mappings := config.GetWebDAVDConfig().Cache.MimeTypes.CustomMappings + if assert.Len(t, mappings, 1) { + assert.Equal(t, ".custom", mappings[0].Ext) + assert.Equal(t, "application/custom", mappings[0].Mime) + } + // now add from env + os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__EXT", ".custom1") + os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__MIME", "application/custom1") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__EXT") + os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__MIME") + os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__EXT") + os.Unsetenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__1__MIME") + }) + reset() + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + mappings = config.GetWebDAVDConfig().Cache.MimeTypes.CustomMappings + if assert.Len(t, mappings, 2) { + assert.Equal(t, ".custom", mappings[0].Ext) + assert.Equal(t, "application/custom", mappings[0].Mime) + assert.Equal(t, ".custom1", mappings[1].Ext) + assert.Equal(t, "application/custom1", mappings[1].Mime) + } + // override from env + os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__EXT", ".custom0") + os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__MIME", "application/custom0") + reset() + err = config.LoadConfig(configDir, confName) + assert.NoError(t, err) + mappings = config.GetWebDAVDConfig().Cache.MimeTypes.CustomMappings + if assert.Len(t, mappings, 2) { + assert.Equal(t, ".custom0", mappings[0].Ext) + assert.Equal(t, "application/custom0", mappings[0].Mime) + assert.Equal(t, ".custom1", mappings[1].Ext) + assert.Equal(t, "application/custom1", mappings[1].Mime) + } + err = os.Remove(configFilePath) + assert.NoError(t, err) +} + +func TestWebDAVBindingsFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS", "127.0.0.1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT", "8000") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS", "0") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES", "TLS_RSA_WITH_AES_128_CBC_SHA ") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_PROTOCOLS", "http/1.1 ") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_MODE", "1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED", "192.168.10.1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_PROXY_HEADER", "X-Forwarded-For") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_HEADER_DEPTH", "2") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS", "127.0.1.1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT", "9000") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS", "1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__MIN_TLS_VERSION", "13") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CLIENT_AUTH_TYPE", "1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX", "/dav2") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE", "webdav.crt") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE", "webdav.key") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER", "1") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PORT") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ENABLE_HTTPS") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__TLS_PROTOCOLS") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_MODE") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__PROXY_ALLOWED") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_PROXY_HEADER") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__CLIENT_IP_HEADER_DEPTH") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ADDRESS") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PORT") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__ENABLE_HTTPS") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__MIN_TLS_VERSION") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CLIENT_AUTH_TYPE") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + bindings := config.GetWebDAVDConfig().Bindings + require.Len(t, bindings, 3) + require.Equal(t, 0, bindings[0].Port) + require.Empty(t, bindings[0].Address) + require.False(t, bindings[0].EnableHTTPS) + require.Equal(t, 12, bindings[0].MinTLSVersion) + require.Len(t, bindings[0].TLSCipherSuites, 0) + require.Len(t, bindings[0].Protocols, 0) + require.Equal(t, 0, bindings[0].ProxyMode) + require.Empty(t, bindings[0].Prefix) + require.Equal(t, 0, bindings[0].ClientIPHeaderDepth) + require.False(t, bindings[0].DisableWWWAuthHeader) + require.Equal(t, 8000, bindings[1].Port) + require.Equal(t, "127.0.0.1", bindings[1].Address) + require.False(t, bindings[1].EnableHTTPS) + require.Equal(t, 12, bindings[1].MinTLSVersion) + require.Equal(t, 0, bindings[1].ClientAuthType) + require.Len(t, bindings[1].TLSCipherSuites, 1) + require.Equal(t, "TLS_RSA_WITH_AES_128_CBC_SHA", bindings[1].TLSCipherSuites[0]) + require.Len(t, bindings[1].Protocols, 1) + assert.Equal(t, "http/1.1", bindings[1].Protocols[0]) + require.Equal(t, 1, bindings[1].ProxyMode) + require.Equal(t, "192.168.10.1", bindings[1].ProxyAllowed[0]) + require.Equal(t, "X-Forwarded-For", bindings[1].ClientIPProxyHeader) + require.Equal(t, 2, bindings[1].ClientIPHeaderDepth) + require.Empty(t, bindings[1].Prefix) + require.False(t, bindings[1].DisableWWWAuthHeader) + require.Equal(t, 9000, bindings[2].Port) + require.Equal(t, "127.0.1.1", bindings[2].Address) + require.True(t, bindings[2].EnableHTTPS) + require.Equal(t, 13, bindings[2].MinTLSVersion) + require.Equal(t, 1, bindings[2].ClientAuthType) + require.Equal(t, 0, bindings[2].ProxyMode) + require.Nil(t, bindings[2].TLSCipherSuites) + require.Equal(t, "/dav2", bindings[2].Prefix) + require.Equal(t, "webdav.crt", bindings[2].CertificateFile) + require.Equal(t, "webdav.key", bindings[2].CertificateKeyFile) + require.Equal(t, 0, bindings[2].ClientIPHeaderDepth) + require.True(t, bindings[2].DisableWWWAuthHeader) +} + +func TestHTTPDBindingsFromEnv(t *testing.T) { + reset() + + sockPath := filepath.Clean(os.TempDir()) + + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS", sockPath) + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__PORT", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__0__TLS_CIPHER_SUITES", " TLS_AES_128_GCM_SHA256") + os.Setenv("SFTPGO_HTTPD__BINDINGS__1__ADDRESS", "127.0.0.1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__1__PORT", "8000") + os.Setenv("SFTPGO_HTTPD__BINDINGS__1__ENABLE_HTTPS", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__1__HIDE_LOGIN_URL", " 1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_ADMIN__NAME", "Web Admin") + os.Setenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_CLIENT__SHORT_NAME", "WebClient") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ADDRESS", "127.0.1.1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PORT", "9000") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS", "3") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__DISABLED_LOGIN_METHODS", "12") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BASE_URL", "https://example.com") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__LANGUAGES", "en,es") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1 ") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION", "13") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES", " TLS_AES_256_GCM_SHA384 , TLS_CHACHA20_POLY1305_SHA256") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__TLS_PROTOCOLS", "h2, http/1.1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_MODE", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED", " 192.168.9.1 , 172.16.25.0/24") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_PROXY_HEADER", "X-Real-IP") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_HEADER_DEPTH", "2") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__HIDE_LOGIN_URL", "3") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_ID", "client id") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_SECRET", "client secret") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CONFIG_URL", "config url") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__REDIRECT_BASE_URL", "redirect base url") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__USERNAME_FIELD", "preferred_username") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__ROLE_FIELD", "sftpgo_role") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES", "openid") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS", "field1,field2") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED", "true") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS", "*.example.com,*.example.net") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS_ARE_REGEX", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HOSTS_PROXY_HEADERS", "X-Forwarded-Host") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_REDIRECT", "1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_HOST", "www.example.com") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__KEY", "X-Forwarded-Proto") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__VALUE", "https") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_SECONDS", "31536000") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_INCLUDE_SUBDOMAINS", "false") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_PRELOAD", "0") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_TYPE_NOSNIFF", "t") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_SECURITY_POLICY", "script-src $NONCE") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__PERMISSIONS_POLICY", "fullscreen=(), geolocation=()") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_OPENER_POLICY", "same-origin") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_RESOURCE_POLICY", "same-site") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_EMBEDDER_POLICY", "require-corp") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CACHE_CONTROL", "private") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__REFERRER_POLICY", "no-referrer") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__0__PATH", "path1") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__1__PATH", "path2") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__FAVICON_PATH", "favicon.ico") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__LOGO_PATH", "logo.png") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DISCLAIMER_NAME", "disclaimer") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__DISCLAIMER_PATH", "disclaimer.html") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DEFAULT_CSS", "default.css") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__EXTRA_CSS", "1.css,2.css") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_FILE", "httpd.crt") + os.Setenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE", "httpd.key") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__ADDRESS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__PORT") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__0__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__ADDRESS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__PORT") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__ENABLE_HTTPS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__HIDE_LOGIN_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_ADMIN__NAME") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__BRANDING__WEB_CLIENT__SHORT_NAME") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__1__EXTRA_CSS__0__PATH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ADDRESS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PORT") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__DISABLED_LOGIN_METHODS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BASE_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__LANGUAGES") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__TLS_PROTOCOLS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_MODE") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__PROXY_ALLOWED") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_PROXY_HEADER") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_IP_HEADER_DEPTH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__HIDE_LOGIN_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_ID") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CLIENT_SECRET") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CONFIG_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__REDIRECT_BASE_URL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__USERNAME_FIELD") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__ROLE_FIELD") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS_ARE_REGEX") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HOSTS_PROXY_HEADERS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_REDIRECT") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_HOST") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__KEY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__HTTPS_PROXY_HEADERS__1__VALUE") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_SECONDS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_INCLUDE_SUBDOMAINS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__STS_PRELOAD") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_TYPE_NOSNIFF") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CONTENT_SECURITY_POLICY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__PERMISSIONS_POLICY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_OPENER_POLICY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_RESOURCE_POLICY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CROSS_ORIGIN_EMBEDDER_POLICY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__CACHE_CONTROL") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__REFERRER_POLICY") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__0__PATH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__EXTRA_CSS__1__PATH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__FAVICON_PATH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__LOGO_PATH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DISCLAIMER_NAME") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_ADMIN__DISCLAIMER_PATH") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__DEFAULT_CSS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__BRANDING__WEB_CLIENT__EXTRA_CSS") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_FILE") + os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE") + }) + + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + bindings := config.GetHTTPDConfig().Bindings + require.Len(t, bindings, 3) + require.Equal(t, 0, bindings[0].Port) + require.Equal(t, sockPath, bindings[0].Address) + require.False(t, bindings[0].EnableHTTPS) + require.Len(t, bindings[0].Protocols, 0) + require.Equal(t, 12, bindings[0].MinTLSVersion) + require.True(t, bindings[0].EnableWebAdmin) + require.True(t, bindings[0].EnableWebClient) + require.True(t, bindings[0].EnableRESTAPI) + require.Equal(t, 0, bindings[0].EnabledLoginMethods) + require.Equal(t, 0, bindings[0].DisabledLoginMethods) + require.True(t, bindings[0].RenderOpenAPI) + require.Empty(t, bindings[0].BaseURL) + require.Len(t, bindings[0].Languages, 1) + assert.Contains(t, bindings[0].Languages, "en") + require.Len(t, bindings[0].TLSCipherSuites, 1) + require.Equal(t, 0, bindings[0].ProxyMode) + require.Empty(t, bindings[0].OIDC.ConfigURL) + require.Equal(t, "TLS_AES_128_GCM_SHA256", bindings[0].TLSCipherSuites[0]) + require.Equal(t, 0, bindings[0].HideLoginURL) + require.False(t, bindings[0].Security.Enabled) + require.Equal(t, 0, bindings[0].ClientIPHeaderDepth) + require.Len(t, bindings[0].OIDC.Scopes, 3) + require.False(t, bindings[0].OIDC.InsecureSkipSignatureCheck) + require.False(t, bindings[0].OIDC.Debug) + require.Empty(t, bindings[0].Security.ReferrerPolicy) + require.Equal(t, 8000, bindings[1].Port) + require.Equal(t, "127.0.0.1", bindings[1].Address) + require.False(t, bindings[1].EnableHTTPS) + require.Equal(t, 12, bindings[0].MinTLSVersion) + require.True(t, bindings[1].EnableWebAdmin) + require.True(t, bindings[1].EnableWebClient) + require.True(t, bindings[1].EnableRESTAPI) + require.Equal(t, 0, bindings[1].EnabledLoginMethods) + require.Equal(t, 0, bindings[1].DisabledLoginMethods) + require.True(t, bindings[1].RenderOpenAPI) + require.Empty(t, bindings[1].BaseURL) + require.Len(t, bindings[1].Languages, 1) + assert.Contains(t, bindings[1].Languages, "en") + require.Nil(t, bindings[1].TLSCipherSuites) + require.Equal(t, 1, bindings[1].HideLoginURL) + require.Empty(t, bindings[1].OIDC.ClientID) + require.Len(t, bindings[1].OIDC.Scopes, 3) + require.False(t, bindings[1].OIDC.InsecureSkipSignatureCheck) + require.False(t, bindings[1].OIDC.Debug) + require.False(t, bindings[1].Security.Enabled) + require.Equal(t, "Web Admin", bindings[1].Branding.WebAdmin.Name) + require.Equal(t, "WebClient", bindings[1].Branding.WebClient.ShortName) + require.Equal(t, 0, bindings[1].ProxyMode) + require.Equal(t, 0, bindings[1].ClientIPHeaderDepth) + require.Equal(t, 9000, bindings[2].Port) + require.Equal(t, "127.0.1.1", bindings[2].Address) + require.True(t, bindings[2].EnableHTTPS) + require.Equal(t, 13, bindings[2].MinTLSVersion) + require.False(t, bindings[2].EnableWebAdmin) + require.False(t, bindings[2].EnableWebClient) + require.False(t, bindings[2].EnableRESTAPI) + require.Equal(t, 3, bindings[2].EnabledLoginMethods) + require.Equal(t, 12, bindings[2].DisabledLoginMethods) + require.False(t, bindings[2].RenderOpenAPI) + require.Equal(t, "https://example.com", bindings[2].BaseURL) + require.Len(t, bindings[2].Languages, 2) + assert.Contains(t, bindings[2].Languages, "en") + assert.Contains(t, bindings[2].Languages, "es") + require.Equal(t, 1, bindings[2].ClientAuthType) + require.Len(t, bindings[2].TLSCipherSuites, 2) + require.Equal(t, "TLS_AES_256_GCM_SHA384", bindings[2].TLSCipherSuites[0]) + require.Equal(t, "TLS_CHACHA20_POLY1305_SHA256", bindings[2].TLSCipherSuites[1]) + require.Len(t, bindings[2].Protocols, 2) + require.Equal(t, "h2", bindings[2].Protocols[0]) + require.Equal(t, "http/1.1", bindings[2].Protocols[1]) + require.Equal(t, 1, bindings[2].ProxyMode) + require.Len(t, bindings[2].ProxyAllowed, 2) + require.Equal(t, "192.168.9.1", bindings[2].ProxyAllowed[0]) + require.Equal(t, "172.16.25.0/24", bindings[2].ProxyAllowed[1]) + require.Equal(t, "X-Real-IP", bindings[2].ClientIPProxyHeader) + require.Equal(t, 2, bindings[2].ClientIPHeaderDepth) + require.Equal(t, 3, bindings[2].HideLoginURL) + require.Equal(t, "client id", bindings[2].OIDC.ClientID) + require.Equal(t, "client secret", bindings[2].OIDC.ClientSecret) + require.Equal(t, "config url", bindings[2].OIDC.ConfigURL) + require.Equal(t, "redirect base url", bindings[2].OIDC.RedirectBaseURL) + require.Equal(t, "preferred_username", bindings[2].OIDC.UsernameField) + require.Equal(t, "sftpgo_role", bindings[2].OIDC.RoleField) + require.Len(t, bindings[2].OIDC.Scopes, 1) + require.Equal(t, "openid", bindings[2].OIDC.Scopes[0]) + require.True(t, bindings[2].OIDC.ImplicitRoles) + require.Len(t, bindings[2].OIDC.CustomFields, 2) + require.Equal(t, "field1", bindings[2].OIDC.CustomFields[0]) + require.Equal(t, "field2", bindings[2].OIDC.CustomFields[1]) + require.True(t, bindings[2].OIDC.InsecureSkipSignatureCheck) + require.True(t, bindings[2].OIDC.Debug) + require.True(t, bindings[2].Security.Enabled) + require.Len(t, bindings[2].Security.AllowedHosts, 2) + require.Equal(t, "*.example.com", bindings[2].Security.AllowedHosts[0]) + require.Equal(t, "*.example.net", bindings[2].Security.AllowedHosts[1]) + require.True(t, bindings[2].Security.AllowedHostsAreRegex) + require.Len(t, bindings[2].Security.HostsProxyHeaders, 1) + require.Equal(t, "X-Forwarded-Host", bindings[2].Security.HostsProxyHeaders[0]) + require.True(t, bindings[2].Security.HTTPSRedirect) + require.Equal(t, "www.example.com", bindings[2].Security.HTTPSHost) + require.Len(t, bindings[2].Security.HTTPSProxyHeaders, 1) + require.Equal(t, "X-Forwarded-Proto", bindings[2].Security.HTTPSProxyHeaders[0].Key) + require.Equal(t, "https", bindings[2].Security.HTTPSProxyHeaders[0].Value) + require.Equal(t, int64(31536000), bindings[2].Security.STSSeconds) + require.False(t, bindings[2].Security.STSIncludeSubdomains) + require.False(t, bindings[2].Security.STSPreload) + require.True(t, bindings[2].Security.ContentTypeNosniff) + require.Equal(t, "script-src $NONCE", bindings[2].Security.ContentSecurityPolicy) + require.Equal(t, "fullscreen=(), geolocation=()", bindings[2].Security.PermissionsPolicy) + require.Equal(t, "same-origin", bindings[2].Security.CrossOriginOpenerPolicy) + require.Equal(t, "same-site", bindings[2].Security.CrossOriginResourcePolicy) + require.Equal(t, "require-corp", bindings[2].Security.CrossOriginEmbedderPolicy) + require.Equal(t, "private", bindings[2].Security.CacheControl) + require.Equal(t, "no-referrer", bindings[2].Security.ReferrerPolicy) + require.Equal(t, "favicon.ico", bindings[2].Branding.WebAdmin.FaviconPath) + require.Equal(t, "logo.png", bindings[2].Branding.WebClient.LogoPath) + require.Equal(t, "disclaimer", bindings[2].Branding.WebClient.DisclaimerName) + require.Equal(t, "disclaimer.html", bindings[2].Branding.WebAdmin.DisclaimerPath) + require.Equal(t, []string{"default.css"}, bindings[2].Branding.WebClient.DefaultCSS) + require.Len(t, bindings[2].Branding.WebClient.ExtraCSS, 2) + require.Equal(t, "1.css", bindings[2].Branding.WebClient.ExtraCSS[0]) + require.Equal(t, "2.css", bindings[2].Branding.WebClient.ExtraCSS[1]) + require.Equal(t, "httpd.crt", bindings[2].CertificateFile) + require.Equal(t, "httpd.key", bindings[2].CertificateKeyFile) +} + +func TestHTTPClientCertificatesFromEnv(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + httpConf := config.GetHTTPConfig() + httpConf.Certificates = append(httpConf.Certificates, httpclient.TLSKeyPair{ + Cert: "cert", + Key: "key", + }) + c := make(map[string]httpclient.Config) + c["http"] = httpConf + jsonConf, err := json.Marshal(c) + require.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + require.NoError(t, err) + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Certificates, 1) + require.Equal(t, "cert", config.GetHTTPConfig().Certificates[0].Cert) + require.Equal(t, "key", config.GetHTTPConfig().Certificates[0].Key) + + os.Setenv("SFTPGO_HTTP__CERTIFICATES__0__CERT", "cert0") + os.Setenv("SFTPGO_HTTP__CERTIFICATES__0__KEY", "key0") + os.Setenv("SFTPGO_HTTP__CERTIFICATES__8__CERT", "cert8") + os.Setenv("SFTPGO_HTTP__CERTIFICATES__9__CERT", "cert9") + os.Setenv("SFTPGO_HTTP__CERTIFICATES__9__KEY", "key9") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__0__CERT") + os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__0__KEY") + os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__8__CERT") + os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__9__CERT") + os.Unsetenv("SFTPGO_HTTP__CERTIFICATES__9__KEY") + }) + + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Certificates, 2) + require.Equal(t, "cert0", config.GetHTTPConfig().Certificates[0].Cert) + require.Equal(t, "key0", config.GetHTTPConfig().Certificates[0].Key) + require.Equal(t, "cert9", config.GetHTTPConfig().Certificates[1].Cert) + require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key) + + err = os.Remove(configFilePath) + assert.NoError(t, err) + + config.Init() + + err = config.LoadConfig(configDir, "") + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Certificates, 2) + require.Equal(t, "cert0", config.GetHTTPConfig().Certificates[0].Cert) + require.Equal(t, "key0", config.GetHTTPConfig().Certificates[0].Key) + require.Equal(t, "cert9", config.GetHTTPConfig().Certificates[1].Cert) + require.Equal(t, "key9", config.GetHTTPConfig().Certificates[1].Key) +} + +func TestHTTPClientHeadersFromEnv(t *testing.T) { + reset() + + confName := tempConfigName + ".json" + configFilePath := filepath.Join(configDir, confName) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + httpConf := config.GetHTTPConfig() + httpConf.Headers = append(httpConf.Headers, httpclient.Header{ + Key: "key", + Value: "value", + URL: "url", + }) + c := make(map[string]httpclient.Config) + c["http"] = httpConf + jsonConf, err := json.Marshal(c) + require.NoError(t, err) + err = os.WriteFile(configFilePath, jsonConf, os.ModePerm) + require.NoError(t, err) + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Headers, 1) + require.Equal(t, "key", config.GetHTTPConfig().Headers[0].Key) + require.Equal(t, "value", config.GetHTTPConfig().Headers[0].Value) + require.Equal(t, "url", config.GetHTTPConfig().Headers[0].URL) + + os.Setenv("SFTPGO_HTTP__HEADERS__0__KEY", "key0") + os.Setenv("SFTPGO_HTTP__HEADERS__0__VALUE", "value0") + os.Setenv("SFTPGO_HTTP__HEADERS__0__URL", "url0") + os.Setenv("SFTPGO_HTTP__HEADERS__8__KEY", "key8") + os.Setenv("SFTPGO_HTTP__HEADERS__9__KEY", "key9") + os.Setenv("SFTPGO_HTTP__HEADERS__9__VALUE", "value9") + os.Setenv("SFTPGO_HTTP__HEADERS__9__URL", "url9") + + t.Cleanup(func() { + os.Unsetenv("SFTPGO_HTTP__HEADERS__0__KEY") + os.Unsetenv("SFTPGO_HTTP__HEADERS__0__VALUE") + os.Unsetenv("SFTPGO_HTTP__HEADERS__0__URL") + os.Unsetenv("SFTPGO_HTTP__HEADERS__8__KEY") + os.Unsetenv("SFTPGO_HTTP__HEADERS__9__KEY") + os.Unsetenv("SFTPGO_HTTP__HEADERS__9__VALUE") + os.Unsetenv("SFTPGO_HTTP__HEADERS__9__URL") + }) + + err = config.LoadConfig(configDir, confName) + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Headers, 2) + require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key) + require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value) + require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL) + require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key) + require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value) + require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL) + + err = os.Remove(configFilePath) + assert.NoError(t, err) + + config.Init() + + err = config.LoadConfig(configDir, "") + require.NoError(t, err) + require.Len(t, config.GetHTTPConfig().Headers, 2) + require.Equal(t, "key0", config.GetHTTPConfig().Headers[0].Key) + require.Equal(t, "value0", config.GetHTTPConfig().Headers[0].Value) + require.Equal(t, "url0", config.GetHTTPConfig().Headers[0].URL) + require.Equal(t, "key9", config.GetHTTPConfig().Headers[1].Key) + require.Equal(t, "value9", config.GetHTTPConfig().Headers[1].Value) + require.Equal(t, "url9", config.GetHTTPConfig().Headers[1].URL) +} + +func TestConfigFromEnv(t *testing.T) { + reset() + + os.Setenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS", "127.0.0.1") + os.Setenv("SFTPGO_WEBDAVD__BINDINGS__0__PORT", "12000") + os.Setenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS", "41") + os.Setenv("SFTPGO_DATA_PROVIDER__POOL_SIZE", "10") + os.Setenv("SFTPGO_DATA_PROVIDER__IS_SHARED", "1") + os.Setenv("SFTPGO_DATA_PROVIDER__ACTIONS__EXECUTE_ON", "add") + os.Setenv("SFTPGO_KMS__SECRETS__URL", "local") + os.Setenv("SFTPGO_KMS__SECRETS__MASTER_KEY_PATH", "path") + os.Setenv("SFTPGO_TELEMETRY__TLS_CIPHER_SUITES", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA") + os.Setenv("SFTPGO_TELEMETRY__TLS_PROTOCOLS", "h2") + os.Setenv("SFTPGO_HTTPD__SETUP__INSTALLATION_CODE", "123") + os.Setenv("SFTPGO_ACME__HTTP01_CHALLENGE__PORT", "5002") + t.Cleanup(func() { + os.Unsetenv("SFTPGO_SFTPD__BINDINGS__0__ADDRESS") + os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__0__PORT") + os.Unsetenv("SFTPGO_DATA_PROVIDER__PASSWORD_HASHING__ARGON2_OPTIONS__ITERATIONS") + os.Unsetenv("SFTPGO_DATA_PROVIDER__POOL_SIZE") + os.Unsetenv("SFTPGO_DATA_PROVIDER__IS_SHARED") + os.Unsetenv("SFTPGO_DATA_PROVIDER__ACTIONS__EXECUTE_ON") + os.Unsetenv("SFTPGO_KMS__SECRETS__URL") + os.Unsetenv("SFTPGO_KMS__SECRETS__MASTER_KEY_PATH") + os.Unsetenv("SFTPGO_TELEMETRY__TLS_CIPHER_SUITES") + os.Unsetenv("SFTPGO_TELEMETRY__TLS_PROTOCOLS") + os.Unsetenv("SFTPGO_HTTPD__SETUP__INSTALLATION_CODE") + os.Unsetenv("SFTPGO_ACME__HTTP01_CHALLENGE_PORT") + }) + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + sftpdConfig := config.GetSFTPDConfig() + assert.Equal(t, "127.0.0.1", sftpdConfig.Bindings[0].Address) + assert.Equal(t, 12000, config.GetWebDAVDConfig().Bindings[0].Port) + dataProviderConf := config.GetProviderConf() + assert.Equal(t, uint32(41), dataProviderConf.PasswordHashing.Argon2Options.Iterations) + assert.Equal(t, 10, dataProviderConf.PoolSize) + assert.Equal(t, 1, dataProviderConf.IsShared) + assert.Len(t, dataProviderConf.Actions.ExecuteOn, 1) + assert.Contains(t, dataProviderConf.Actions.ExecuteOn, "add") + kmsConfig := config.GetKMSConfig() + assert.Equal(t, "local", kmsConfig.Secrets.URL) + assert.Equal(t, "path", kmsConfig.Secrets.MasterKeyPath) + telemetryConfig := config.GetTelemetryConfig() + require.Len(t, telemetryConfig.TLSCipherSuites, 2) + assert.Equal(t, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", telemetryConfig.TLSCipherSuites[0]) + assert.Equal(t, "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", telemetryConfig.TLSCipherSuites[1]) + require.Len(t, telemetryConfig.Protocols, 1) + assert.Equal(t, "h2", telemetryConfig.Protocols[0]) + assert.Equal(t, "123", config.GetHTTPDConfig().Setup.InstallationCode) + acmeConfig := config.GetACMEConfig() + assert.Equal(t, 5002, acmeConfig.HTTP01Challenge.Port) +} diff --git a/internal/dataprovider/actions.go b/internal/dataprovider/actions.go new file mode 100644 index 00000000..b61e5b0a --- /dev/null +++ b/internal/dataprovider/actions.go @@ -0,0 +1,159 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "bytes" + "context" + "fmt" + "net/url" + "os/exec" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/command" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + // ActionExecutorSelf is used as username for self action, for example a user/admin that updates itself + ActionExecutorSelf = "__self__" + // ActionExecutorSystem is used as username for actions with no explicit executor associated, for example + // adding/updating a user/admin by loading initial data + ActionExecutorSystem = "__system__" +) + +const ( + actionObjectUser = "user" + actionObjectFolder = "folder" + actionObjectGroup = "group" + actionObjectAdmin = "admin" + actionObjectAPIKey = "api_key" + actionObjectShare = "share" + actionObjectEventAction = "event_action" + actionObjectEventRule = "event_rule" + actionObjectRole = "role" + actionObjectIPListEntry = "ip_list_entry" + actionObjectConfigs = "configs" +) + +var ( + actionsConcurrencyGuard = make(chan struct{}, 100) + reservedUsers = []string{ActionExecutorSelf, ActionExecutorSystem} +) + +func executeAction(operation, executor, ip, objectType, objectName, role string, object plugin.Renderer) { + if plugin.Handler.HasNotifiers() { + plugin.Handler.NotifyProviderEvent(¬ifier.ProviderEvent{ + Action: operation, + Username: executor, + ObjectType: objectType, + ObjectName: objectName, + IP: ip, + Role: role, + Timestamp: time.Now().UnixNano(), + }, object) + } + if fnHandleRuleForProviderEvent != nil { + fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, role, object) + } + if config.Actions.Hook == "" { + return + } + if !slices.Contains(config.Actions.ExecuteOn, operation) || + !slices.Contains(config.Actions.ExecuteFor, objectType) { + return + } + + go func() { + actionsConcurrencyGuard <- struct{}{} + defer func() { + <-actionsConcurrencyGuard + }() + + dataAsJSON, err := object.RenderAsJSON(operation != operationDelete) + if err != nil { + providerLog(logger.LevelError, "unable to serialize user as JSON for operation %q: %v", operation, err) + return + } + if strings.HasPrefix(config.Actions.Hook, "http") { + var url *url.URL + url, err := url.Parse(config.Actions.Hook) + if err != nil { + providerLog(logger.LevelError, "Invalid http_notification_url %q for operation %q: %v", + config.Actions.Hook, operation, err) + return + } + q := url.Query() + q.Add("action", operation) + q.Add("username", executor) + q.Add("ip", ip) + q.Add("object_type", objectType) + q.Add("object_name", objectName) + if role != "" { + q.Add("role", role) + } + q.Add("timestamp", fmt.Sprintf("%d", time.Now().UnixNano())) + url.RawQuery = q.Encode() + startTime := time.Now() + resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(dataAsJSON)) + respCode := 0 + if err == nil { + respCode = resp.StatusCode + resp.Body.Close() + } + providerLog(logger.LevelDebug, "notified operation %q to URL: %s status code: %d, elapsed: %s err: %v", + operation, url.Redacted(), respCode, time.Since(startTime), err) + return + } + executeNotificationCommand(operation, executor, ip, objectType, objectName, role, dataAsJSON) //nolint:errcheck // the error is used in test cases only + }() +} + +func executeNotificationCommand(operation, executor, ip, objectType, objectName, role string, objectAsJSON []byte) error { + if !filepath.IsAbs(config.Actions.Hook) { + err := fmt.Errorf("invalid notification command %q", config.Actions.Hook) + logger.Warn(logSender, "", "unable to execute notification command: %v", err) + return err + } + + timeout, env, args := command.GetConfig(config.Actions.Hook, command.HookProviderActions) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, config.Actions.Hook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_PROVIDER_ACTION=%s", operation), + fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_TYPE=%s", objectType), + fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_NAME=%s", objectName), + fmt.Sprintf("SFTPGO_PROVIDER_USERNAME=%s", executor), + fmt.Sprintf("SFTPGO_PROVIDER_IP=%s", ip), + fmt.Sprintf("SFTPGO_PROVIDER_ROLE=%s", role), + fmt.Sprintf("SFTPGO_PROVIDER_TIMESTAMP=%d", util.GetTimeAsMsSinceEpoch(time.Now())), + fmt.Sprintf("SFTPGO_PROVIDER_OBJECT=%s", objectAsJSON)) + + startTime := time.Now() + err := cmd.Run() + providerLog(logger.LevelDebug, "executed command %q, elapsed: %s, error: %v", config.Actions.Hook, + time.Since(startTime), err) + return err +} diff --git a/internal/dataprovider/admin.go b/internal/dataprovider/admin.go new file mode 100644 index 00000000..dea4a20b --- /dev/null +++ b/internal/dataprovider/admin.go @@ -0,0 +1,660 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "os" + "slices" + "strconv" + "strings" + + "github.com/alexedwards/argon2id" + "github.com/sftpgo/sdk" + passwordvalidator "github.com/wagslane/go-password-validator" + "golang.org/x/crypto/bcrypt" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// Available permissions for SFTPGo admins +const ( + PermAdminAny = "*" + PermAdminAddUsers = "add_users" + PermAdminChangeUsers = "edit_users" + PermAdminDeleteUsers = "del_users" + PermAdminViewUsers = "view_users" + PermAdminViewConnections = "view_conns" + PermAdminCloseConnections = "close_conns" + PermAdminViewServerStatus = "view_status" + PermAdminManageGroups = "manage_groups" + PermAdminManageFolders = "manage_folders" + PermAdminQuotaScans = "quota_scans" + PermAdminManageDefender = "manage_defender" + PermAdminViewDefender = "view_defender" + PermAdminViewEvents = "view_events" + PermAdminDisableMFA = "disable_mfa" +) + +const ( + // GroupAddToUsersAsMembership defines that the admin's group will be added as membership group for new users + GroupAddToUsersAsMembership = iota + // GroupAddToUsersAsPrimary defines that the admin's group will be added as primary group for new users + GroupAddToUsersAsPrimary + // GroupAddToUsersAsSecondary defines that the admin's group will be added as secondary group for new users + GroupAddToUsersAsSecondary +) + +var ( + validAdminPerms = []string{PermAdminAny, PermAdminAddUsers, PermAdminChangeUsers, PermAdminDeleteUsers, + PermAdminViewUsers, PermAdminManageFolders, PermAdminManageGroups, PermAdminViewConnections, + PermAdminCloseConnections, PermAdminViewServerStatus, PermAdminQuotaScans, + PermAdminManageDefender, PermAdminViewDefender, PermAdminViewEvents, PermAdminDisableMFA} + forbiddenPermsForRoleAdmins = []string{PermAdminAny} +) + +// AdminTOTPConfig defines the time-based one time password configuration +type AdminTOTPConfig struct { + Enabled bool `json:"enabled,omitempty"` + ConfigName string `json:"config_name,omitempty"` + Secret *kms.Secret `json:"secret,omitempty"` +} + +func (c *AdminTOTPConfig) validate(username string) error { + if !c.Enabled { + c.ConfigName = "" + c.Secret = kms.NewEmptySecret() + return nil + } + if c.ConfigName == "" { + return util.NewValidationError("totp: config name is mandatory") + } + if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) { + return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName)) + } + if c.Secret.IsEmpty() { + return util.NewValidationError("totp: secret is mandatory") + } + if c.Secret.IsPlain() { + c.Secret.SetAdditionalData(username) + if err := c.Secret.Encrypt(); err != nil { + return util.NewValidationError(fmt.Sprintf("totp: unable to encrypt secret: %v", err)) + } + } + return nil +} + +// AdminPreferences defines the admin preferences +type AdminPreferences struct { + // Allow to hide some sections from the user page. + // These are not security settings and are not enforced server side + // in any way. They are only intended to simplify the user page in + // the WebAdmin UI. + // + // 1 means hide groups section + // 2 means hide filesystem section, "users_base_dir" must be set in the config file otherwise this setting is ignored + // 4 means hide virtual folders section + // 8 means hide profile section + // 16 means hide ACLs section + // 32 means hide disk and bandwidth quota limits section + // 64 means hide advanced settings section + // + // The settings can be combined + HideUserPageSections int `json:"hide_user_page_sections,omitempty"` + // Defines the default expiration for newly created users as number of days. + // 0 means no expiration + DefaultUsersExpiration int `json:"default_users_expiration,omitempty"` +} + +// HideGroups returns true if the groups section should be hidden +func (p *AdminPreferences) HideGroups() bool { + return p.HideUserPageSections&1 != 0 +} + +// HideFilesystem returns true if the filesystem section should be hidden +func (p *AdminPreferences) HideFilesystem() bool { + return config.UsersBaseDir != "" && p.HideUserPageSections&2 != 0 +} + +// HideVirtualFolders returns true if the virtual folder section should be hidden +func (p *AdminPreferences) HideVirtualFolders() bool { + return p.HideUserPageSections&4 != 0 +} + +// HideProfile returns true if the profile section should be hidden +func (p *AdminPreferences) HideProfile() bool { + return p.HideUserPageSections&8 != 0 +} + +// HideACLs returns true if the ACLs section should be hidden +func (p *AdminPreferences) HideACLs() bool { + return p.HideUserPageSections&16 != 0 +} + +// HideDiskQuotaAndBandwidthLimits returns true if the disk quota and bandwidth limits +// section should be hidden +func (p *AdminPreferences) HideDiskQuotaAndBandwidthLimits() bool { + return p.HideUserPageSections&32 != 0 +} + +// HideAdvancedSettings returns true if the advanced settings section should be hidden +func (p *AdminPreferences) HideAdvancedSettings() bool { + return p.HideUserPageSections&64 != 0 +} + +// VisibleUserPageSections returns the number of visible sections +// in the user page +func (p *AdminPreferences) VisibleUserPageSections() int { + var result int + + if !p.HideProfile() { + result++ + } + if !p.HideACLs() { + result++ + } + if !p.HideDiskQuotaAndBandwidthLimits() { + result++ + } + if !p.HideAdvancedSettings() { + result++ + } + + return result +} + +// AdminFilters defines additional restrictions for SFTPGo admins +// TODO: rename to AdminOptions in v3 +type AdminFilters struct { + // only clients connecting from these IP/Mask are allowed. + // IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291 + // for example "192.0.2.0/24" or "2001:db8::/32" + AllowList []string `json:"allow_list,omitempty"` + // API key auth allows to impersonate this administrator with an API key + AllowAPIKeyAuth bool `json:"allow_api_key_auth,omitempty"` + // A password change is required at the next login + RequirePasswordChange bool `json:"require_password_change,omitempty"` + // Require two factor authentication + RequireTwoFactor bool `json:"require_two_factor"` + // Time-based one time passwords configuration + TOTPConfig AdminTOTPConfig `json:"totp_config,omitempty"` + // Recovery codes to use if the user loses access to their second factor auth device. + // Each code can only be used once, you should use these codes to login and disable or + // reset 2FA for your account + RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"` + Preferences AdminPreferences `json:"preferences"` +} + +// AdminGroupMappingOptions defines the options for admin/group mapping +type AdminGroupMappingOptions struct { + AddToUsersAs int `json:"add_to_users_as,omitempty"` +} + +func (o *AdminGroupMappingOptions) validate() error { + if o.AddToUsersAs < GroupAddToUsersAsMembership || o.AddToUsersAs > GroupAddToUsersAsSecondary { + return util.NewValidationError(fmt.Sprintf("Invalid mode to add groups to new users: %d", o.AddToUsersAs)) + } + return nil +} + +// GetUserGroupType returns the type for the matching user group +func (o *AdminGroupMappingOptions) GetUserGroupType() int { + switch o.AddToUsersAs { + case GroupAddToUsersAsPrimary: + return sdk.GroupTypePrimary + case GroupAddToUsersAsSecondary: + return sdk.GroupTypeSecondary + default: + return sdk.GroupTypeMembership + } +} + +// AdminGroupMapping defines the mapping between an SFTPGo admin and a group +type AdminGroupMapping struct { + Name string `json:"name"` + Options AdminGroupMappingOptions `json:"options"` +} + +// Admin defines a SFTPGo admin +type Admin struct { + // Database unique identifier + ID int64 `json:"id"` + // 1 enabled, 0 disabled (login is not allowed) + Status int `json:"status"` + // Username + Username string `json:"username"` + Password string `json:"password,omitempty"` + Email string `json:"email,omitempty"` + Permissions []string `json:"permissions"` + Filters AdminFilters `json:"filters,omitempty"` + Description string `json:"description,omitempty"` + AdditionalInfo string `json:"additional_info,omitempty"` + // Groups membership + Groups []AdminGroupMapping `json:"groups,omitempty"` + // Creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0 + CreatedAt int64 `json:"created_at"` + // last update time as unix timestamp in milliseconds + UpdatedAt int64 `json:"updated_at"` + // Last login as unix timestamp in milliseconds + LastLogin int64 `json:"last_login"` + // Role name. If set the admin can only administer users with the same role. + // Role admins cannot be super administrators + Role string `json:"role,omitempty"` +} + +// CountUnusedRecoveryCodes returns the number of unused recovery codes +func (a *Admin) CountUnusedRecoveryCodes() int { + unused := 0 + for _, code := range a.Filters.RecoveryCodes { + if !code.Used { + unused++ + } + } + return unused +} + +func (a *Admin) hashPassword() error { + if a.Password != "" && !util.IsStringPrefixInSlice(a.Password, internalHashPwdPrefixes) { + if config.PasswordValidation.Admins.MinEntropy > 0 { + if err := passwordvalidator.Validate(a.Password, config.PasswordValidation.Admins.MinEntropy); err != nil { + return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) + } + } + if config.PasswordHashing.Algo == HashingAlgoBcrypt { + pwd, err := bcrypt.GenerateFromPassword([]byte(a.Password), config.PasswordHashing.BcryptOptions.Cost) + if err != nil { + return err + } + a.Password = util.BytesToString(pwd) + } else { + pwd, err := argon2id.CreateHash(a.Password, argon2Params) + if err != nil { + return err + } + a.Password = pwd + } + } + return nil +} + +func (a *Admin) hasRedactedSecret() bool { + return a.Filters.TOTPConfig.Secret.IsRedacted() +} + +func (a *Admin) validateRecoveryCodes() error { + for i := 0; i < len(a.Filters.RecoveryCodes); i++ { + code := &a.Filters.RecoveryCodes[i] + if code.Secret.IsEmpty() { + return util.NewValidationError("mfa: recovery code cannot be empty") + } + if code.Secret.IsPlain() { + code.Secret.SetAdditionalData(a.Username) + if err := code.Secret.Encrypt(); err != nil { + return util.NewValidationError(fmt.Sprintf("mfa: unable to encrypt recovery code: %v", err)) + } + } + } + return nil +} + +func (a *Admin) validatePermissions() error { + a.Permissions = util.RemoveDuplicates(a.Permissions, false) + if len(a.Permissions) == 0 { + return util.NewI18nError( + util.NewValidationError("please grant some permissions to this admin"), + util.I18nErrorPermissionsRequired, + ) + } + if slices.Contains(a.Permissions, PermAdminAny) { + a.Permissions = []string{PermAdminAny} + } + for _, perm := range a.Permissions { + if !slices.Contains(validAdminPerms, perm) { + return util.NewValidationError(fmt.Sprintf("invalid permission: %q", perm)) + } + if a.Role != "" { + if slices.Contains(forbiddenPermsForRoleAdmins, perm) { + return util.NewI18nError( + util.NewValidationError("a role admin cannot be a super admin"), + util.I18nErrorRoleAdminPerms, + ) + } + } + } + return nil +} + +func (a *Admin) validateGroups() error { + hasPrimary := false + for _, g := range a.Groups { + if g.Name == "" { + return util.NewValidationError("group name is mandatory") + } + if err := g.Options.validate(); err != nil { + return err + } + if g.Options.AddToUsersAs == GroupAddToUsersAsPrimary { + if hasPrimary { + return util.NewI18nError( + util.NewValidationError("only one primary group is allowed"), + util.I18nErrorPrimaryGroup, + ) + } + hasPrimary = true + } + } + return nil +} + +func (a *Admin) applyNamingRules() { + a.Username = config.convertName(a.Username) + a.Role = config.convertName(a.Role) + for idx := range a.Groups { + a.Groups[idx].Name = config.convertName(a.Groups[idx].Name) + } +} + +func (a *Admin) validate() error { //nolint:gocyclo + a.SetEmptySecretsIfNil() + a.applyNamingRules() + a.Password = strings.TrimSpace(a.Password) + if a.Username == "" { + return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) + } + if !util.IsNameValid(a.Username) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if err := checkReservedUsernames(a.Username); err != nil { + return util.NewI18nError(err, util.I18nErrorReservedUsername) + } + if a.Password == "" { + return util.NewI18nError(util.NewValidationError("please set a password"), util.I18nErrorPasswordRequired) + } + if a.hasRedactedSecret() { + return util.NewValidationError("cannot save an admin with a redacted secret") + } + if err := a.Filters.TOTPConfig.validate(a.Username); err != nil { + return util.NewI18nError(err, util.I18nError2FAInvalid) + } + if err := a.validateRecoveryCodes(); err != nil { + return util.NewI18nError(err, util.I18nErrorRecoveryCodesInvalid) + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(a.Username) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("username %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", a.Username)), + util.I18nErrorInvalidUser, + ) + } + if err := a.hashPassword(); err != nil { + return err + } + if err := a.validatePermissions(); err != nil { + return err + } + if a.Email != "" && !util.IsEmailValid(a.Email) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("email %q is not valid", a.Email)), + util.I18nErrorInvalidEmail, + ) + } + a.Filters.AllowList = util.RemoveDuplicates(a.Filters.AllowList, false) + for _, IPMask := range a.Filters.AllowList { + _, _, err := net.ParseCIDR(IPMask) + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not parse allow list entry %q : %v", IPMask, err)), + util.I18nErrorInvalidIPMask, + ) + } + } + + return a.validateGroups() +} + +// CheckPassword verifies the admin password +func (a *Admin) CheckPassword(password string) (bool, error) { + if config.PasswordCaching { + found, match := cachedAdminPasswords.Check(a.Username, password, a.Password) + if found { + if !match { + return false, ErrInvalidCredentials + } + return match, nil + } + } + if strings.HasPrefix(a.Password, bcryptPwdPrefix) { + if err := bcrypt.CompareHashAndPassword([]byte(a.Password), []byte(password)); err != nil { + return false, ErrInvalidCredentials + } + cachedAdminPasswords.Add(a.Username, password, a.Password) + return true, nil + } + match, err := argon2id.ComparePasswordAndHash(password, a.Password) + if !match || err != nil { + return false, ErrInvalidCredentials + } + if match { + cachedAdminPasswords.Add(a.Username, password, a.Password) + } + return match, err +} + +// CanLoginFromIP returns true if login from the given IP is allowed +func (a *Admin) CanLoginFromIP(ip string) bool { + if len(a.Filters.AllowList) == 0 { + return true + } + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return len(a.Filters.AllowList) == 0 + } + + for _, ipMask := range a.Filters.AllowList { + _, network, err := net.ParseCIDR(ipMask) + if err != nil { + continue + } + if network.Contains(parsedIP) { + return true + } + } + return false +} + +// CanLogin returns an error if the login is not allowed +func (a *Admin) CanLogin(ip string) error { + if a.Status != 1 { + return fmt.Errorf("admin %q is disabled", a.Username) + } + if !a.CanLoginFromIP(ip) { + return fmt.Errorf("login from IP %v not allowed", ip) + } + return nil +} + +func (a *Admin) checkUserAndPass(password, ip string) error { + if err := a.CanLogin(ip); err != nil { + return err + } + if a.Password == "" || strings.TrimSpace(password) == "" { + return errors.New("credentials cannot be null or empty") + } + match, err := a.CheckPassword(password) + if err != nil { + return err + } + if !match { + return ErrInvalidCredentials + } + return nil +} + +// RenderAsJSON implements the renderer interface used within plugins +func (a *Admin) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + admin, err := provider.adminExists(a.Username) + if err != nil { + providerLog(logger.LevelError, "unable to reload admin before rendering as json: %v", err) + return nil, err + } + admin.HideConfidentialData() + return json.Marshal(admin) + } + a.HideConfidentialData() + return json.Marshal(a) +} + +// HideConfidentialData hides admin confidential data +func (a *Admin) HideConfidentialData() { + a.Password = "" + if a.Filters.TOTPConfig.Secret != nil { + a.Filters.TOTPConfig.Secret.Hide() + } + for _, code := range a.Filters.RecoveryCodes { + if code.Secret != nil { + code.Secret.Hide() + } + } + a.SetNilSecretsIfEmpty() +} + +// SetEmptySecretsIfNil sets the secrets to empty if nil +func (a *Admin) SetEmptySecretsIfNil() { + if a.Filters.TOTPConfig.Secret == nil { + a.Filters.TOTPConfig.Secret = kms.NewEmptySecret() + } +} + +// SetNilSecretsIfEmpty set the secrets to nil if empty. +// This is useful before rendering as JSON so the empty fields +// will not be serialized. +func (a *Admin) SetNilSecretsIfEmpty() { + if a.Filters.TOTPConfig.Secret != nil && a.Filters.TOTPConfig.Secret.IsEmpty() { + a.Filters.TOTPConfig.Secret = nil + } +} + +// HasPermission returns true if the admin has the specified permission +func (a *Admin) HasPermission(perm string) bool { + if slices.Contains(a.Permissions, PermAdminAny) { + return true + } + return slices.Contains(a.Permissions, perm) +} + +// HasPermissions returns true if the admin has all the specified permissions +func (a *Admin) HasPermissions(perms ...string) bool { + for _, perm := range perms { + if !a.HasPermission(perm) { + return false + } + } + return len(perms) > 0 +} + +// GetAllowedIPAsString returns the allowed IP as comma separated string +func (a *Admin) GetAllowedIPAsString() string { + return strings.Join(a.Filters.AllowList, ",") +} + +// GetValidPerms returns the allowed admin permissions +func (a *Admin) GetValidPerms() []string { + return validAdminPerms +} + +// CanManageMFA returns true if the admin can add a multi-factor authentication configuration +func (a *Admin) CanManageMFA() bool { + return len(mfa.GetAvailableTOTPConfigs()) > 0 +} + +// GetSignature returns a signature for this admin. +// It will change after an update +func (a *Admin) GetSignature() string { + return strconv.FormatInt(a.UpdatedAt, 10) +} + +func (a *Admin) getACopy() Admin { + a.SetEmptySecretsIfNil() + permissions := make([]string, len(a.Permissions)) + copy(permissions, a.Permissions) + filters := AdminFilters{} + filters.AllowList = make([]string, len(a.Filters.AllowList)) + filters.AllowAPIKeyAuth = a.Filters.AllowAPIKeyAuth + filters.RequirePasswordChange = a.Filters.RequirePasswordChange + filters.RequireTwoFactor = a.Filters.RequireTwoFactor + filters.TOTPConfig.Enabled = a.Filters.TOTPConfig.Enabled + filters.TOTPConfig.ConfigName = a.Filters.TOTPConfig.ConfigName + filters.TOTPConfig.Secret = a.Filters.TOTPConfig.Secret.Clone() + copy(filters.AllowList, a.Filters.AllowList) + filters.RecoveryCodes = make([]RecoveryCode, 0) + for _, code := range a.Filters.RecoveryCodes { + if code.Secret == nil { + code.Secret = kms.NewEmptySecret() + } + filters.RecoveryCodes = append(filters.RecoveryCodes, RecoveryCode{ + Secret: code.Secret.Clone(), + Used: code.Used, + }) + } + filters.Preferences = AdminPreferences{ + HideUserPageSections: a.Filters.Preferences.HideUserPageSections, + DefaultUsersExpiration: a.Filters.Preferences.DefaultUsersExpiration, + } + groups := make([]AdminGroupMapping, 0, len(a.Groups)) + for _, g := range a.Groups { + groups = append(groups, AdminGroupMapping{ + Name: g.Name, + Options: AdminGroupMappingOptions{ + AddToUsersAs: g.Options.AddToUsersAs, + }, + }) + } + + return Admin{ + ID: a.ID, + Status: a.Status, + Username: a.Username, + Password: a.Password, + Email: a.Email, + Permissions: permissions, + Groups: groups, + Filters: filters, + AdditionalInfo: a.AdditionalInfo, + Description: a.Description, + LastLogin: a.LastLogin, + CreatedAt: a.CreatedAt, + UpdatedAt: a.UpdatedAt, + Role: a.Role, + } +} + +func (a *Admin) setFromEnv() error { + envUsername := strings.TrimSpace(os.Getenv("SFTPGO_DEFAULT_ADMIN_USERNAME")) + envPassword := strings.TrimSpace(os.Getenv("SFTPGO_DEFAULT_ADMIN_PASSWORD")) + if envUsername == "" || envPassword == "" { + return errors.New(`to create the default admin you need to set the env vars "SFTPGO_DEFAULT_ADMIN_USERNAME" and "SFTPGO_DEFAULT_ADMIN_PASSWORD"`) + } + a.Username = envUsername + a.Password = envPassword + a.Status = 1 + a.Permissions = []string{PermAdminAny} + return nil +} diff --git a/internal/dataprovider/apikey.go b/internal/dataprovider/apikey.go new file mode 100644 index 00000000..b7424dd6 --- /dev/null +++ b/internal/dataprovider/apikey.go @@ -0,0 +1,213 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/alexedwards/argon2id" + "golang.org/x/crypto/bcrypt" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// APIKeyScope defines the supported API key scopes +type APIKeyScope int + +// Supported API key scopes +const ( + // the API key will be used for an admin + APIKeyScopeAdmin APIKeyScope = iota + 1 + // the API key will be used for a user + APIKeyScopeUser +) + +// APIKey defines a SFTPGo API key. +// API keys can be used as authentication alternative to short lived tokens +// for REST API +type APIKey struct { + // Database unique identifier + ID int64 `json:"-"` + // Unique key identifier, used for key lookups. + // The generated key is in the format `KeyID.hash(Key)` so we can split + // and lookup by KeyID and then verify if the key matches the recorded hash + KeyID string `json:"id"` + // User friendly key name + Name string `json:"name"` + // we store the hash of the key, this is just like a password + Key string `json:"key,omitempty"` + Scope APIKeyScope `json:"scope"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + // 0 means never used + LastUseAt int64 `json:"last_use_at,omitempty"` + // 0 means never expire + ExpiresAt int64 `json:"expires_at,omitempty"` + Description string `json:"description,omitempty"` + // Username associated with this API key. + // If empty and the scope is APIKeyScopeUser the key is valid for any user + User string `json:"user,omitempty"` + // Admin username associated with this API key. + // If empty and the scope is APIKeyScopeAdmin the key is valid for any admin + Admin string `json:"admin,omitempty"` + // these fields are for internal use + userID int64 + adminID int64 + plainKey string +} + +func (k *APIKey) getACopy() APIKey { + return APIKey{ + ID: k.ID, + KeyID: k.KeyID, + Name: k.Name, + Key: k.Key, + Scope: k.Scope, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + LastUseAt: k.LastUseAt, + ExpiresAt: k.ExpiresAt, + Description: k.Description, + User: k.User, + Admin: k.Admin, + userID: k.userID, + adminID: k.adminID, + } +} + +// RenderAsJSON implements the renderer interface used within plugins +func (k *APIKey) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + apiKey, err := provider.apiKeyExists(k.KeyID) + if err != nil { + providerLog(logger.LevelError, "unable to reload api key before rendering as json: %v", err) + return nil, err + } + apiKey.HideConfidentialData() + return json.Marshal(apiKey) + } + k.HideConfidentialData() + return json.Marshal(k) +} + +// HideConfidentialData hides API key confidential data +func (k *APIKey) HideConfidentialData() { + k.Key = "" +} + +func (k *APIKey) hashKey() error { + if k.Key != "" && !util.IsStringPrefixInSlice(k.Key, internalHashPwdPrefixes) { + if config.PasswordHashing.Algo == HashingAlgoBcrypt { + hashed, err := bcrypt.GenerateFromPassword([]byte(k.Key), config.PasswordHashing.BcryptOptions.Cost) + if err != nil { + return err + } + k.Key = util.BytesToString(hashed) + } else { + hashed, err := argon2id.CreateHash(k.Key, argon2Params) + if err != nil { + return err + } + k.Key = hashed + } + } + return nil +} + +func (k *APIKey) generateKey() { + if k.KeyID != "" || k.Key != "" { + return + } + k.KeyID = util.GenerateUniqueID() + k.Key = util.GenerateUniqueID() + k.plainKey = k.Key +} + +// DisplayKey returns the key to show to the user +func (k *APIKey) DisplayKey() string { + return fmt.Sprintf("%v.%v", k.KeyID, k.plainKey) +} + +func (k *APIKey) validate() error { + if k.Name == "" { + return util.NewValidationError("name is mandatory") + } + if !util.IsNameValid(k.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if k.Scope != APIKeyScopeAdmin && k.Scope != APIKeyScopeUser { + return util.NewValidationError(fmt.Sprintf("invalid scope: %v", k.Scope)) + } + k.generateKey() + if err := k.hashKey(); err != nil { + return err + } + if k.User != "" && k.Admin != "" { + return util.NewValidationError("an API key can be related to a user or an admin, not both") + } + if k.Scope == APIKeyScopeAdmin { + k.User = "" + } + if k.Scope == APIKeyScopeUser { + k.Admin = "" + } + if k.User != "" { + _, err := provider.userExists(k.User, "") + if err != nil { + return util.NewValidationError(fmt.Sprintf("unable to check API key user %v: %v", k.User, err)) + } + } + if k.Admin != "" { + _, err := provider.adminExists(k.Admin) + if err != nil { + return util.NewValidationError(fmt.Sprintf("unable to check API key admin %v: %v", k.Admin, err)) + } + } + return nil +} + +// Authenticate tries to authenticate the provided plain key +func (k *APIKey) Authenticate(plainKey string) error { + if k.ExpiresAt > 0 && k.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) { + return fmt.Errorf("API key %q is expired, expiration timestamp: %v current timestamp: %v", k.KeyID, + k.ExpiresAt, util.GetTimeAsMsSinceEpoch(time.Now())) + } + if config.PasswordCaching { + found, match := cachedAPIKeys.Check(k.KeyID, plainKey, k.Key) + if found { + if !match { + return ErrInvalidCredentials + } + return nil + } + } + if strings.HasPrefix(k.Key, bcryptPwdPrefix) { + if err := bcrypt.CompareHashAndPassword([]byte(k.Key), []byte(plainKey)); err != nil { + return ErrInvalidCredentials + } + } else if strings.HasPrefix(k.Key, argonPwdPrefix) { + match, err := argon2id.ComparePasswordAndHash(plainKey, k.Key) + if err != nil || !match { + return ErrInvalidCredentials + } + } + + cachedAPIKeys.Add(k.KeyID, plainKey, k.Key) + return nil +} diff --git a/internal/dataprovider/bolt.go b/internal/dataprovider/bolt.go new file mode 100644 index 00000000..28adb25f --- /dev/null +++ b/internal/dataprovider/bolt.go @@ -0,0 +1,3951 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nobolt + +package dataprovider + +import ( + "bytes" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "net/netip" + "path/filepath" + "slices" + "sort" + "strconv" + "time" + + bolt "go.etcd.io/bbolt" + bolterrors "go.etcd.io/bbolt/errors" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + boltDatabaseVersion = 34 +) + +var ( + usersBucket = []byte("users") + groupsBucket = []byte("groups") + foldersBucket = []byte("folders") + adminsBucket = []byte("admins") + apiKeysBucket = []byte("api_keys") + sharesBucket = []byte("shares") + actionsBucket = []byte("events_actions") + rulesBucket = []byte("events_rules") + rolesBucket = []byte("roles") + ipListsBucket = []byte("ip_lists") + configsBucket = []byte("configs") + dbVersionBucket = []byte("db_version") + dbVersionKey = []byte("version") + configsKey = []byte("configs") + boltBuckets = [][]byte{usersBucket, groupsBucket, foldersBucket, adminsBucket, apiKeysBucket, + sharesBucket, actionsBucket, rulesBucket, rolesBucket, ipListsBucket, configsBucket, dbVersionBucket} +) + +// BoltProvider defines the auth provider for bolt key/value store +type BoltProvider struct { + dbHandle *bolt.DB +} + +func init() { + version.AddFeature("+bolt") +} + +func initializeBoltProvider(basePath string) error { + var err error + + dbPath := config.Name + if !util.IsFileInputValid(dbPath) { + return fmt.Errorf("invalid database path: %q", dbPath) + } + if !filepath.IsAbs(dbPath) { + dbPath = filepath.Join(basePath, dbPath) + } + dbHandle, err := bolt.Open(dbPath, 0600, &bolt.Options{ + NoGrowSync: false, + FreelistType: bolt.FreelistArrayType, + Timeout: 5 * time.Second}) + if err == nil { + providerLog(logger.LevelDebug, "bolt key store handle created") + + for _, bucket := range boltBuckets { + if err := dbHandle.Update(func(tx *bolt.Tx) error { + _, e := tx.CreateBucketIfNotExists(bucket) + return e + }); err != nil { + providerLog(logger.LevelError, "error creating bucket %q: %v", string(bucket), err) + } + } + + provider = &BoltProvider{dbHandle: dbHandle} + } else { + providerLog(logger.LevelError, "error creating bolt key/value store handler: %v", err) + } + return err +} + +func (p *BoltProvider) checkAvailability() error { + _, err := getBoltDatabaseVersion(p.dbHandle) + return err +} + +func (p *BoltProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { + var user User + if tlsCert == nil { + return user, errors.New("TLS certificate cannot be null or empty") + } + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, err + } + return checkUserAndTLSCertificate(&user, protocol, tlsCert) +} + +func (p *BoltProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, err + } + return checkUserAndPass(&user, password, ip, protocol) +} + +func (p *BoltProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { + admin, err := p.adminExists(username) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating admin %q: %v", username, err) + return admin, err + } + err = admin.checkUserAndPass(password, ip) + return admin, err +} + +func (p *BoltProvider) validateUserAndPubKey(username string, pubKey []byte, isSSHCert bool) (User, string, error) { + var user User + if len(pubKey) == 0 { + return user, "", errors.New("credentials cannot be null or empty") + } + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, "", err + } + return checkUserAndPubKey(&user, pubKey, isSSHCert) +} + +func (p *BoltProvider) updateAPIKeyLastUse(keyID string) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(keyID)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("key %q does not exist, unable to update last use", keyID)) + } + var apiKey APIKey + err = json.Unmarshal(u, &apiKey) + if err != nil { + return err + } + apiKey.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(apiKey) + if err != nil { + return err + } + err = bucket.Put([]byte(keyID), buf) + if err != nil { + providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err) + return err + } + providerLog(logger.LevelDebug, "last use updated for key %q", keyID) + return nil + }) +} + +func (p *BoltProvider) getAdminSignature(username string) (string, error) { + var updatedAt int64 + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + u := bucket.Get([]byte(username)) + var admin Admin + err = json.Unmarshal(u, &admin) + if err != nil { + return err + } + updatedAt = admin.UpdatedAt + return nil + }) + if err != nil { + return "", err + } + return strconv.FormatInt(updatedAt, 10), nil +} + +func (p *BoltProvider) getUserSignature(username string) (string, error) { + var updatedAt int64 + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + u := bucket.Get([]byte(username)) + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + updatedAt = user.UpdatedAt + return nil + }) + if err != nil { + return "", err + } + return strconv.FormatInt(updatedAt, 10), nil +} + +func (p *BoltProvider) setUpdatedAt(username string) { + p.dbHandle.Update(func(tx *bolt.Tx) error { //nolint:errcheck + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update updated at", username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + err = bucket.Put([]byte(username), buf) + if err == nil { + providerLog(logger.LevelDebug, "updated at set for user %q", username) + setLastUserUpdate() + } else { + providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err) + } + return err + }) +} + +func (p *BoltProvider) updateLastLogin(username string) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update last login", username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + user.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + err = bucket.Put([]byte(username), buf) + if err != nil { + providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err) + } else { + providerLog(logger.LevelDebug, "last login updated for user %q", username) + } + return err + }) +} + +func (p *BoltProvider) updateAdminLastLogin(username string) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + var a []byte + if a = bucket.Get([]byte(username)); a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("admin %q does not exist, unable to update last login", username)) + } + var admin Admin + err = json.Unmarshal(a, &admin) + if err != nil { + return err + } + admin.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(admin) + if err != nil { + return err + } + err = bucket.Put([]byte(username), buf) + if err == nil { + providerLog(logger.LevelDebug, "last login updated for admin %q", username) + return err + } + providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err) + return err + }) +} + +func (p *BoltProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update transfer quota", + username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + if !reset { + user.UsedUploadDataTransfer += uploadSize + user.UsedDownloadDataTransfer += downloadSize + } else { + user.UsedUploadDataTransfer = uploadSize + user.UsedDownloadDataTransfer = downloadSize + } + user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + err = bucket.Put([]byte(username), buf) + providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %v dl increment: %v is reset? %v", + username, uploadSize, downloadSize, reset) + return err + }) +} + +func (p *BoltProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to update quota", username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + if reset { + user.UsedQuotaSize = sizeAdd + user.UsedQuotaFiles = filesAdd + } else { + user.UsedQuotaSize += sizeAdd + user.UsedQuotaFiles += filesAdd + } + user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + err = bucket.Put([]byte(username), buf) + providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %v size increment: %v is reset? %v", + username, filesAdd, sizeAdd, reset) + return err + }) +} + +func (p *BoltProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelError, "unable to get quota for user %v error: %v", username, err) + return 0, 0, 0, 0, err + } + return user.UsedQuotaFiles, user.UsedQuotaSize, user.UsedUploadDataTransfer, user.UsedDownloadDataTransfer, err +} + +func (p *BoltProvider) adminExists(username string) (Admin, error) { + var admin Admin + + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + a := bucket.Get([]byte(username)) + if a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", username)) + } + return json.Unmarshal(a, &admin) + }) + + return admin, err +} + +func (p *BoltProvider) addAdmin(admin *Admin) error { + err := admin.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + groupBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + rolesBucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + if a := bucket.Get([]byte(admin.Username)); a != nil { + return util.NewI18nError( + fmt.Errorf("%w: admin %q already exists", ErrDuplicatedKey, admin.Username), + util.I18nErrorDuplicatedUsername, + ) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + admin.ID = int64(id) + admin.LastLogin = 0 + admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + for idx := range admin.Groups { + err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name, groupBucket) + if err != nil { + return err + } + } + if err = p.addAdminToRole(admin.Username, admin.Role, rolesBucket); err != nil { + return err + } + + buf, err := json.Marshal(admin) + if err != nil { + return err + } + return bucket.Put([]byte(admin.Username), buf) + }) +} + +func (p *BoltProvider) updateAdmin(admin *Admin) error { + err := admin.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + groupBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + rolesBucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + var a []byte + if a = bucket.Get([]byte(admin.Username)); a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", admin.Username)) + } + var oldAdmin Admin + err = json.Unmarshal(a, &oldAdmin) + if err != nil { + return err + } + + if err = p.removeAdminFromRole(oldAdmin.Username, oldAdmin.Role, rolesBucket); err != nil { + return err + } + for idx := range oldAdmin.Groups { + err = p.removeAdminFromGroupMapping(oldAdmin.Username, oldAdmin.Groups[idx].Name, groupBucket) + if err != nil { + return err + } + } + if err = p.addAdminToRole(admin.Username, admin.Role, rolesBucket); err != nil { + return err + } + for idx := range admin.Groups { + err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name, groupBucket) + if err != nil { + return err + } + } + admin.ID = oldAdmin.ID + admin.CreatedAt = oldAdmin.CreatedAt + admin.LastLogin = oldAdmin.LastLogin + admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(admin) + if err != nil { + return err + } + return bucket.Put([]byte(admin.Username), buf) + }) +} + +func (p *BoltProvider) deleteAdmin(admin Admin) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + + var a []byte + if a = bucket.Get([]byte(admin.Username)); a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", admin.Username)) + } + var oldAdmin Admin + err = json.Unmarshal(a, &oldAdmin) + if err != nil { + return err + } + if len(oldAdmin.Groups) > 0 { + groupBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + for idx := range oldAdmin.Groups { + err = p.removeAdminFromGroupMapping(oldAdmin.Username, oldAdmin.Groups[idx].Name, groupBucket) + if err != nil { + return err + } + } + } + if oldAdmin.Role != "" { + rolesBucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + if err = p.removeAdminFromRole(oldAdmin.Username, oldAdmin.Role, rolesBucket); err != nil { + return err + } + } + + if err := p.deleteRelatedAPIKey(tx, admin.Username, APIKeyScopeAdmin); err != nil { + return err + } + + return bucket.Delete([]byte(admin.Username)) + }) +} + +func (p *BoltProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { + admins := make([]Admin, 0, limit) + + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var admin Admin + err = json.Unmarshal(v, &admin) + if err != nil { + return err + } + admin.HideConfidentialData() + admins = append(admins, admin) + if len(admins) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var admin Admin + err = json.Unmarshal(v, &admin) + if err != nil { + return err + } + admin.HideConfidentialData() + admins = append(admins, admin) + if len(admins) >= limit { + break + } + } + } + return err + }) + + return admins, err +} + +func (p *BoltProvider) dumpAdmins() ([]Admin, error) { + admins := make([]Admin, 0, 30) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var admin Admin + err = json.Unmarshal(v, &admin) + if err != nil { + return err + } + admins = append(admins, admin) + } + return err + }) + + return admins, err +} + +func (p *BoltProvider) userExists(username, role string) (User, error) { + var user User + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + u := bucket.Get([]byte(username)) + if u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + user, err = p.joinUserAndFolders(u, foldersBucket) + if err != nil { + return err + } + if !user.hasRole(role) { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + return nil + }) + return user, err +} + +func (p *BoltProvider) addUser(user *User) error { + err := ValidateUser(user) + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + groupBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + rolesBucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + if u := bucket.Get([]byte(user.Username)); u != nil { + return util.NewI18nError( + fmt.Errorf("%w: username %v already exists", ErrDuplicatedKey, user.Username), + util.I18nErrorDuplicatedUsername, + ) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + user.ID = int64(id) + user.LastQuotaUpdate = 0 + user.UsedQuotaSize = 0 + user.UsedQuotaFiles = 0 + user.UsedUploadDataTransfer = 0 + user.UsedDownloadDataTransfer = 0 + user.LastLogin = 0 + user.FirstDownload = 0 + user.FirstUpload = 0 + user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + if err := p.addUserToRole(user.Username, user.Role, rolesBucket); err != nil { + return err + } + for idx := range user.VirtualFolders { + err = p.addRelationToFolderMapping(user.VirtualFolders[idx].Name, user, nil, foldersBucket) + if err != nil { + return err + } + } + for idx := range user.Groups { + err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name, groupBucket) + if err != nil { + return err + } + } + buf, err := json.Marshal(user) + if err != nil { + return err + } + return bucket.Put([]byte(user.Username), buf) + }) +} + +func (p *BoltProvider) updateUser(user *User) error { + err := ValidateUser(user) + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(user.Username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", user.Username)) + } + var oldUser User + err = json.Unmarshal(u, &oldUser) + if err != nil { + return err + } + if err = p.updateUserRelations(tx, user, oldUser); err != nil { + return err + } + user.ID = oldUser.ID + user.LastQuotaUpdate = oldUser.LastQuotaUpdate + user.UsedQuotaSize = oldUser.UsedQuotaSize + user.UsedQuotaFiles = oldUser.UsedQuotaFiles + user.UsedUploadDataTransfer = oldUser.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = oldUser.UsedDownloadDataTransfer + user.LastLogin = oldUser.LastLogin + user.FirstDownload = oldUser.FirstDownload + user.FirstUpload = oldUser.FirstUpload + user.CreatedAt = oldUser.CreatedAt + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + + err = bucket.Put([]byte(user.Username), buf) + if err == nil { + setLastUserUpdate() + } + return err + }) +} + +func (p *BoltProvider) deleteUser(user User, _ bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + groupBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + rolesBucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(user.Username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", user.Username)) + } + var oldUser User + err = json.Unmarshal(u, &oldUser) + if err != nil { + return err + } + if err := p.removeUserFromRole(oldUser.Username, oldUser.Role, rolesBucket); err != nil { + return err + } + for idx := range oldUser.VirtualFolders { + err = p.removeRelationFromFolderMapping(oldUser.VirtualFolders[idx], oldUser.Username, "", foldersBucket) + if err != nil { + return err + } + } + for idx := range oldUser.Groups { + err = p.removeUserFromGroupMapping(oldUser.Username, oldUser.Groups[idx].Name, groupBucket) + if err != nil { + return err + } + } + if err := p.deleteRelatedAPIKey(tx, user.Username, APIKeyScopeUser); err != nil { + return err + } + if err := p.deleteRelatedShares(tx, user.Username); err != nil { + return err + } + return bucket.Delete([]byte(user.Username)) + }) +} + +func (p *BoltProvider) updateUserPassword(username, password string) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + user.Password = password + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + return bucket.Put([]byte(username), buf) + }) +} + +func (p *BoltProvider) dumpUsers() ([]User, error) { + users := make([]User, 0, 100) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + user, err := p.joinUserAndFolders(v, foldersBucket) + if err != nil { + return err + } + users = append(users, user) + } + return err + }) + return users, err +} + +func (p *BoltProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { + if getLastUserUpdate() < after { + return nil, nil + } + users := make([]User, 0, 10) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + groupsBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var user User + err := json.Unmarshal(v, &user) + if err != nil { + return err + } + if user.UpdatedAt < after { + continue + } + if len(user.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder + folders = append(folders, *folder) + } + user.VirtualFolders = folders + } + if len(user.Groups) > 0 { + groupMapping := make(map[string]Group) + for idx := range user.Groups { + group, err := p.groupExistsInternal(user.Groups[idx].Name, groupsBucket) + if err != nil { + continue + } + groupMapping[group.Name] = group + } + user.applyGroupSettings(groupMapping) + } + user.SetEmptySecretsIfNil() + users = append(users, user) + } + return err + }) + return users, err +} + +func (p *BoltProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + users := make([]User, 0, 10) + + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + groupsBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var user User + err := json.Unmarshal(v, &user) + if err != nil { + return err + } + if needFolders, ok := toFetch[user.Username]; ok { + if needFolders && len(user.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder + folders = append(folders, *folder) + } + user.VirtualFolders = folders + } + if len(user.Groups) > 0 { + groupMapping := make(map[string]Group) + for idx := range user.Groups { + group, err := p.groupExistsInternal(user.Groups[idx].Name, groupsBucket) + if err != nil { + continue + } + groupMapping[group.Name] = group + } + user.applyGroupSettings(groupMapping) + } + + user.SetEmptySecretsIfNil() + user.PrepareForRendering() + users = append(users, user) + } + } + return nil + }) + + return users, err +} + +func (p *BoltProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { + users := make([]User, 0, limit) + var err error + if limit <= 0 { + return users, err + } + err = p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + user, err := p.joinUserAndFolders(v, foldersBucket) + if err != nil { + return err + } + if !user.hasRole(role) { + continue + } + user.PrepareForRendering() + users = append(users, user) + if len(users) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + user, err := p.joinUserAndFolders(v, foldersBucket) + if err != nil { + return err + } + if !user.hasRole(role) { + continue + } + user.PrepareForRendering() + users = append(users, user) + if len(users) >= limit { + break + } + } + } + return err + }) + return users, err +} + +func (p *BoltProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { + folders := make([]vfs.BaseVirtualFolder, 0, 50) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var folder vfs.BaseVirtualFolder + err = json.Unmarshal(v, &folder) + if err != nil { + return err + } + folders = append(folders, folder) + } + return err + }) + return folders, err +} + +func (p *BoltProvider) getFolders(limit, offset int, order string, _ bool) ([]vfs.BaseVirtualFolder, error) { + folders := make([]vfs.BaseVirtualFolder, 0, limit) + var err error + if limit <= 0 { + return folders, err + } + err = p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var folder vfs.BaseVirtualFolder + err = json.Unmarshal(v, &folder) + if err != nil { + return err + } + folder.PrepareForRendering() + folders = append(folders, folder) + if len(folders) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var folder vfs.BaseVirtualFolder + err = json.Unmarshal(v, &folder) + if err != nil { + return err + } + folder.PrepareForRendering() + folders = append(folders, folder) + if len(folders) >= limit { + break + } + } + } + return err + }) + return folders, err +} + +func (p *BoltProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { + var folder vfs.BaseVirtualFolder + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + folder, err = p.folderExistsInternal(name, bucket) + return err + }) + return folder, err +} + +func (p *BoltProvider) addFolder(folder *vfs.BaseVirtualFolder) error { + err := ValidateFolder(folder) + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + if f := bucket.Get([]byte(folder.Name)); f != nil { + return util.NewI18nError( + fmt.Errorf("%w: folder %q already exists", ErrDuplicatedKey, folder.Name), + util.I18nErrorDuplicatedUsername, + ) + } + folder.Users = nil + folder.Groups = nil + return p.addFolderInternal(*folder, bucket) + }) +} + +func (p *BoltProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { + err := ValidateFolder(folder) + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + var f []byte + + if f = bucket.Get([]byte(folder.Name)); f == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("folder %v does not exist", folder.Name)) + } + var oldFolder vfs.BaseVirtualFolder + err = json.Unmarshal(f, &oldFolder) + if err != nil { + return err + } + + folder.ID = oldFolder.ID + folder.LastQuotaUpdate = oldFolder.LastQuotaUpdate + folder.UsedQuotaFiles = oldFolder.UsedQuotaFiles + folder.UsedQuotaSize = oldFolder.UsedQuotaSize + folder.Users = oldFolder.Users + folder.Groups = oldFolder.Groups + buf, err := json.Marshal(folder) + if err != nil { + return err + } + return bucket.Put([]byte(folder.Name), buf) + }) +} + +func (p *BoltProvider) deleteFolderMappings(folder vfs.BaseVirtualFolder, usersBucket, groupsBucket *bolt.Bucket) error { + for _, username := range folder.Users { + var u []byte + if u = usersBucket.Get([]byte(username)); u == nil { + continue + } + var user User + err := json.Unmarshal(u, &user) + if err != nil { + return err + } + var folders []vfs.VirtualFolder + for _, userFolder := range user.VirtualFolders { + if folder.Name != userFolder.Name { + folders = append(folders, userFolder) + } + } + user.VirtualFolders = folders + buf, err := json.Marshal(user) + if err != nil { + return err + } + err = usersBucket.Put([]byte(user.Username), buf) + if err != nil { + return err + } + } + for _, groupname := range folder.Groups { + var u []byte + if u = groupsBucket.Get([]byte(groupname)); u == nil { + continue + } + var group Group + err := json.Unmarshal(u, &group) + if err != nil { + return err + } + var folders []vfs.VirtualFolder + for _, groupFolder := range group.VirtualFolders { + if folder.Name != groupFolder.Name { + folders = append(folders, groupFolder) + } + } + group.VirtualFolders = folders + buf, err := json.Marshal(group) + if err != nil { + return err + } + err = groupsBucket.Put([]byte(group.Name), buf) + if err != nil { + return err + } + } + return nil +} + +func (p *BoltProvider) deleteFolder(baseFolder vfs.BaseVirtualFolder) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + usersBucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + groupsBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + + var f []byte + if f = bucket.Get([]byte(baseFolder.Name)); f == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("folder %v does not exist", baseFolder.Name)) + } + var folder vfs.BaseVirtualFolder + err = json.Unmarshal(f, &folder) + if err != nil { + return err + } + if err = p.deleteFolderMappings(folder, usersBucket, groupsBucket); err != nil { + return err + } + + return bucket.Delete([]byte(folder.Name)) + }) +} + +func (p *BoltProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + var f []byte + if f = bucket.Get([]byte(name)); f == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("folder %q does not exist, unable to update quota", name)) + } + var folder vfs.BaseVirtualFolder + err = json.Unmarshal(f, &folder) + if err != nil { + return err + } + if reset { + folder.UsedQuotaSize = sizeAdd + folder.UsedQuotaFiles = filesAdd + } else { + folder.UsedQuotaSize += sizeAdd + folder.UsedQuotaFiles += filesAdd + } + folder.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(folder) + if err != nil { + return err + } + return bucket.Put([]byte(folder.Name), buf) + }) +} + +func (p *BoltProvider) getUsedFolderQuota(name string) (int, int64, error) { + folder, err := p.getFolderByName(name) + if err != nil { + providerLog(logger.LevelError, "unable to get quota for folder %q error: %v", name, err) + return 0, 0, err + } + return folder.UsedQuotaFiles, folder.UsedQuotaSize, err +} + +func (p *BoltProvider) getGroups(limit, offset int, order string, _ bool) ([]Group, error) { + groups := make([]Group, 0, limit) + var err error + if limit <= 0 { + return groups, err + } + err = p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var group Group + group, err = p.joinGroupAndFolders(v, foldersBucket) + if err != nil { + return err + } + group.PrepareForRendering() + groups = append(groups, group) + if len(groups) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var group Group + group, err = p.joinGroupAndFolders(v, foldersBucket) + if err != nil { + return err + } + group.PrepareForRendering() + groups = append(groups, group) + if len(groups) >= limit { + break + } + } + } + return err + }) + return groups, err +} + +func (p *BoltProvider) getGroupsWithNames(names []string) ([]Group, error) { + var groups []Group + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + for _, name := range names { + g := bucket.Get([]byte(name)) + if g == nil { + continue + } + group, err := p.joinGroupAndFolders(g, foldersBucket) + if err != nil { + return err + } + groups = append(groups, group) + } + return nil + }) + return groups, err +} + +func (p *BoltProvider) getUsersInGroups(names []string) ([]string, error) { + var usernames []string + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + for _, name := range names { + g := bucket.Get([]byte(name)) + if g == nil { + continue + } + var group Group + err := json.Unmarshal(g, &group) + if err != nil { + return err + } + usernames = append(usernames, group.Users...) + } + return nil + }) + return usernames, err +} + +func (p *BoltProvider) groupExists(name string) (Group, error) { + var group Group + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + g := bucket.Get([]byte(name)) + if g == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + group, err = p.joinGroupAndFolders(g, foldersBucket) + return err + }) + return group, err +} + +func (p *BoltProvider) addGroup(group *Group) error { + if err := group.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + if u := bucket.Get([]byte(group.Name)); u != nil { + return util.NewI18nError( + fmt.Errorf("%w: group %q already exists", ErrDuplicatedKey, group.Name), + util.I18nErrorDuplicatedUsername, + ) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + group.ID = int64(id) + group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + group.Users = nil + group.Admins = nil + for idx := range group.VirtualFolders { + err = p.addRelationToFolderMapping(group.VirtualFolders[idx].Name, nil, group, foldersBucket) + if err != nil { + return err + } + } + buf, err := json.Marshal(group) + if err != nil { + return err + } + return bucket.Put([]byte(group.Name), buf) + }) +} + +func (p *BoltProvider) updateGroup(group *Group) error { + if err := group.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + var g []byte + if g = bucket.Get([]byte(group.Name)); g == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", group.Name)) + } + var oldGroup Group + err = json.Unmarshal(g, &oldGroup) + if err != nil { + return err + } + for idx := range oldGroup.VirtualFolders { + err = p.removeRelationFromFolderMapping(oldGroup.VirtualFolders[idx], "", oldGroup.Name, foldersBucket) + if err != nil { + return err + } + } + for idx := range group.VirtualFolders { + err = p.addRelationToFolderMapping(group.VirtualFolders[idx].Name, nil, group, foldersBucket) + if err != nil { + return err + } + } + group.ID = oldGroup.ID + group.CreatedAt = oldGroup.CreatedAt + group.Users = oldGroup.Users + group.Admins = oldGroup.Admins + group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(group) + if err != nil { + return err + } + return bucket.Put([]byte(group.Name), buf) + }) +} + +func (p *BoltProvider) deleteGroup(group Group) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + var g []byte + if g = bucket.Get([]byte(group.Name)); g == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", group.Name)) + } + var oldGroup Group + err = json.Unmarshal(g, &oldGroup) + if err != nil { + return err + } + if len(oldGroup.Users) > 0 { + return util.NewValidationError(fmt.Sprintf("the group %q is referenced, it cannot be removed", oldGroup.Name)) + } + if len(oldGroup.VirtualFolders) > 0 { + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + for idx := range oldGroup.VirtualFolders { + err = p.removeRelationFromFolderMapping(oldGroup.VirtualFolders[idx], "", oldGroup.Name, foldersBucket) + if err != nil { + return err + } + } + } + if len(oldGroup.Admins) > 0 { + adminsBucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + for idx := range oldGroup.Admins { + err = p.removeGroupFromAdminMapping(oldGroup.Name, oldGroup.Admins[idx], adminsBucket) + if err != nil { + return err + } + } + } + + return bucket.Delete([]byte(group.Name)) + }) +} + +func (p *BoltProvider) dumpGroups() ([]Group, error) { + groups := make([]Group, 0, 50) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + group, err := p.joinGroupAndFolders(v, foldersBucket) + if err != nil { + return err + } + groups = append(groups, group) + } + return err + }) + return groups, err +} + +func (p *BoltProvider) apiKeyExists(keyID string) (APIKey, error) { + var apiKey APIKey + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + + k := bucket.Get([]byte(keyID)) + if k == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("API key %v does not exist", keyID)) + } + return json.Unmarshal(k, &apiKey) + }) + return apiKey, err +} + +func (p *BoltProvider) addAPIKey(apiKey *APIKey) error { + err := apiKey.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + if a := bucket.Get([]byte(apiKey.KeyID)); a != nil { + return fmt.Errorf("API key %v already exists", apiKey.KeyID) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + apiKey.ID = int64(id) + apiKey.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + apiKey.LastUseAt = 0 + if apiKey.User != "" { + if err := p.userExistsInternal(tx, apiKey.User); err != nil { + return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) + } + } + if apiKey.Admin != "" { + if err := p.adminExistsInternal(tx, apiKey.Admin); err != nil { + return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) + } + } + buf, err := json.Marshal(apiKey) + if err != nil { + return err + } + return bucket.Put([]byte(apiKey.KeyID), buf) + }) +} + +func (p *BoltProvider) updateAPIKey(apiKey *APIKey) error { + err := apiKey.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + var a []byte + + if a = bucket.Get([]byte(apiKey.KeyID)); a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("API key %v does not exist", apiKey.KeyID)) + } + var oldAPIKey APIKey + err = json.Unmarshal(a, &oldAPIKey) + if err != nil { + return err + } + + apiKey.ID = oldAPIKey.ID + apiKey.KeyID = oldAPIKey.KeyID + apiKey.Key = oldAPIKey.Key + apiKey.CreatedAt = oldAPIKey.CreatedAt + apiKey.LastUseAt = oldAPIKey.LastUseAt + apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + if apiKey.User != "" { + if err := p.userExistsInternal(tx, apiKey.User); err != nil { + return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) + } + } + if apiKey.Admin != "" { + if err := p.adminExistsInternal(tx, apiKey.Admin); err != nil { + return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) + } + } + buf, err := json.Marshal(apiKey) + if err != nil { + return err + } + return bucket.Put([]byte(apiKey.KeyID), buf) + }) +} + +func (p *BoltProvider) deleteAPIKey(apiKey APIKey) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + + if bucket.Get([]byte(apiKey.KeyID)) == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("API key %v does not exist", apiKey.KeyID)) + } + + return bucket.Delete([]byte(apiKey.KeyID)) + }) +} + +func (p *BoltProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { + apiKeys := make([]APIKey, 0, limit) + + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var apiKey APIKey + err = json.Unmarshal(v, &apiKey) + if err != nil { + return err + } + apiKey.HideConfidentialData() + apiKeys = append(apiKeys, apiKey) + if len(apiKeys) >= limit { + break + } + } + return nil + } + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var apiKey APIKey + err = json.Unmarshal(v, &apiKey) + if err != nil { + return err + } + apiKey.HideConfidentialData() + apiKeys = append(apiKeys, apiKey) + if len(apiKeys) >= limit { + break + } + } + return nil + }) + + return apiKeys, err +} + +func (p *BoltProvider) dumpAPIKeys() ([]APIKey, error) { + apiKeys := make([]APIKey, 0, 30) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var apiKey APIKey + err = json.Unmarshal(v, &apiKey) + if err != nil { + return err + } + apiKeys = append(apiKeys, apiKey) + } + return err + }) + + return apiKeys, err +} + +func (p *BoltProvider) shareExists(shareID, username string) (Share, error) { + var share Share + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + + s := bucket.Get([]byte(shareID)) + if s == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", shareID)) + } + if err := json.Unmarshal(s, &share); err != nil { + return err + } + if username != "" && share.Username != username { + return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", shareID)) + } + return nil + }) + return share, err +} + +func (p *BoltProvider) addShare(share *Share) error { + err := share.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + if a := bucket.Get([]byte(share.ShareID)); a != nil { + return fmt.Errorf("share %q already exists", share.ShareID) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + share.ID = int64(id) + if !share.IsRestore { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + share.UpdatedAt = share.CreatedAt + share.LastUseAt = 0 + share.UsedTokens = 0 + } + if share.CreatedAt == 0 { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.UpdatedAt == 0 { + share.UpdatedAt = share.CreatedAt + } + if err := p.userExistsInternal(tx, share.Username); err != nil { + return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) + } + buf, err := json.Marshal(share) + if err != nil { + return err + } + return bucket.Put([]byte(share.ShareID), buf) + }) +} + +func (p *BoltProvider) updateShare(share *Share) error { + if err := share.validate(); err != nil { + return err + } + + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + var s []byte + + if s = bucket.Get([]byte(share.ShareID)); s == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) + } + var oldObject Share + if err = json.Unmarshal(s, &oldObject); err != nil { + return err + } + if oldObject.Username != share.Username { + return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) + } + + share.ID = oldObject.ID + share.ShareID = oldObject.ShareID + if !share.IsRestore { + share.UsedTokens = oldObject.UsedTokens + share.CreatedAt = oldObject.CreatedAt + share.LastUseAt = oldObject.LastUseAt + share.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.CreatedAt == 0 { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.UpdatedAt == 0 { + share.UpdatedAt = share.CreatedAt + } + if err := p.userExistsInternal(tx, share.Username); err != nil { + return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) + } + buf, err := json.Marshal(share) + if err != nil { + return err + } + return bucket.Put([]byte(share.ShareID), buf) + }) +} + +func (p *BoltProvider) deleteShare(share Share) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + + var s []byte + + if s = bucket.Get([]byte(share.ShareID)); s == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) + } + var oldObject Share + if err = json.Unmarshal(s, &oldObject); err != nil { + return err + } + if oldObject.Username != share.Username { + return util.NewRecordNotFoundError(fmt.Sprintf("Share %v does not exist", share.ShareID)) + } + + return bucket.Delete([]byte(share.ShareID)) + }) +} + +func (p *BoltProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { + shares := make([]Share, 0, limit) + + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var share Share + if err := json.Unmarshal(v, &share); err != nil { + return err + } + if share.Username != username { + continue + } + itNum++ + if itNum <= offset { + continue + } + share.HideConfidentialData() + shares = append(shares, share) + if len(shares) >= limit { + break + } + } + return nil + } + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + var share Share + err = json.Unmarshal(v, &share) + if err != nil { + return err + } + if share.Username != username { + continue + } + itNum++ + if itNum <= offset { + continue + } + share.HideConfidentialData() + shares = append(shares, share) + if len(shares) >= limit { + break + } + } + return nil + }) + + return shares, err +} + +func (p *BoltProvider) dumpShares() ([]Share, error) { + shares := make([]Share, 0, 30) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var share Share + err = json.Unmarshal(v, &share) + if err != nil { + return err + } + shares = append(shares, share) + } + return err + }) + + return shares, err +} + +func (p *BoltProvider) updateShareLastUse(shareID string, numTokens int) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(shareID)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("share %q does not exist, unable to update last use", shareID)) + } + var share Share + err = json.Unmarshal(u, &share) + if err != nil { + return err + } + share.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) + share.UsedTokens += numTokens + buf, err := json.Marshal(share) + if err != nil { + return err + } + err = bucket.Put([]byte(shareID), buf) + if err != nil { + providerLog(logger.LevelWarn, "error updating last use for share %q: %v", shareID, err) + return err + } + providerLog(logger.LevelDebug, "last use updated for share %q", shareID) + return nil + }) +} + +func (p *BoltProvider) getDefenderHosts(_ int64, _ int) ([]DefenderEntry, error) { + return nil, ErrNotImplemented +} + +func (p *BoltProvider) getDefenderHostByIP(_ string, _ int64) (DefenderEntry, error) { + return DefenderEntry{}, ErrNotImplemented +} + +func (p *BoltProvider) isDefenderHostBanned(_ string) (DefenderEntry, error) { + return DefenderEntry{}, ErrNotImplemented +} + +func (p *BoltProvider) updateDefenderBanTime(_ string, _ int) error { + return ErrNotImplemented +} + +func (p *BoltProvider) deleteDefenderHost(_ string) error { + return ErrNotImplemented +} + +func (p *BoltProvider) addDefenderEvent(_ string, _ int) error { + return ErrNotImplemented +} + +func (p *BoltProvider) setDefenderBanTime(_ string, _ int64) error { + return ErrNotImplemented +} + +func (p *BoltProvider) cleanupDefender(_ int64) error { + return ErrNotImplemented +} + +func (p *BoltProvider) addActiveTransfer(_ ActiveTransfer) error { + return ErrNotImplemented +} + +func (p *BoltProvider) updateActiveTransferSizes(_, _, _ int64, _ string) error { + return ErrNotImplemented +} + +func (p *BoltProvider) removeActiveTransfer(_ int64, _ string) error { + return ErrNotImplemented +} + +func (p *BoltProvider) cleanupActiveTransfers(_ time.Time) error { + return ErrNotImplemented +} + +func (p *BoltProvider) getActiveTransfers(_ time.Time) ([]ActiveTransfer, error) { + return nil, ErrNotImplemented +} + +func (p *BoltProvider) addSharedSession(_ Session) error { + return ErrNotImplemented +} + +func (p *BoltProvider) deleteSharedSession(_ string, _ SessionType) error { + return ErrNotImplemented +} + +func (p *BoltProvider) getSharedSession(_ string, _ SessionType) (Session, error) { + return Session{}, ErrNotImplemented +} + +func (p *BoltProvider) cleanupSharedSessions(_ SessionType, _ int64) error { + return ErrNotImplemented +} + +func (p *BoltProvider) getEventActions(limit, offset int, order string, _ bool) ([]BaseEventAction, error) { + if limit <= 0 { + return nil, nil + } + actions := make([]BaseEventAction, 0, limit) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + itNum := 0 + cursor := bucket.Cursor() + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var action BaseEventAction + err = json.Unmarshal(v, &action) + if err != nil { + return err + } + action.PrepareForRendering() + actions = append(actions, action) + if len(actions) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var action BaseEventAction + err = json.Unmarshal(v, &action) + if err != nil { + return err + } + action.PrepareForRendering() + actions = append(actions, action) + if len(actions) >= limit { + break + } + } + } + return nil + }) + return actions, err +} + +func (p *BoltProvider) dumpEventActions() ([]BaseEventAction, error) { + actions := make([]BaseEventAction, 0, 50) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var action BaseEventAction + err = json.Unmarshal(v, &action) + if err != nil { + return err + } + actions = append(actions, action) + } + return nil + }) + return actions, err +} + +func (p *BoltProvider) eventActionExists(name string) (BaseEventAction, error) { + var action BaseEventAction + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + k := bucket.Get([]byte(name)) + if k == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("action %q does not exist", name)) + } + return json.Unmarshal(k, &action) + }) + return action, err +} + +func (p *BoltProvider) addEventAction(action *BaseEventAction) error { + err := action.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + if a := bucket.Get([]byte(action.Name)); a != nil { + return util.NewI18nError( + fmt.Errorf("%w: event action %q already exists", ErrDuplicatedKey, action.Name), + util.I18nErrorDuplicatedName, + ) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + action.ID = int64(id) + action.Rules = nil + buf, err := json.Marshal(action) + if err != nil { + return err + } + return bucket.Put([]byte(action.Name), buf) + }) +} + +func (p *BoltProvider) updateEventAction(action *BaseEventAction) error { + err := action.validate() + if err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + var a []byte + + if a = bucket.Get([]byte(action.Name)); a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("event action %s does not exist", action.Name)) + } + var oldAction BaseEventAction + err = json.Unmarshal(a, &oldAction) + if err != nil { + return err + } + action.ID = oldAction.ID + action.Name = oldAction.Name + action.Rules = nil + if len(oldAction.Rules) > 0 { + rulesBucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + var relatedRules []string + for _, ruleName := range oldAction.Rules { + r := rulesBucket.Get([]byte(ruleName)) + if r != nil { + relatedRules = append(relatedRules, ruleName) + var rule EventRule + err := json.Unmarshal(r, &rule) + if err != nil { + return err + } + rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(rule) + if err != nil { + return err + } + if err = rulesBucket.Put([]byte(rule.Name), buf); err != nil { + return err + } + setLastRuleUpdate() + } + } + action.Rules = relatedRules + } + buf, err := json.Marshal(action) + if err != nil { + return err + } + return bucket.Put([]byte(action.Name), buf) + }) +} + +func (p *BoltProvider) deleteEventAction(action BaseEventAction) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + var a []byte + + if a = bucket.Get([]byte(action.Name)); a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("action %s does not exist", action.Name)) + } + var oldAction BaseEventAction + err = json.Unmarshal(a, &oldAction) + if err != nil { + return err + } + if len(oldAction.Rules) > 0 { + return util.NewValidationError(fmt.Sprintf("action %s is referenced, it cannot be removed", oldAction.Name)) + } + return bucket.Delete([]byte(action.Name)) + }) +} + +func (p *BoltProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { + if limit <= 0 { + return nil, nil + } + rules := make([]EventRule, 0, limit) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + itNum := 0 + cursor := bucket.Cursor() + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var rule EventRule + rule, err = p.joinRuleAndActions(v, actionsBucket) + if err != nil { + return err + } + rule.PrepareForRendering() + rules = append(rules, rule) + if len(rules) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var rule EventRule + rule, err = p.joinRuleAndActions(v, actionsBucket) + if err != nil { + return err + } + rule.PrepareForRendering() + rules = append(rules, rule) + if len(rules) >= limit { + break + } + } + } + return err + }) + return rules, err +} + +func (p *BoltProvider) dumpEventRules() ([]EventRule, error) { + rules := make([]EventRule, 0, 50) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + rule, err := p.joinRuleAndActions(v, actionsBucket) + if err != nil { + return err + } + rules = append(rules, rule) + } + return nil + }) + return rules, err +} + +func (p *BoltProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { + if getLastRuleUpdate() < after { + return nil, nil + } + rules := make([]EventRule, 0, 10) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var rule EventRule + err := json.Unmarshal(v, &rule) + if err != nil { + return err + } + if rule.UpdatedAt < after { + continue + } + var actions []EventAction + for idx := range rule.Actions { + action := &rule.Actions[idx] + var baseAction BaseEventAction + k := actionsBucket.Get([]byte(action.Name)) + if k == nil { + continue + } + err = json.Unmarshal(k, &baseAction) + if err != nil { + continue + } + baseAction.Options.SetEmptySecretsIfNil() + action.BaseEventAction = baseAction + actions = append(actions, *action) + } + rule.Actions = actions + rules = append(rules, rule) + } + return nil + }) + return rules, err +} + +func (p *BoltProvider) eventRuleExists(name string) (EventRule, error) { + var rule EventRule + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + r := bucket.Get([]byte(name)) + if r == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", name)) + } + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + rule, err = p.joinRuleAndActions(r, actionsBucket) + return err + }) + return rule, err +} + +func (p *BoltProvider) addEventRule(rule *EventRule) error { + if err := rule.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + if r := bucket.Get([]byte(rule.Name)); r != nil { + return util.NewI18nError( + fmt.Errorf("%w: event rule %q already exists", ErrDuplicatedKey, rule.Name), + util.I18nErrorDuplicatedName, + ) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + rule.ID = int64(id) + rule.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + rule.UpdatedAt = rule.CreatedAt + for idx := range rule.Actions { + if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name, actionsBucket); err != nil { + return err + } + } + sort.Slice(rule.Actions, func(i, j int) bool { + return rule.Actions[i].Order < rule.Actions[j].Order + }) + buf, err := json.Marshal(rule) + if err != nil { + return err + } + err = bucket.Put([]byte(rule.Name), buf) + if err == nil { + setLastRuleUpdate() + } + return err + }) +} + +func (p *BoltProvider) updateEventRule(rule *EventRule) error { + if err := rule.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + var r []byte + if r = bucket.Get([]byte(rule.Name)); r == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", rule.Name)) + } + var oldRule EventRule + if err = json.Unmarshal(r, &oldRule); err != nil { + return err + } + for idx := range oldRule.Actions { + if err = p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name, actionsBucket); err != nil { + return err + } + } + for idx := range rule.Actions { + if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name, actionsBucket); err != nil { + return err + } + } + rule.ID = oldRule.ID + rule.CreatedAt = oldRule.CreatedAt + rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(rule) + if err != nil { + return err + } + sort.Slice(rule.Actions, func(i, j int) bool { + return rule.Actions[i].Order < rule.Actions[j].Order + }) + err = bucket.Put([]byte(rule.Name), buf) + if err == nil { + setLastRuleUpdate() + } + return err + }) +} + +func (p *BoltProvider) deleteEventRule(rule EventRule, _ bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getRulesBucket(tx) + if err != nil { + return err + } + var r []byte + if r = bucket.Get([]byte(rule.Name)); r == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", rule.Name)) + } + var oldRule EventRule + if err = json.Unmarshal(r, &oldRule); err != nil { + return err + } + if len(oldRule.Actions) > 0 { + actionsBucket, err := p.getActionsBucket(tx) + if err != nil { + return err + } + for idx := range oldRule.Actions { + if err = p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name, actionsBucket); err != nil { + return err + } + } + } + return bucket.Delete([]byte(rule.Name)) + }) +} + +func (*BoltProvider) getTaskByName(_ string) (Task, error) { + return Task{}, ErrNotImplemented +} + +func (*BoltProvider) addTask(_ string) error { + return ErrNotImplemented +} + +func (*BoltProvider) updateTask(_ string, _ int64) error { + return ErrNotImplemented +} + +func (*BoltProvider) updateTaskTimestamp(_ string) error { + return ErrNotImplemented +} + +func (*BoltProvider) addNode() error { + return ErrNotImplemented +} + +func (*BoltProvider) getNodeByName(_ string) (Node, error) { + return Node{}, ErrNotImplemented +} + +func (*BoltProvider) getNodes() ([]Node, error) { + return nil, ErrNotImplemented +} + +func (*BoltProvider) updateNodeTimestamp() error { + return ErrNotImplemented +} + +func (*BoltProvider) cleanupNodes() error { + return ErrNotImplemented +} + +func (p *BoltProvider) roleExists(name string) (Role, error) { + var role Role + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + r := bucket.Get([]byte(name)) + if r == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("role %q does not exist", name)) + } + return json.Unmarshal(r, &role) + }) + return role, err +} + +func (p *BoltProvider) addRole(role *Role) error { + if err := role.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + if r := bucket.Get([]byte(role.Name)); r != nil { + return util.NewI18nError( + fmt.Errorf("%w: role %q already exists", ErrDuplicatedKey, role.Name), + util.I18nErrorDuplicatedName, + ) + } + id, err := bucket.NextSequence() + if err != nil { + return err + } + role.ID = int64(id) + role.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + role.Users = nil + role.Admins = nil + buf, err := json.Marshal(role) + if err != nil { + return err + } + return bucket.Put([]byte(role.Name), buf) + }) +} + +func (p *BoltProvider) updateRole(role *Role) error { + if err := role.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + var r []byte + if r = bucket.Get([]byte(role.Name)); r == nil { + return fmt.Errorf("role %q does not exist", role.Name) + } + var oldRole Role + err = json.Unmarshal(r, &oldRole) + if err != nil { + return err + } + role.ID = oldRole.ID + role.CreatedAt = oldRole.CreatedAt + role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + role.Users = oldRole.Users + role.Admins = oldRole.Admins + buf, err := json.Marshal(role) + if err != nil { + return err + } + return bucket.Put([]byte(role.Name), buf) + }) +} + +func (p *BoltProvider) deleteRole(role Role) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + var r []byte + if r = bucket.Get([]byte(role.Name)); r == nil { + return fmt.Errorf("role %q does not exist", role.Name) + } + var oldRole Role + err = json.Unmarshal(r, &oldRole) + if err != nil { + return err + } + if len(oldRole.Admins) > 0 { + return util.NewValidationError(fmt.Sprintf("the role %q is referenced, it cannot be removed", oldRole.Name)) + } + if len(oldRole.Users) > 0 { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + for _, username := range oldRole.Users { + if err := p.removeRoleFromUser(username, oldRole.Name, bucket); err != nil { + return err + } + } + } + + return bucket.Delete([]byte(role.Name)) + }) +} + +func (p *BoltProvider) getRoles(limit int, offset int, order string, _ bool) ([]Role, error) { + roles := make([]Role, 0, limit) + if limit <= 0 { + return roles, nil + } + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + itNum := 0 + if order == OrderASC { + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + itNum++ + if itNum <= offset { + continue + } + var role Role + err = json.Unmarshal(v, &role) + if err != nil { + return err + } + roles = append(roles, role) + if len(roles) >= limit { + break + } + } + } else { + for k, v := cursor.Last(); k != nil; k, v = cursor.Prev() { + itNum++ + if itNum <= offset { + continue + } + var role Role + err = json.Unmarshal(v, &role) + if err != nil { + return err + } + roles = append(roles, role) + if len(roles) >= limit { + break + } + } + } + return nil + }) + return roles, err +} + +func (p *BoltProvider) dumpRoles() ([]Role, error) { + roles := make([]Role, 0, 10) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var role Role + err = json.Unmarshal(v, &role) + if err != nil { + return err + } + roles = append(roles, role) + } + return err + }) + return roles, err +} + +func (p *BoltProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { + entry := IPListEntry{ + IPOrNet: ipOrNet, + Type: listType, + } + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + e := bucket.Get([]byte(entry.getKey())) + if e == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("entry %q does not exist", entry.IPOrNet)) + } + err = json.Unmarshal(e, &entry) + if err == nil { + entry.PrepareForRendering() + } + return err + }) + return entry, err +} + +func (p *BoltProvider) addIPListEntry(entry *IPListEntry) error { + if err := entry.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + if e := bucket.Get([]byte(entry.getKey())); e != nil { + return util.NewI18nError( + fmt.Errorf("%w: entry %q already exists", ErrDuplicatedKey, entry.IPOrNet), + util.I18nErrorDuplicatedIPNet, + ) + } + entry.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(entry) + if err != nil { + return err + } + return bucket.Put([]byte(entry.getKey()), buf) + }) +} + +func (p *BoltProvider) updateIPListEntry(entry *IPListEntry) error { + if err := entry.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + var e []byte + if e = bucket.Get([]byte(entry.getKey())); e == nil { + return fmt.Errorf("entry %q does not exist", entry.IPOrNet) + } + var oldEntry IPListEntry + err = json.Unmarshal(e, &oldEntry) + if err != nil { + return err + } + entry.CreatedAt = oldEntry.CreatedAt + entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(entry) + if err != nil { + return err + } + return bucket.Put([]byte(entry.getKey()), buf) + }) +} + +func (p *BoltProvider) deleteIPListEntry(entry IPListEntry, _ bool) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + if e := bucket.Get([]byte(entry.getKey())); e == nil { + return fmt.Errorf("entry %q does not exist", entry.IPOrNet) + } + return bucket.Delete([]byte(entry.getKey())) + }) +} + +func (p *BoltProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { + entries := make([]IPListEntry, 0, 15) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + prefix := []byte(fmt.Sprintf("%d_", listType)) + acceptKey := func(k []byte) bool { + return k != nil && bytes.HasPrefix(k, prefix) + } + cursor := bucket.Cursor() + if order == OrderASC { + for k, v := cursor.Seek(prefix); acceptKey(k); k, v = cursor.Next() { + var entry IPListEntry + err = json.Unmarshal(v, &entry) + if err != nil { + return err + } + if entry.satisfySearchConstraints(filter, from, order) { + entry.PrepareForRendering() + entries = append(entries, entry) + if limit > 0 && len(entries) >= limit { + break + } + } + } + } else { + for k, v := cursor.Last(); acceptKey(k); k, v = cursor.Prev() { + var entry IPListEntry + err = json.Unmarshal(v, &entry) + if err != nil { + return err + } + if entry.satisfySearchConstraints(filter, from, order) { + entry.PrepareForRendering() + entries = append(entries, entry) + if limit > 0 && len(entries) >= limit { + break + } + } + } + } + return nil + }) + return entries, err +} + +func (p *BoltProvider) getRecentlyUpdatedIPListEntries(_ int64) ([]IPListEntry, error) { + return nil, ErrNotImplemented +} + +func (p *BoltProvider) dumpIPListEntries() ([]IPListEntry, error) { + entries := make([]IPListEntry, 0, 10) + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + if count := bucket.Stats().KeyN; count > ipListMemoryLimit { + providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count) + return nil + } + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var entry IPListEntry + err = json.Unmarshal(v, &entry) + if err != nil { + return err + } + entry.PrepareForRendering() + entries = append(entries, entry) + } + return nil + }) + return entries, err +} + +func (p *BoltProvider) countIPListEntries(listType IPListType) (int64, error) { + var count int64 + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + if listType == 0 { + count = int64(bucket.Stats().KeyN) + return nil + } + prefix := []byte(fmt.Sprintf("%d_", listType)) + cursor := bucket.Cursor() + for k, _ := cursor.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = cursor.Next() { + count++ + } + return nil + }) + return count, err +} + +func (p *BoltProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { + entries := make([]IPListEntry, 0, 3) + ipAddr, err := netip.ParseAddr(ip) + if err != nil { + return entries, fmt.Errorf("invalid ip address %s", ip) + } + var netType int + var ipBytes []byte + if ipAddr.Is4() || ipAddr.Is4In6() { + netType = ipTypeV4 + as4 := ipAddr.As4() + ipBytes = as4[:] + } else { + netType = ipTypeV6 + as16 := ipAddr.As16() + ipBytes = as16[:] + } + err = p.dbHandle.View(func(tx *bolt.Tx) error { + bucket, err := p.getIPListsBucket(tx) + if err != nil { + return err + } + prefix := []byte(fmt.Sprintf("%d_", listType)) + cursor := bucket.Cursor() + for k, v := cursor.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = cursor.Next() { + var entry IPListEntry + err = json.Unmarshal(v, &entry) + if err != nil { + return err + } + if entry.IPType == netType && bytes.Compare(ipBytes, entry.First) >= 0 && bytes.Compare(ipBytes, entry.Last) <= 0 { + entry.PrepareForRendering() + entries = append(entries, entry) + } + } + return nil + }) + return entries, err +} + +func (p *BoltProvider) getConfigs() (Configs, error) { + var configs Configs + err := p.dbHandle.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(configsBucket) + if bucket == nil { + return fmt.Errorf("unable to find configs bucket") + } + data := bucket.Get(configsKey) + if data != nil { + return json.Unmarshal(data, &configs) + } + return nil + }) + return configs, err +} + +func (p *BoltProvider) setConfigs(configs *Configs) error { + if err := configs.validate(); err != nil { + return err + } + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket(configsBucket) + if bucket == nil { + return fmt.Errorf("unable to find configs bucket") + } + buf, err := json.Marshal(configs) + if err != nil { + return err + } + return bucket.Put(configsKey, buf) + }) +} + +func (p *BoltProvider) setFirstDownloadTimestamp(username string) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to set download timestamp", + username)) + } + var user User + err = json.Unmarshal(u, &user) + if err != nil { + return err + } + if user.FirstDownload > 0 { + return util.NewGenericError(fmt.Sprintf("first download already set to %v", + util.GetTimeFromMsecSinceEpoch(user.FirstDownload))) + } + user.FirstDownload = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + return bucket.Put([]byte(username), buf) + }) +} + +func (p *BoltProvider) setFirstUploadTimestamp(username string) error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + var u []byte + if u = bucket.Get([]byte(username)); u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist, unable to set upload timestamp", + username)) + } + var user User + if err = json.Unmarshal(u, &user); err != nil { + return err + } + if user.FirstUpload > 0 { + return util.NewGenericError(fmt.Sprintf("first upload already set to %v", + util.GetTimeFromMsecSinceEpoch(user.FirstUpload))) + } + user.FirstUpload = util.GetTimeAsMsSinceEpoch(time.Now()) + buf, err := json.Marshal(user) + if err != nil { + return err + } + return bucket.Put([]byte(username), buf) + }) +} + +func (p *BoltProvider) close() error { + return p.dbHandle.Close() +} + +func (p *BoltProvider) reloadConfig() error { + return nil +} + +// initializeDatabase does nothing, no initilization is needed for bolt provider +func (p *BoltProvider) initializeDatabase() error { + return ErrNoInitRequired +} + +func (p *BoltProvider) migrateDatabase() error { + dbVersion, err := getBoltDatabaseVersion(p.dbHandle) + if err != nil { + return err + } + switch version := dbVersion.Version; { + case version == boltDatabaseVersion: + providerLog(logger.LevelDebug, "bolt database is up to date, current version: %d", version) + return ErrNoInitRequired + case version < 33: + err = errSchemaVersionTooOld(version) + providerLog(logger.LevelError, "%v", err) + logger.ErrorToConsole("%v", err) + return err + case version == 33: + logger.InfoToConsole("updating database schema version: %d -> 34", version) + providerLog(logger.LevelInfo, "updating database schema version: %d -> 34", version) + return updateBoltDatabaseVersion(p.dbHandle, 34) + + default: + if version > boltDatabaseVersion { + providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, + boltDatabaseVersion) + logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, + boltDatabaseVersion) + return nil + } + return fmt.Errorf("database schema version not handled: %d", version) + } +} + +func (p *BoltProvider) revertDatabase(targetVersion int) error { + dbVersion, err := getBoltDatabaseVersion(p.dbHandle) + if err != nil { + return err + } + if dbVersion.Version == targetVersion { + return errors.New("current version match target version, nothing to do") + } + switch dbVersion.Version { + case 34: + logger.InfoToConsole("downgrading database schema version: %d -> 33", dbVersion.Version) + providerLog(logger.LevelInfo, "downgrading database schema version: %d -> 33", dbVersion.Version) + return updateBoltDatabaseVersion(p.dbHandle, 33) + + default: + return fmt.Errorf("database schema version not handled: %v", dbVersion.Version) + } +} + +func (p *BoltProvider) resetDatabase() error { + return p.dbHandle.Update(func(tx *bolt.Tx) error { + for _, bucketName := range boltBuckets { + err := tx.DeleteBucket(bucketName) + if err != nil && !errors.Is(err, bolterrors.ErrBucketNotFound) { + return fmt.Errorf("unable to remove bucket %v: %w", bucketName, err) + } + } + return nil + }) +} + +func (p *BoltProvider) joinRuleAndActions(r []byte, actionsBucket *bolt.Bucket) (EventRule, error) { + var rule EventRule + err := json.Unmarshal(r, &rule) + if err != nil { + return rule, err + } + var actions []EventAction + for idx := range rule.Actions { + action := &rule.Actions[idx] + var baseAction BaseEventAction + k := actionsBucket.Get([]byte(action.Name)) + if k == nil { + continue + } + err = json.Unmarshal(k, &baseAction) + if err != nil { + continue + } + baseAction.Options.SetEmptySecretsIfNil() + action.BaseEventAction = baseAction + actions = append(actions, *action) + } + rule.Actions = actions + return rule, nil +} + +func (p *BoltProvider) joinGroupAndFolders(g []byte, foldersBucket *bolt.Bucket) (Group, error) { + var group Group + err := json.Unmarshal(g, &group) + if err != nil { + return group, err + } + if len(group.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range group.VirtualFolders { + folder := &group.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder + folders = append(folders, *folder) + } + group.VirtualFolders = folders + } + group.SetEmptySecretsIfNil() + return group, err +} + +func (p *BoltProvider) joinUserAndFolders(u []byte, foldersBucket *bolt.Bucket) (User, error) { + var user User + err := json.Unmarshal(u, &user) + if err != nil { + return user, err + } + if len(user.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name, foldersBucket) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder + folders = append(folders, *folder) + } + user.VirtualFolders = folders + } + user.SetEmptySecretsIfNil() + return user, err +} + +func (p *BoltProvider) groupExistsInternal(name string, bucket *bolt.Bucket) (Group, error) { + var group Group + g := bucket.Get([]byte(name)) + if g == nil { + err := util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) + return group, err + } + err := json.Unmarshal(g, &group) + return group, err +} + +func (p *BoltProvider) folderExistsInternal(name string, bucket *bolt.Bucket) (vfs.BaseVirtualFolder, error) { + var folder vfs.BaseVirtualFolder + f := bucket.Get([]byte(name)) + if f == nil { + err := util.NewRecordNotFoundError(fmt.Sprintf("folder %q does not exist", name)) + return folder, err + } + err := json.Unmarshal(f, &folder) + return folder, err +} + +func (p *BoltProvider) addFolderInternal(folder vfs.BaseVirtualFolder, bucket *bolt.Bucket) error { + id, err := bucket.NextSequence() + if err != nil { + return err + } + folder.ID = int64(id) + buf, err := json.Marshal(folder) + if err != nil { + return err + } + return bucket.Put([]byte(folder.Name), buf) +} + +func (p *BoltProvider) removeRoleFromUser(username, role string, bucket *bolt.Bucket) error { + u := bucket.Get([]byte(username)) + if u == nil { + providerLog(logger.LevelWarn, "user %q does not exist, cannot remove role %q", username, role) + return nil + } + var user User + err := json.Unmarshal(u, &user) + if err != nil { + return err + } + if user.Role == role { + user.Role = "" + buf, err := json.Marshal(user) + if err != nil { + return err + } + return bucket.Put([]byte(user.Username), buf) + } + providerLog(logger.LevelError, "user %q does not have the expected role %q, actual %q", username, role, user.Role) + return nil +} + +func (p *BoltProvider) addAdminToRole(username, roleName string, bucket *bolt.Bucket) error { + if roleName == "" { + return nil + } + r := bucket.Get([]byte(roleName)) + if r == nil { + return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, roleName) + } + var role Role + err := json.Unmarshal(r, &role) + if err != nil { + return err + } + if !slices.Contains(role.Admins, username) { + role.Admins = append(role.Admins, username) + buf, err := json.Marshal(role) + if err != nil { + return err + } + return bucket.Put([]byte(role.Name), buf) + } + return nil +} + +func (p *BoltProvider) removeAdminFromRole(username, roleName string, bucket *bolt.Bucket) error { + if roleName == "" { + return nil + } + r := bucket.Get([]byte(roleName)) + if r == nil { + providerLog(logger.LevelWarn, "role %q does not exist, cannot remove admin %q", roleName, username) + return nil + } + var role Role + err := json.Unmarshal(r, &role) + if err != nil { + return err + } + if slices.Contains(role.Admins, username) { + var admins []string + for _, admin := range role.Admins { + if admin != username { + admins = append(admins, admin) + } + } + role.Admins = util.RemoveDuplicates(admins, false) + buf, err := json.Marshal(role) + if err != nil { + return err + } + return bucket.Put([]byte(role.Name), buf) + } + return nil +} + +func (p *BoltProvider) addUserToRole(username, roleName string, bucket *bolt.Bucket) error { + if roleName == "" { + return nil + } + r := bucket.Get([]byte(roleName)) + if r == nil { + return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, roleName) + } + var role Role + err := json.Unmarshal(r, &role) + if err != nil { + return err + } + if !slices.Contains(role.Users, username) { + role.Users = append(role.Users, username) + buf, err := json.Marshal(role) + if err != nil { + return err + } + return bucket.Put([]byte(role.Name), buf) + } + return nil +} + +func (p *BoltProvider) removeUserFromRole(username, roleName string, bucket *bolt.Bucket) error { + if roleName == "" { + return nil + } + r := bucket.Get([]byte(roleName)) + if r == nil { + providerLog(logger.LevelWarn, "role %q does not exist, cannot remove admin %q", roleName, username) + return nil + } + var role Role + err := json.Unmarshal(r, &role) + if err != nil { + return err + } + if slices.Contains(role.Users, username) { + var users []string + for _, user := range role.Users { + if user != username { + users = append(users, user) + } + } + users = util.RemoveDuplicates(users, false) + role.Users = users + buf, err := json.Marshal(role) + if err != nil { + return err + } + return bucket.Put([]byte(role.Name), buf) + } + return nil +} + +func (p *BoltProvider) addRuleToActionMapping(ruleName, actionName string, bucket *bolt.Bucket) error { + a := bucket.Get([]byte(actionName)) + if a == nil { + return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName)) + } + var action BaseEventAction + err := json.Unmarshal(a, &action) + if err != nil { + return err + } + if !slices.Contains(action.Rules, ruleName) { + action.Rules = append(action.Rules, ruleName) + buf, err := json.Marshal(action) + if err != nil { + return err + } + return bucket.Put([]byte(action.Name), buf) + } + return nil +} + +func (p *BoltProvider) removeRuleFromActionMapping(ruleName, actionName string, bucket *bolt.Bucket) error { + a := bucket.Get([]byte(actionName)) + if a == nil { + providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName) + return nil + } + var action BaseEventAction + err := json.Unmarshal(a, &action) + if err != nil { + return err + } + if slices.Contains(action.Rules, ruleName) { + var rules []string + for _, r := range action.Rules { + if r != ruleName { + rules = append(rules, r) + } + } + action.Rules = util.RemoveDuplicates(rules, false) + buf, err := json.Marshal(action) + if err != nil { + return err + } + return bucket.Put([]byte(action.Name), buf) + } + return nil +} + +func (p *BoltProvider) addUserToGroupMapping(username, groupname string, bucket *bolt.Bucket) error { + g := bucket.Get([]byte(groupname)) + if g == nil { + return util.NewGenericError(fmt.Sprintf("group %q does not exist", groupname)) + } + var group Group + err := json.Unmarshal(g, &group) + if err != nil { + return err + } + if !slices.Contains(group.Users, username) { + group.Users = append(group.Users, username) + buf, err := json.Marshal(group) + if err != nil { + return err + } + return bucket.Put([]byte(group.Name), buf) + } + return nil +} + +func (p *BoltProvider) removeUserFromGroupMapping(username, groupname string, bucket *bolt.Bucket) error { + g := bucket.Get([]byte(groupname)) + if g == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", groupname)) + } + var group Group + err := json.Unmarshal(g, &group) + if err != nil { + return err + } + var users []string + for _, u := range group.Users { + if u != username { + users = append(users, u) + } + } + group.Users = util.RemoveDuplicates(users, false) + buf, err := json.Marshal(group) + if err != nil { + return err + } + return bucket.Put([]byte(group.Name), buf) +} + +func (p *BoltProvider) addAdminToGroupMapping(username, groupname string, bucket *bolt.Bucket) error { + g := bucket.Get([]byte(groupname)) + if g == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", groupname)) + } + var group Group + err := json.Unmarshal(g, &group) + if err != nil { + return err + } + if !slices.Contains(group.Admins, username) { + group.Admins = append(group.Admins, username) + buf, err := json.Marshal(group) + if err != nil { + return err + } + return bucket.Put([]byte(group.Name), buf) + } + return nil +} + +func (p *BoltProvider) removeAdminFromGroupMapping(username, groupname string, bucket *bolt.Bucket) error { + g := bucket.Get([]byte(groupname)) + if g == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", groupname)) + } + var group Group + err := json.Unmarshal(g, &group) + if err != nil { + return err + } + var admins []string + for _, a := range group.Admins { + if a != username { + admins = append(admins, a) + } + } + group.Admins = util.RemoveDuplicates(admins, false) + buf, err := json.Marshal(group) + if err != nil { + return err + } + return bucket.Put([]byte(group.Name), buf) +} + +func (p *BoltProvider) removeGroupFromAdminMapping(groupName, adminName string, bucket *bolt.Bucket) error { + var a []byte + if a = bucket.Get([]byte(adminName)); a == nil { + // the admin does not exist so there is no associated group + return nil + } + var admin Admin + err := json.Unmarshal(a, &admin) + if err != nil { + return err + } + var newGroups []AdminGroupMapping + for _, g := range admin.Groups { + if g.Name != groupName { + newGroups = append(newGroups, g) + } + } + admin.Groups = newGroups + buf, err := json.Marshal(admin) + if err != nil { + return err + } + return bucket.Put([]byte(adminName), buf) +} + +func (p *BoltProvider) addRelationToFolderMapping(folderName string, user *User, group *Group, bucket *bolt.Bucket) error { + f := bucket.Get([]byte(folderName)) + if f == nil { + return util.NewGenericError(fmt.Sprintf("folder %q does not exist", folderName)) + } + var folder vfs.BaseVirtualFolder + err := json.Unmarshal(f, &folder) + if err != nil { + return err + } + updated := false + if user != nil && !slices.Contains(folder.Users, user.Username) { + folder.Users = append(folder.Users, user.Username) + updated = true + } + if group != nil && !slices.Contains(folder.Groups, group.Name) { + folder.Groups = append(folder.Groups, group.Name) + updated = true + } + if !updated { + return nil + } + buf, err := json.Marshal(folder) + if err != nil { + return err + } + return bucket.Put([]byte(folder.Name), buf) +} + +func (p *BoltProvider) removeRelationFromFolderMapping(folder vfs.VirtualFolder, username, groupname string, + bucket *bolt.Bucket, +) error { + var f []byte + if f = bucket.Get([]byte(folder.Name)); f == nil { + // the folder does not exist so there is no associated user/group + return nil + } + var baseFolder vfs.BaseVirtualFolder + err := json.Unmarshal(f, &baseFolder) + if err != nil { + return err + } + found := false + if username != "" { + found = true + var newUserMapping []string + for _, u := range baseFolder.Users { + if u != username { + newUserMapping = append(newUserMapping, u) + } + } + baseFolder.Users = newUserMapping + } + if groupname != "" { + found = true + var newGroupMapping []string + for _, g := range baseFolder.Groups { + if g != groupname { + newGroupMapping = append(newGroupMapping, g) + } + } + baseFolder.Groups = newGroupMapping + } + if !found { + return nil + } + buf, err := json.Marshal(baseFolder) + if err != nil { + return err + } + return bucket.Put([]byte(folder.Name), buf) +} + +func (p *BoltProvider) updateUserRelations(tx *bolt.Tx, user *User, oldUser User) error { + foldersBucket, err := p.getFoldersBucket(tx) + if err != nil { + return err + } + groupsBucket, err := p.getGroupsBucket(tx) + if err != nil { + return err + } + rolesBucket, err := p.getRolesBucket(tx) + if err != nil { + return err + } + for idx := range oldUser.VirtualFolders { + err = p.removeRelationFromFolderMapping(oldUser.VirtualFolders[idx], oldUser.Username, "", foldersBucket) + if err != nil { + return err + } + } + for idx := range oldUser.Groups { + err = p.removeUserFromGroupMapping(user.Username, oldUser.Groups[idx].Name, groupsBucket) + if err != nil { + return err + } + } + if err = p.removeUserFromRole(oldUser.Username, oldUser.Role, rolesBucket); err != nil { + return err + } + for idx := range user.VirtualFolders { + err = p.addRelationToFolderMapping(user.VirtualFolders[idx].Name, user, nil, foldersBucket) + if err != nil { + return err + } + } + for idx := range user.Groups { + err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name, groupsBucket) + if err != nil { + return err + } + } + return p.addUserToRole(user.Username, user.Role, rolesBucket) +} + +func (p *BoltProvider) adminExistsInternal(tx *bolt.Tx, username string) error { + bucket, err := p.getAdminsBucket(tx) + if err != nil { + return err + } + a := bucket.Get([]byte(username)) + if a == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("admin %v does not exist", username)) + } + return nil +} + +func (p *BoltProvider) userExistsInternal(tx *bolt.Tx, username string) error { + bucket, err := p.getUsersBucket(tx) + if err != nil { + return err + } + u := bucket.Get([]byte(username)) + if u == nil { + return util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + return nil +} + +func (p *BoltProvider) deleteRelatedShares(tx *bolt.Tx, username string) error { + bucket, err := p.getSharesBucket(tx) + if err != nil { + return err + } + var toRemove []string + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var share Share + err = json.Unmarshal(v, &share) + if err != nil { + return err + } + if share.Username == username { + toRemove = append(toRemove, share.ShareID) + } + } + + for _, k := range toRemove { + if err := bucket.Delete([]byte(k)); err != nil { + return err + } + } + + return nil +} + +func (p *BoltProvider) deleteRelatedAPIKey(tx *bolt.Tx, username string, scope APIKeyScope) error { + bucket, err := p.getAPIKeysBucket(tx) + if err != nil { + return err + } + var toRemove []string + cursor := bucket.Cursor() + for k, v := cursor.First(); k != nil; k, v = cursor.Next() { + var apiKey APIKey + err = json.Unmarshal(v, &apiKey) + if err != nil { + return err + } + if scope == APIKeyScopeUser { + if apiKey.User == username { + toRemove = append(toRemove, apiKey.KeyID) + } + } else { + if apiKey.Admin == username { + toRemove = append(toRemove, apiKey.KeyID) + } + } + } + + for _, k := range toRemove { + if err := bucket.Delete([]byte(k)); err != nil { + return err + } + } + + return nil +} + +func (p *BoltProvider) getSharesBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + + bucket := tx.Bucket(sharesBucket) + if bucket == nil { + err = errors.New("unable to find shares bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getAPIKeysBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + + bucket := tx.Bucket(apiKeysBucket) + if bucket == nil { + err = errors.New("unable to find api keys bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getAdminsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + + bucket := tx.Bucket(adminsBucket) + if bucket == nil { + err = errors.New("unable to find admins bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getUsersBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(usersBucket) + if bucket == nil { + err = errors.New("unable to find users bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getGroupsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(groupsBucket) + if bucket == nil { + err = fmt.Errorf("unable to find groups bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getRolesBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(rolesBucket) + if bucket == nil { + err = fmt.Errorf("unable to find roles bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getIPListsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(rolesBucket) + if bucket == nil { + err = fmt.Errorf("unable to find IP lists bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getFoldersBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(foldersBucket) + if bucket == nil { + err = fmt.Errorf("unable to find folders bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getActionsBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(actionsBucket) + if bucket == nil { + err = fmt.Errorf("unable to find event actions bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func (p *BoltProvider) getRulesBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + var err error + bucket := tx.Bucket(rulesBucket) + if bucket == nil { + err = fmt.Errorf("unable to find event rules bucket, bolt database structure not correcly defined") + } + return bucket, err +} + +func getBoltDatabaseVersion(dbHandle *bolt.DB) (schemaVersion, error) { + var dbVersion schemaVersion + err := dbHandle.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(dbVersionBucket) + if bucket == nil { + return fmt.Errorf("unable to find database schema version bucket") + } + v := bucket.Get(dbVersionKey) + if v == nil { + dbVersion = schemaVersion{ + Version: 33, + } + return nil + } + return json.Unmarshal(v, &dbVersion) + }) + return dbVersion, err +} + +func updateBoltDatabaseVersion(dbHandle *bolt.DB, version int) error { + err := dbHandle.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket(dbVersionBucket) + if bucket == nil { + return fmt.Errorf("unable to find database schema version bucket") + } + newDbVersion := schemaVersion{ + Version: version, + } + buf, err := json.Marshal(newDbVersion) + if err != nil { + return err + } + return bucket.Put(dbVersionKey, buf) + }) + return err +} diff --git a/internal/dataprovider/bolt_disabled.go b/internal/dataprovider/bolt_disabled.go new file mode 100644 index 00000000..0ec5030a --- /dev/null +++ b/internal/dataprovider/bolt_disabled.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nobolt + +package dataprovider + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-bolt") +} + +func initializeBoltProvider(_ string) error { + return errors.New("bolt disabled at build time") +} diff --git a/internal/dataprovider/cachedpassword.go b/internal/dataprovider/cachedpassword.go new file mode 100644 index 00000000..7ba6c051 --- /dev/null +++ b/internal/dataprovider/cachedpassword.go @@ -0,0 +1,169 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + cachedUserPasswords credentialsCache + cachedAdminPasswords credentialsCache + cachedAPIKeys credentialsCache +) + +func init() { + cachedUserPasswords = credentialsCache{ + name: "users", + sizeLimit: 500, + cache: make(map[string]credentialObject), + } + cachedAdminPasswords = credentialsCache{ + name: "admins", + sizeLimit: 100, + cache: make(map[string]credentialObject), + } + cachedAPIKeys = credentialsCache{ + name: "API keys", + sizeLimit: 500, + cache: make(map[string]credentialObject), + } +} + +// CheckCachedUserPassword is an utility method used only in test cases +func CheckCachedUserPassword(username, password, hash string) (bool, bool) { + return cachedUserPasswords.Check(username, password, hash) +} + +type credentialObject struct { + key string + hash string + password string + usedAt *atomic.Int64 +} + +type credentialsCache struct { + name string + sizeLimit int + sync.RWMutex + cache map[string]credentialObject +} + +func (c *credentialsCache) Add(username, password, hash string) { + if !config.PasswordCaching || username == "" || password == "" || hash == "" { + return + } + + c.Lock() + defer c.Unlock() + + obj := credentialObject{ + key: username, + hash: hash, + password: password, + usedAt: &atomic.Int64{}, + } + obj.usedAt.Store(util.GetTimeAsMsSinceEpoch(time.Now())) + + c.cache[username] = obj +} + +func (c *credentialsCache) Remove(username string) { + if !config.PasswordCaching { + return + } + + c.Lock() + defer c.Unlock() + + delete(c.cache, username) +} + +// Check returns if the username is found and if the password match +func (c *credentialsCache) Check(username, password, hash string) (bool, bool) { + if username == "" || password == "" || hash == "" { + return false, false + } + + c.RLock() + defer c.RUnlock() + + creds, ok := c.cache[username] + if !ok { + return false, false + } + if creds.hash != hash { + creds.usedAt.Store(0) + return false, false + } + match := creds.password == password + if match { + creds.usedAt.Store(util.GetTimeAsMsSinceEpoch(time.Now())) + } + return true, match +} + +func (c *credentialsCache) count() int { + c.RLock() + defer c.RUnlock() + + return len(c.cache) +} + +func (c *credentialsCache) cleanup() { + if !config.PasswordCaching { + return + } + if c.count() <= c.sizeLimit { + return + } + + c.Lock() + defer c.Unlock() + + for k, v := range c.cache { + if v.usedAt.Load() < util.GetTimeAsMsSinceEpoch(time.Now().Add(-60*time.Minute)) { + delete(c.cache, k) + } + } + providerLog(logger.LevelDebug, "size for credentials %q after cleanup: %d", c.name, len(c.cache)) + + if len(c.cache) < c.sizeLimit*5 { + return + } + numToRemove := len(c.cache) - c.sizeLimit + providerLog(logger.LevelDebug, "additional item to remove from credentials %q: %d", c.name, numToRemove) + credentials := make([]credentialObject, 0, len(c.cache)) + for _, v := range c.cache { + credentials = append(credentials, v) + } + sort.Slice(credentials, func(i, j int) bool { + return credentials[i].usedAt.Load() < credentials[j].usedAt.Load() + }) + + for idx := range credentials { + if idx >= numToRemove { + break + } + delete(c.cache, credentials[idx].key) + } + providerLog(logger.LevelDebug, "size for credentials %q after additional cleanup: %d", c.name, len(c.cache)) +} diff --git a/internal/dataprovider/cacheduser.go b/internal/dataprovider/cacheduser.go new file mode 100644 index 00000000..9c5ff15e --- /dev/null +++ b/internal/dataprovider/cacheduser.go @@ -0,0 +1,180 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "sync" + "time" + + "github.com/drakkan/webdav" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + webDAVUsersCache *usersCache +) + +func init() { + webDAVUsersCache = &usersCache{ + users: map[string]CachedUser{}, + } +} + +// InitializeWebDAVUserCache initializes the cache for webdav users +func InitializeWebDAVUserCache(maxSize int) { + webDAVUsersCache = &usersCache{ + users: map[string]CachedUser{}, + maxSize: maxSize, + } +} + +// CachedUser adds fields useful for caching to a SFTPGo user +type CachedUser struct { + User User + Expiration time.Time + Password string + LockSystem webdav.LockSystem +} + +// IsExpired returns true if the cached user is expired +func (c *CachedUser) IsExpired() bool { + if c.Expiration.IsZero() { + return false + } + return c.Expiration.Before(time.Now()) +} + +type usersCache struct { + sync.RWMutex + users map[string]CachedUser + maxSize int +} + +func (cache *usersCache) updateLastLogin(username string) { + cache.Lock() + defer cache.Unlock() + + if cachedUser, ok := cache.users[username]; ok { + cachedUser.User.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + cache.users[username] = cachedUser + } +} + +// swapWebDAVUser updates an existing cached user with the specified one +// preserving the lock fs if possible +// FIXME: this could be racy in rare cases +func (cache *usersCache) swap(userRef *User, plainPassword string) { + user := userRef.getACopy() + err := user.LoadAndApplyGroupSettings() + + cache.Lock() + defer cache.Unlock() + + if cachedUser, ok := cache.users[user.Username]; ok { + if err != nil { + providerLog(logger.LevelDebug, "unable to load group settings, for user %q, removing from cache, err :%v", + user.Username, err) + delete(cache.users, user.Username) + return + } + if plainPassword != "" { + cachedUser.Password = plainPassword + } else { + if cachedUser.User.Password != user.Password { + providerLog(logger.LevelDebug, "current password different from the cached one for user %q, removing from cache", + user.Username) + // the password changed, the cached user is no longer valid + delete(cache.users, user.Username) + return + } + } + if cachedUser.User.isFsEqual(&user) { + // the updated user has the same fs as the cached one, we can preserve the lock filesystem + providerLog(logger.LevelDebug, "current password and fs unchanged for for user %q, swap cached one", + user.Username) + cachedUser.User = user + cache.users[user.Username] = cachedUser + } else { + // filesystem changed, the cached user is no longer valid + providerLog(logger.LevelDebug, "current fs different from the cached one for user %q, removing from cache", + user.Username) + delete(cache.users, user.Username) + } + } +} + +func (cache *usersCache) add(cachedUser *CachedUser) { + cache.Lock() + defer cache.Unlock() + + if cache.maxSize > 0 && len(cache.users) >= cache.maxSize { + var userToRemove string + var expirationTime time.Time + + for k, v := range cache.users { + if userToRemove == "" { + userToRemove = k + expirationTime = v.Expiration + continue + } + expireTime := v.Expiration + if !expireTime.IsZero() && expireTime.Before(expirationTime) { + userToRemove = k + expirationTime = expireTime + } + } + + delete(cache.users, userToRemove) + } + + if cachedUser.User.Username != "" { + cache.users[cachedUser.User.Username] = *cachedUser + } +} + +func (cache *usersCache) remove(username string) { + cache.Lock() + defer cache.Unlock() + + delete(cache.users, username) +} + +func (cache *usersCache) get(username string) (*CachedUser, bool) { + cache.RLock() + defer cache.RUnlock() + + cachedUser, ok := cache.users[username] + if !ok { + return nil, false + } + return &cachedUser, true +} + +// CacheWebDAVUser add a user to the WebDAV cache +func CacheWebDAVUser(cachedUser *CachedUser) { + webDAVUsersCache.add(cachedUser) +} + +// GetCachedWebDAVUser returns a previously cached WebDAV user +func GetCachedWebDAVUser(username string) (*CachedUser, bool) { + return webDAVUsersCache.get(username) +} + +// RemoveCachedWebDAVUser removes a cached WebDAV user +func RemoveCachedWebDAVUser(username string) { + webDAVUsersCache.remove(username) +} diff --git a/internal/dataprovider/configs.go b/internal/dataprovider/configs.go new file mode 100644 index 00000000..22a9c5a4 --- /dev/null +++ b/internal/dataprovider/configs.go @@ -0,0 +1,652 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "bytes" + "encoding/json" + "fmt" + "image/png" + "net/url" + "slices" + + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// Supported values for host keys, KEXs, ciphers, MACs +var ( + supportedHostKeyAlgos = []string{ssh.KeyAlgoRSA} + supportedPublicKeyAlgos = []string{ssh.KeyAlgoRSA, ssh.InsecureKeyAlgoDSA} //nolint:staticcheck + supportedKexAlgos = []string{ + ssh.KeyExchangeDH16SHA512, ssh.InsecureKeyExchangeDH14SHA1, ssh.InsecureKeyExchangeDH1SHA1, + ssh.InsecureKeyExchangeDHGEXSHA1, + } + supportedCiphers = []string{ + ssh.InsecureCipherAES128CBC, + ssh.InsecureCipherTripleDESCBC, + } + supportedMACs = []string{ + ssh.HMACSHA512ETM, ssh.HMACSHA512, + ssh.HMACSHA1, ssh.InsecureHMACSHA196, + } +) + +// SFTPDConfigs defines configurations for SFTPD +type SFTPDConfigs struct { + HostKeyAlgos []string `json:"host_key_algos,omitempty"` + PublicKeyAlgos []string `json:"public_key_algos,omitempty"` + KexAlgorithms []string `json:"kex_algorithms,omitempty"` + Ciphers []string `json:"ciphers,omitempty"` + MACs []string `json:"macs,omitempty"` +} + +func (c *SFTPDConfigs) isEmpty() bool { + if len(c.HostKeyAlgos) > 0 { + return false + } + if len(c.PublicKeyAlgos) > 0 { + return false + } + if len(c.KexAlgorithms) > 0 { + return false + } + if len(c.Ciphers) > 0 { + return false + } + if len(c.MACs) > 0 { + return false + } + return true +} + +// GetSupportedHostKeyAlgos returns the supported legacy host key algos +func (*SFTPDConfigs) GetSupportedHostKeyAlgos() []string { + return supportedHostKeyAlgos +} + +// GetSupportedPublicKeyAlgos returns the supported legacy public key algos +func (*SFTPDConfigs) GetSupportedPublicKeyAlgos() []string { + return supportedPublicKeyAlgos +} + +// GetSupportedKEXAlgos returns the supported KEX algos +func (*SFTPDConfigs) GetSupportedKEXAlgos() []string { + return supportedKexAlgos +} + +// GetSupportedCiphers returns the supported ciphers +func (*SFTPDConfigs) GetSupportedCiphers() []string { + return supportedCiphers +} + +// GetSupportedMACs returns the supported MACs algos +func (*SFTPDConfigs) GetSupportedMACs() []string { + return supportedMACs +} + +func (c *SFTPDConfigs) validate() error { + var hostKeyAlgos []string + for _, algo := range c.HostKeyAlgos { + if algo == ssh.CertAlgoRSAv01 { + continue + } + if !slices.Contains(supportedHostKeyAlgos, algo) { + return util.NewValidationError(fmt.Sprintf("unsupported host key algorithm %q", algo)) + } + hostKeyAlgos = append(hostKeyAlgos, algo) + } + c.HostKeyAlgos = hostKeyAlgos + var kexAlgos []string + for _, algo := range c.KexAlgorithms { + if algo == "diffie-hellman-group18-sha512" || algo == ssh.KeyExchangeDHGEXSHA256 { + continue + } + if !slices.Contains(supportedKexAlgos, algo) { + return util.NewValidationError(fmt.Sprintf("unsupported KEX algorithm %q", algo)) + } + kexAlgos = append(kexAlgos, algo) + } + c.KexAlgorithms = kexAlgos + for _, cipher := range c.Ciphers { + if slices.Contains([]string{"aes192-cbc", "aes256-cbc"}, cipher) { + continue + } + if !slices.Contains(supportedCiphers, cipher) { + return util.NewValidationError(fmt.Sprintf("unsupported cipher %q", cipher)) + } + } + for _, mac := range c.MACs { + if !slices.Contains(supportedMACs, mac) { + return util.NewValidationError(fmt.Sprintf("unsupported MAC algorithm %q", mac)) + } + } + for _, algo := range c.PublicKeyAlgos { + if !slices.Contains(supportedPublicKeyAlgos, algo) { + return util.NewValidationError(fmt.Sprintf("unsupported public key algorithm %q", algo)) + } + } + return nil +} + +func (c *SFTPDConfigs) getACopy() *SFTPDConfigs { + hostKeys := make([]string, len(c.HostKeyAlgos)) + copy(hostKeys, c.HostKeyAlgos) + publicKeys := make([]string, len(c.PublicKeyAlgos)) + copy(publicKeys, c.PublicKeyAlgos) + kexs := make([]string, len(c.KexAlgorithms)) + copy(kexs, c.KexAlgorithms) + ciphers := make([]string, len(c.Ciphers)) + copy(ciphers, c.Ciphers) + macs := make([]string, len(c.MACs)) + copy(macs, c.MACs) + + return &SFTPDConfigs{ + HostKeyAlgos: hostKeys, + PublicKeyAlgos: publicKeys, + KexAlgorithms: kexs, + Ciphers: ciphers, + MACs: macs, + } +} + +func validateSMTPSecret(secret *kms.Secret, name string) error { + if secret.IsRedacted() { + return util.NewValidationError(fmt.Sprintf("cannot save a redacted smtp %s", name)) + } + if secret.IsEncrypted() && !secret.IsValid() { + return util.NewValidationError(fmt.Sprintf("invalid encrypted smtp %s", name)) + } + if !secret.IsEmpty() && !secret.IsValidInput() { + return util.NewValidationError(fmt.Sprintf("invalid smtp %s", name)) + } + if secret.IsPlain() { + secret.SetAdditionalData("smtp") + if err := secret.Encrypt(); err != nil { + return util.NewValidationError(fmt.Sprintf("could not encrypt smtp %s: %v", name, err)) + } + } + return nil +} + +// SMTPOAuth2 defines the SMTP related OAuth2 configurations +type SMTPOAuth2 struct { + Provider int `json:"provider,omitempty"` + Tenant string `json:"tenant,omitempty"` + ClientID string `json:"client_id,omitempty"` + ClientSecret *kms.Secret `json:"client_secret,omitempty"` + RefreshToken *kms.Secret `json:"refresh_token,omitempty"` +} + +func (c *SMTPOAuth2) validate() error { + if c.Provider < 0 || c.Provider > 1 { + return util.NewValidationError("smtp oauth2: unsupported provider") + } + if c.ClientID == "" { + return util.NewI18nError( + util.NewValidationError("smtp oauth2: client id is required"), + util.I18nErrorClientIDRequired, + ) + } + if c.ClientSecret == nil || c.ClientSecret.IsEmpty() { + return util.NewI18nError( + util.NewValidationError("smtp oauth2: client secret is required"), + util.I18nErrorClientSecretRequired, + ) + } + if c.RefreshToken == nil || c.RefreshToken.IsEmpty() { + return util.NewI18nError( + util.NewValidationError("smtp oauth2: refresh token is required"), + util.I18nErrorRefreshTokenRequired, + ) + } + if err := validateSMTPSecret(c.ClientSecret, "oauth2 client secret"); err != nil { + return err + } + return validateSMTPSecret(c.RefreshToken, "oauth2 refresh token") +} + +func (c *SMTPOAuth2) getACopy() SMTPOAuth2 { + var clientSecret, refreshToken *kms.Secret + if c.ClientSecret != nil { + clientSecret = c.ClientSecret.Clone() + } + if c.RefreshToken != nil { + refreshToken = c.RefreshToken.Clone() + } + return SMTPOAuth2{ + Provider: c.Provider, + Tenant: c.Tenant, + ClientID: c.ClientID, + ClientSecret: clientSecret, + RefreshToken: refreshToken, + } +} + +// SMTPConfigs defines configuration for SMTP +type SMTPConfigs struct { + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + From string `json:"from,omitempty"` + User string `json:"user,omitempty"` + Password *kms.Secret `json:"password,omitempty"` + AuthType int `json:"auth_type,omitempty"` + Encryption int `json:"encryption,omitempty"` + Domain string `json:"domain,omitempty"` + Debug int `json:"debug,omitempty"` + OAuth2 SMTPOAuth2 `json:"oauth2"` +} + +// IsEmpty returns true if the configuration is empty +func (c *SMTPConfigs) IsEmpty() bool { + return c.Host == "" +} + +func (c *SMTPConfigs) validate() error { + if c.IsEmpty() { + return nil + } + if c.Port <= 0 || c.Port > 65535 { + return util.NewValidationError(fmt.Sprintf("smtp: invalid port %d", c.Port)) + } + if c.Password != nil && c.AuthType != 3 { + if err := validateSMTPSecret(c.Password, "password"); err != nil { + return err + } + } + if c.User == "" && c.From == "" { + return util.NewI18nError( + util.NewValidationError("smtp: from address and user cannot both be empty"), + util.I18nErrorSMTPRequiredFields, + ) + } + if c.AuthType < 0 || c.AuthType > 3 { + return util.NewValidationError(fmt.Sprintf("smtp: invalid auth type %d", c.AuthType)) + } + if c.Encryption < 0 || c.Encryption > 2 { + return util.NewValidationError(fmt.Sprintf("smtp: invalid encryption %d", c.Encryption)) + } + if c.AuthType == 3 { + c.Password = kms.NewEmptySecret() + return c.OAuth2.validate() + } + c.OAuth2 = SMTPOAuth2{} + return nil +} + +// TryDecrypt tries to decrypt the encrypted secrets +func (c *SMTPConfigs) TryDecrypt() error { + if c.Password == nil { + c.Password = kms.NewEmptySecret() + } + if c.OAuth2.ClientSecret == nil { + c.OAuth2.ClientSecret = kms.NewEmptySecret() + } + if c.OAuth2.RefreshToken == nil { + c.OAuth2.RefreshToken = kms.NewEmptySecret() + } + if err := c.Password.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt smtp password: %w", err) + } + if err := c.OAuth2.ClientSecret.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt smtp oauth2 client secret: %w", err) + } + if err := c.OAuth2.RefreshToken.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt smtp oauth2 refresh token: %w", err) + } + return nil +} + +func (c *SMTPConfigs) prepareForRendering() { + if c.Password != nil { + c.Password.Hide() + if c.Password.IsEmpty() { + c.Password = nil + } + } + if c.OAuth2.ClientSecret != nil { + c.OAuth2.ClientSecret.Hide() + if c.OAuth2.ClientSecret.IsEmpty() { + c.OAuth2.ClientSecret = nil + } + } + if c.OAuth2.RefreshToken != nil { + c.OAuth2.RefreshToken.Hide() + if c.OAuth2.RefreshToken.IsEmpty() { + c.OAuth2.RefreshToken = nil + } + } +} + +func (c *SMTPConfigs) getACopy() *SMTPConfigs { + var password *kms.Secret + if c.Password != nil { + password = c.Password.Clone() + } + return &SMTPConfigs{ + Host: c.Host, + Port: c.Port, + From: c.From, + User: c.User, + Password: password, + AuthType: c.AuthType, + Encryption: c.Encryption, + Domain: c.Domain, + Debug: c.Debug, + OAuth2: c.OAuth2.getACopy(), + } +} + +// ACMEHTTP01Challenge defines the configuration for HTTP-01 challenge type +type ACMEHTTP01Challenge struct { + Port int `json:"port"` +} + +// ACMEConfigs defines ACME related configuration +type ACMEConfigs struct { + Domain string `json:"domain"` + Email string `json:"email"` + HTTP01Challenge ACMEHTTP01Challenge `json:"http01_challenge"` + // apply the certificate for the specified protocols: + // + // 1 means HTTP + // 2 means FTP + // 4 means WebDAV + // + // Protocols can be combined + Protocols int `json:"protocols"` +} + +func (c *ACMEConfigs) isEmpty() bool { + return c.Domain == "" +} + +func (c *ACMEConfigs) validate() error { + if c.Domain == "" { + return nil + } + if c.Email == "" && !util.IsEmailValid(c.Email) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("acme: invalid email %q", c.Email)), + util.I18nErrorInvalidEmail, + ) + } + if c.HTTP01Challenge.Port <= 0 || c.HTTP01Challenge.Port > 65535 { + return util.NewValidationError(fmt.Sprintf("acme: invalid HTTP-01 challenge port %d", c.HTTP01Challenge.Port)) + } + return nil +} + +// HasProtocol returns true if the ACME certificate must be used for the specified protocol +func (c *ACMEConfigs) HasProtocol(protocol string) bool { + switch protocol { + case protocolHTTP: + return c.Protocols&1 != 0 + case protocolFTP: + return c.Protocols&2 != 0 + case protocolWebDAV: + return c.Protocols&4 != 0 + default: + return false + } +} + +func (c *ACMEConfigs) getACopy() *ACMEConfigs { + return &ACMEConfigs{ + Email: c.Email, + Domain: c.Domain, + HTTP01Challenge: ACMEHTTP01Challenge{Port: c.HTTP01Challenge.Port}, + Protocols: c.Protocols, + } +} + +// BrandingConfig defines the branding configuration +type BrandingConfig struct { + Name string `json:"name"` + ShortName string `json:"short_name"` + Logo []byte `json:"logo"` + Favicon []byte `json:"favicon"` + DisclaimerName string `json:"disclaimer_name"` + DisclaimerURL string `json:"disclaimer_url"` +} + +func (c *BrandingConfig) isEmpty() bool { + if c.Name != "" { + return false + } + if c.ShortName != "" { + return false + } + if len(c.Logo) > 0 { + return false + } + if len(c.Favicon) > 0 { + return false + } + if c.DisclaimerName != "" && c.DisclaimerURL != "" { + return false + } + return true +} + +func (*BrandingConfig) validatePNG(b []byte, maxWidth, maxHeight int) error { + if len(b) == 0 { + return nil + } + // DecodeConfig is more efficient, but I'm not sure if this would lead to + // accepting invalid images in some edge cases and performance does not + // matter here. + img, err := png.Decode(bytes.NewBuffer(b)) + if err != nil { + return util.NewI18nError( + util.NewValidationError("invalid PNG image"), + util.I18nErrorInvalidPNG, + ) + } + bounds := img.Bounds() + if bounds.Dx() > maxWidth || bounds.Dy() > maxHeight { + return util.NewI18nError( + util.NewValidationError("invalid PNG image size"), + util.I18nErrorInvalidPNGSize, + ) + } + return nil +} + +func (c *BrandingConfig) validateDisclaimerURL() error { + if c.DisclaimerURL == "" { + return nil + } + u, err := url.Parse(c.DisclaimerURL) + if err != nil { + return util.NewI18nError( + util.NewValidationError("invalid disclaimer URL"), + util.I18nErrorInvalidDisclaimerURL, + ) + } + if u.Scheme != "http" && u.Scheme != "https" { + return util.NewI18nError( + util.NewValidationError("invalid disclaimer URL scheme"), + util.I18nErrorInvalidDisclaimerURL, + ) + } + return nil +} + +func (c *BrandingConfig) validate() error { + if err := c.validateDisclaimerURL(); err != nil { + return err + } + if err := c.validatePNG(c.Logo, 512, 512); err != nil { + return err + } + return c.validatePNG(c.Favicon, 256, 256) +} + +func (c *BrandingConfig) getACopy() BrandingConfig { + logo := make([]byte, len(c.Logo)) + copy(logo, c.Logo) + favicon := make([]byte, len(c.Favicon)) + copy(favicon, c.Favicon) + + return BrandingConfig{ + Name: c.Name, + ShortName: c.ShortName, + Logo: logo, + Favicon: favicon, + DisclaimerName: c.DisclaimerName, + DisclaimerURL: c.DisclaimerURL, + } +} + +// BrandingConfigs defines the branding configuration for WebAdmin and WebClient UI +type BrandingConfigs struct { + WebAdmin BrandingConfig + WebClient BrandingConfig +} + +func (c *BrandingConfigs) isEmpty() bool { + return c.WebAdmin.isEmpty() && c.WebClient.isEmpty() +} + +func (c *BrandingConfigs) validate() error { + if err := c.WebAdmin.validate(); err != nil { + return err + } + return c.WebClient.validate() +} + +func (c *BrandingConfigs) getACopy() *BrandingConfigs { + return &BrandingConfigs{ + WebAdmin: c.WebAdmin.getACopy(), + WebClient: c.WebClient.getACopy(), + } +} + +// Configs allows to set configuration keys disabled by default without +// modifying the config file or setting env vars +type Configs struct { + SFTPD *SFTPDConfigs `json:"sftpd,omitempty"` + SMTP *SMTPConfigs `json:"smtp,omitempty"` + ACME *ACMEConfigs `json:"acme,omitempty"` + Branding *BrandingConfigs `json:"branding,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` +} + +func (c *Configs) validate() error { + if c.SFTPD != nil { + if err := c.SFTPD.validate(); err != nil { + return err + } + } + if c.SMTP != nil { + if err := c.SMTP.validate(); err != nil { + return err + } + } + if c.ACME != nil { + if err := c.ACME.validate(); err != nil { + return err + } + } + if c.Branding != nil { + if err := c.Branding.validate(); err != nil { + return err + } + } + return nil +} + +// PrepareForRendering prepares configs for rendering. +// It hides confidential data and set to nil the empty structs/secrets +// so they are not serialized +func (c *Configs) PrepareForRendering() { + if c.SFTPD != nil && c.SFTPD.isEmpty() { + c.SFTPD = nil + } + if c.SMTP != nil && c.SMTP.IsEmpty() { + c.SMTP = nil + } + if c.ACME != nil && c.ACME.isEmpty() { + c.ACME = nil + } + if c.Branding != nil && c.Branding.isEmpty() { + c.Branding = nil + } + if c.SMTP != nil { + c.SMTP.prepareForRendering() + } +} + +// SetNilsToEmpty sets nil fields to empty +func (c *Configs) SetNilsToEmpty() { + if c.SFTPD == nil { + c.SFTPD = &SFTPDConfigs{} + } + if c.SMTP == nil { + c.SMTP = &SMTPConfigs{} + } + if c.SMTP.Password == nil { + c.SMTP.Password = kms.NewEmptySecret() + } + if c.SMTP.OAuth2.ClientSecret == nil { + c.SMTP.OAuth2.ClientSecret = kms.NewEmptySecret() + } + if c.SMTP.OAuth2.RefreshToken == nil { + c.SMTP.OAuth2.RefreshToken = kms.NewEmptySecret() + } + if c.ACME == nil { + c.ACME = &ACMEConfigs{} + } + if c.Branding == nil { + c.Branding = &BrandingConfigs{} + } +} + +// RenderAsJSON implements the renderer interface used within plugins +func (c *Configs) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + config, err := provider.getConfigs() + if err != nil { + providerLog(logger.LevelError, "unable to reload config overrides before rendering as json: %v", err) + return nil, err + } + config.PrepareForRendering() + return json.Marshal(config) + } + c.PrepareForRendering() + return json.Marshal(c) +} + +func (c *Configs) getACopy() Configs { + var result Configs + if c.SFTPD != nil { + result.SFTPD = c.SFTPD.getACopy() + } + if c.SMTP != nil { + result.SMTP = c.SMTP.getACopy() + } + if c.ACME != nil { + result.ACME = c.ACME.getACopy() + } + if c.Branding != nil { + result.Branding = c.Branding.getACopy() + } + result.UpdatedAt = c.UpdatedAt + return result +} diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go new file mode 100644 index 00000000..7ecbbe2d --- /dev/null +++ b/internal/dataprovider/dataprovider.go @@ -0,0 +1,4798 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package dataprovider provides data access. +// It abstracts different data providers using a common API. +package dataprovider + +import ( + "bufio" + "bytes" + "context" + "crypto/md5" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "crypto/subtle" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path" + "path/filepath" + "regexp" + "runtime" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/GehirnInc/crypt" + "github.com/GehirnInc/crypt/apr1_crypt" + "github.com/GehirnInc/crypt/md5_crypt" + "github.com/GehirnInc/crypt/sha256_crypt" + "github.com/GehirnInc/crypt/sha512_crypt" + "github.com/alexedwards/argon2id" + "github.com/go-chi/render" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + passwordvalidator "github.com/wagslane/go-password-validator" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/pbkdf2" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/command" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + // SQLiteDataProviderName defines the name for SQLite database provider + SQLiteDataProviderName = "sqlite" + // PGSQLDataProviderName defines the name for PostgreSQL database provider + PGSQLDataProviderName = "postgresql" + // MySQLDataProviderName defines the name for MySQL database provider + MySQLDataProviderName = "mysql" + // BoltDataProviderName defines the name for bbolt key/value store provider + BoltDataProviderName = "bolt" + // MemoryDataProviderName defines the name for memory provider + MemoryDataProviderName = "memory" + // CockroachDataProviderName defines the for CockroachDB provider + CockroachDataProviderName = "cockroachdb" + // DumpVersion defines the version for the dump. + // For restore/load we support the current version and the previous one + DumpVersion = 17 + + argonPwdPrefix = "$argon2id$" + bcryptPwdPrefix = "$2a$" + pbkdf2SHA1Prefix = "$pbkdf2-sha1$" + pbkdf2SHA256Prefix = "$pbkdf2-sha256$" + pbkdf2SHA512Prefix = "$pbkdf2-sha512$" + pbkdf2SHA256B64SaltPrefix = "$pbkdf2-b64salt-sha256$" + md5cryptPwdPrefix = "$1$" + md5cryptApr1PwdPrefix = "$apr1$" + sha256cryptPwdPrefix = "$5$" + sha512cryptPwdPrefix = "$6$" + yescryptPwdPrefix = "$y$" + md5DigestPwdPrefix = "{MD5}" + sha256DigestPwdPrefix = "{SHA256}" + sha512DigestPwdPrefix = "{SHA512}" + trackQuotaDisabledError = "please enable track_quota in your configuration to use this method" + operationAdd = "add" + operationUpdate = "update" + operationDelete = "delete" + sqlPrefixValidChars = "abcdefghijklmnopqrstuvwxyz_0123456789" + maxHookResponseSize = 1048576 // 1MB +) + +// Supported algorithms for hashing passwords. +// These algorithms can be used when SFTPGo hashes a plain text password +const ( + HashingAlgoBcrypt = "bcrypt" + HashingAlgoArgon2ID = "argon2id" +) + +// ordering constants +const ( + OrderASC = "ASC" + OrderDESC = "DESC" +) + +const ( + protocolSSH = "SSH" + protocolFTP = "FTP" + protocolWebDAV = "DAV" + protocolHTTP = "HTTP" +) + +// Dump scopes +const ( + DumpScopeUsers = "users" + DumpScopeFolders = "folders" + DumpScopeGroups = "groups" + DumpScopeAdmins = "admins" + DumpScopeAPIKeys = "api_keys" + DumpScopeShares = "shares" + DumpScopeActions = "actions" + DumpScopeRules = "rules" + DumpScopeRoles = "roles" + DumpScopeIPLists = "ip_lists" + DumpScopeConfigs = "configs" +) + +const ( + fieldUsername = 1 + fieldName = 2 + fieldIPNet = 3 +) + +var ( + // SupportedProviders defines the supported data providers + SupportedProviders = []string{SQLiteDataProviderName, PGSQLDataProviderName, MySQLDataProviderName, + BoltDataProviderName, MemoryDataProviderName, CockroachDataProviderName} + // ValidPerms defines all the valid permissions for a user + ValidPerms = []string{PermAny, PermListItems, PermDownload, PermUpload, PermOverwrite, PermCreateDirs, PermRename, + PermRenameFiles, PermRenameDirs, PermDelete, PermDeleteFiles, PermDeleteDirs, PermCopy, PermCreateSymlinks, + PermChmod, PermChown, PermChtimes} + // ValidLoginMethods defines all the valid login methods + ValidLoginMethods = []string{SSHLoginMethodPublicKey, LoginMethodPassword, SSHLoginMethodPassword, + SSHLoginMethodKeyboardInteractive, SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt, + LoginMethodTLSCertificate, LoginMethodTLSCertificateAndPwd} + // SSHMultiStepsLoginMethods defines the supported Multi-Step Authentications + SSHMultiStepsLoginMethods = []string{SSHLoginMethodKeyAndPassword, SSHLoginMethodKeyAndKeyboardInt} + // ErrNoAuthTried defines the error for connection closed before authentication + ErrNoAuthTried = errors.New("no auth tried") + // ErrNotImplemented defines the error for features not supported for a particular data provider + ErrNotImplemented = errors.New("feature not supported with the configured data provider") + // ValidProtocols defines all the valid protcols + ValidProtocols = []string{protocolSSH, protocolFTP, protocolWebDAV, protocolHTTP} + // MFAProtocols defines the supported protocols for multi-factor authentication + MFAProtocols = []string{protocolHTTP, protocolSSH, protocolFTP} + // ErrNoInitRequired defines the error returned by InitProvider if no inizialization/update is required + ErrNoInitRequired = errors.New("the data provider is up to date") + // ErrInvalidCredentials defines the error to return if the supplied credentials are invalid + ErrInvalidCredentials = errors.New("invalid credentials") + // ErrLoginNotAllowedFromIP defines the error to return if login is denied from the current IP + ErrLoginNotAllowedFromIP = errors.New("login is not allowed from this IP") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") + errInvalidInput = util.NewValidationError("Invalid input. Slashes (/ ), colons (:), control characters, and reserved system names are not allowed") + tz = "" + isAdminCreated atomic.Bool + validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)} + config Config + provider Provider + sqlPlaceholders []string + internalHashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix} + hashPwdPrefixes = []string{argonPwdPrefix, bcryptPwdPrefix, pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, + pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix, md5cryptPwdPrefix, md5cryptApr1PwdPrefix, md5DigestPwdPrefix, + sha256DigestPwdPrefix, sha512DigestPwdPrefix, sha256cryptPwdPrefix, sha512cryptPwdPrefix, yescryptPwdPrefix} + pbkdfPwdPrefixes = []string{pbkdf2SHA1Prefix, pbkdf2SHA256Prefix, pbkdf2SHA512Prefix, pbkdf2SHA256B64SaltPrefix} + pbkdfPwdB64SaltPrefixes = []string{pbkdf2SHA256B64SaltPrefix} + unixPwdPrefixes = []string{md5cryptPwdPrefix, md5cryptApr1PwdPrefix, sha256cryptPwdPrefix, sha512cryptPwdPrefix, + yescryptPwdPrefix} + digestPwdPrefixes = []string{md5DigestPwdPrefix, sha256DigestPwdPrefix, sha512DigestPwdPrefix} + sharedProviders = []string{PGSQLDataProviderName, MySQLDataProviderName, CockroachDataProviderName} + logSender = "dataprovider" + sqlTableUsers string + sqlTableFolders string + sqlTableUsersFoldersMapping string + sqlTableAdmins string + sqlTableAPIKeys string + sqlTableShares string + sqlTableSharesGroupsMapping string + sqlTableDefenderHosts string + sqlTableDefenderEvents string + sqlTableActiveTransfers string + sqlTableGroups string + sqlTableUsersGroupsMapping string + sqlTableAdminsGroupsMapping string + sqlTableGroupsFoldersMapping string + sqlTableSharedSessions string + sqlTableEventsActions string + sqlTableEventsRules string + sqlTableRulesActionsMapping string + sqlTableTasks string + sqlTableNodes string + sqlTableRoles string + sqlTableIPLists string + sqlTableConfigs string + sqlTableSchemaVersion string + argon2Params *argon2id.Params + lastLoginMinDelay = 10 * time.Minute + usernameRegex = regexp.MustCompile("^[a-zA-Z0-9-_.~]+$") + tempPath string + allowSelfConnections int + fnReloadRules FnReloadRules + fnRemoveRule FnRemoveRule + fnHandleRuleForProviderEvent FnHandleRuleForProviderEvent +) + +func initSQLTables() { + sqlTableUsers = "users" + sqlTableFolders = "folders" + sqlTableUsersFoldersMapping = "users_folders_mapping" + sqlTableAdmins = "admins" + sqlTableAPIKeys = "api_keys" + sqlTableShares = "shares" + sqlTableSharesGroupsMapping = "shares_groups_mapping" + sqlTableDefenderHosts = "defender_hosts" + sqlTableDefenderEvents = "defender_events" + sqlTableActiveTransfers = "active_transfers" + sqlTableGroups = "groups" + sqlTableUsersGroupsMapping = "users_groups_mapping" + sqlTableGroupsFoldersMapping = "groups_folders_mapping" + sqlTableAdminsGroupsMapping = "admins_groups_mapping" + sqlTableSharedSessions = "shared_sessions" + sqlTableEventsActions = "events_actions" + sqlTableEventsRules = "events_rules" + sqlTableRulesActionsMapping = "rules_actions_mapping" + sqlTableTasks = "tasks" + sqlTableNodes = "nodes" + sqlTableRoles = "roles" + sqlTableIPLists = "ip_lists" + sqlTableConfigs = "configurations" + sqlTableSchemaVersion = "schema_version" +} + +// FnReloadRules defined the callback to reload event rules +type FnReloadRules func() + +// FnRemoveRule defines the callback to remove an event rule +type FnRemoveRule func(name string) + +// FnHandleRuleForProviderEvent define the callback to handle event rules for provider events +type FnHandleRuleForProviderEvent func(operation, executor, ip, objectType, objectName, role string, object plugin.Renderer) + +// SetEventRulesCallbacks sets the event rules callbacks +func SetEventRulesCallbacks(reload FnReloadRules, remove FnRemoveRule, handle FnHandleRuleForProviderEvent) { + fnReloadRules = reload + fnRemoveRule = remove + fnHandleRuleForProviderEvent = handle +} + +type schemaVersion struct { + Version int +} + +// BcryptOptions defines the options for bcrypt password hashing +type BcryptOptions struct { + Cost int `json:"cost" mapstructure:"cost"` +} + +// Argon2Options defines the options for argon2 password hashing +type Argon2Options struct { + Memory uint32 `json:"memory" mapstructure:"memory"` + Iterations uint32 `json:"iterations" mapstructure:"iterations"` + Parallelism uint8 `json:"parallelism" mapstructure:"parallelism"` +} + +// PasswordHashing defines the configuration for password hashing +type PasswordHashing struct { + BcryptOptions BcryptOptions `json:"bcrypt_options" mapstructure:"bcrypt_options"` + Argon2Options Argon2Options `json:"argon2_options" mapstructure:"argon2_options"` + // Algorithm to use for hashing passwords. Available algorithms: argon2id, bcrypt. Default: bcrypt + Algo string `json:"algo" mapstructure:"algo"` +} + +// PasswordValidationRules defines the password validation rules +type PasswordValidationRules struct { + // MinEntropy defines the minimum password entropy. + // 0 means disabled, any password will be accepted. + // Take a look at the following link for more details + // https://github.com/wagslane/go-password-validator#what-entropy-value-should-i-use + MinEntropy float64 `json:"min_entropy" mapstructure:"min_entropy"` +} + +// PasswordValidation defines the password validation rules for admins and protocol users +type PasswordValidation struct { + // Password validation rules for SFTPGo admin users + Admins PasswordValidationRules `json:"admins" mapstructure:"admins"` + // Password validation rules for SFTPGo protocol users + Users PasswordValidationRules `json:"users" mapstructure:"users"` +} + +type wrappedFolder struct { + Folder vfs.BaseVirtualFolder +} + +func (w *wrappedFolder) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + folder, err := provider.getFolderByName(w.Folder.Name) + if err != nil { + providerLog(logger.LevelError, "unable to reload folder before rendering as json: %v", err) + return nil, err + } + folder.PrepareForRendering() + return json.Marshal(folder) + } + w.Folder.PrepareForRendering() + return json.Marshal(w.Folder) +} + +// ObjectsActions defines the action to execute on user create, update, delete for the specified objects +type ObjectsActions struct { + // Valid values are add, update, delete. Empty slice to disable + ExecuteOn []string `json:"execute_on" mapstructure:"execute_on"` + // Valid values are user, admin, api_key + ExecuteFor []string `json:"execute_for" mapstructure:"execute_for"` + // Absolute path to an external program or an HTTP URL + Hook string `json:"hook" mapstructure:"hook"` +} + +// ProviderStatus defines the provider status +type ProviderStatus struct { + Driver string `json:"driver"` + IsActive bool `json:"is_active"` + Error string `json:"error"` +} + +// Config defines the provider configuration +type Config struct { + // Driver name, must be one of the SupportedProviders + Driver string `json:"driver" mapstructure:"driver"` + // Database name. For driver sqlite this can be the database name relative to the config dir + // or the absolute path to the SQLite database. + Name string `json:"name" mapstructure:"name"` + // Database host. For postgresql and cockroachdb driver you can specify multiple hosts separated by commas + Host string `json:"host" mapstructure:"host"` + // Database port + Port int `json:"port" mapstructure:"port"` + // Database username + Username string `json:"username" mapstructure:"username"` + // Database password + Password string `json:"password" mapstructure:"password"` + // Used for drivers mysql and postgresql. + // 0 disable SSL/TLS connections. + // 1 require ssl. + // 2 set ssl mode to verify-ca for driver postgresql and skip-verify for driver mysql. + // 3 set ssl mode to verify-full for driver postgresql and preferred for driver mysql. + SSLMode int `json:"sslmode" mapstructure:"sslmode"` + // Used for drivers mysql, postgresql and cockroachdb. Set to true to disable SNI + DisableSNI bool `json:"disable_sni" mapstructure:"disable_sni"` + // TargetSessionAttrs is a postgresql and cockroachdb specific option. + // It determines whether the session must have certain properties to be acceptable. + // It's typically used in combination with multiple host names to select the first + // acceptable alternative among several hosts + TargetSessionAttrs string `json:"target_session_attrs" mapstructure:"target_session_attrs"` + // Path to the root certificate authority used to verify that the server certificate was signed by a trusted CA + RootCert string `json:"root_cert" mapstructure:"root_cert"` + // Path to the client certificate for two-way TLS authentication + ClientCert string `json:"client_cert" mapstructure:"client_cert"` + // Path to the client key for two-way TLS authentication + ClientKey string `json:"client_key" mapstructure:"client_key"` + // Custom database connection string. + // If not empty this connection string will be used instead of build one using the previous parameters + ConnectionString string `json:"connection_string" mapstructure:"connection_string"` + // prefix for SQL tables + SQLTablesPrefix string `json:"sql_tables_prefix" mapstructure:"sql_tables_prefix"` + // Set the preferred way to track users quota between the following choices: + // 0, disable quota tracking. REST API to scan user dir and update quota will do nothing + // 1, quota is updated each time a user upload or delete a file even if the user has no quota restrictions + // 2, quota is updated each time a user upload or delete a file but only for users with quota restrictions + // and for virtual folders. + // With this configuration the "quota scan" REST API can still be used to periodically update space usage + // for users without quota restrictions + TrackQuota int `json:"track_quota" mapstructure:"track_quota"` + // Sets the maximum number of open connections for mysql and postgresql driver. + // Default 0 (unlimited) + PoolSize int `json:"pool_size" mapstructure:"pool_size"` + // Users default base directory. + // If no home dir is defined while adding a new user, and this value is + // a valid absolute path, then the user home dir will be automatically + // defined as the path obtained joining the base dir and the username + UsersBaseDir string `json:"users_base_dir" mapstructure:"users_base_dir"` + // Actions to execute on objects add, update, delete. + // The supported objects are user, admin, api_key. + // Update action will not be fired for internal updates such as the last login or the user quota fields. + Actions ObjectsActions `json:"actions" mapstructure:"actions"` + // Absolute path to an external program or an HTTP URL to invoke for users authentication. + // Leave empty to use builtin authentication. + // If the authentication succeed the user will be automatically added/updated inside the defined data provider. + // Actions defined for user added/updated will not be executed in this case. + // This method is slower than built-in authentication methods, but it's very flexible as anyone can + // easily write his own authentication hooks. + ExternalAuthHook string `json:"external_auth_hook" mapstructure:"external_auth_hook"` + // ExternalAuthScope defines the scope for the external authentication hook. + // - 0 means all supported authentication scopes, the external hook will be executed for password, + // public key, keyboard interactive authentication and TLS certificates + // - 1 means passwords only + // - 2 means public keys only + // - 4 means keyboard interactive only + // - 8 means TLS certificates only + // you can combine the scopes, for example 3 means password and public key, 5 password and keyboard + // interactive and so on + ExternalAuthScope int `json:"external_auth_scope" mapstructure:"external_auth_scope"` + // Absolute path to an external program or an HTTP URL to invoke just before the user login. + // This program/URL allows to modify or create the user trying to login. + // It is useful if you have users with dynamic fields to update just before the login. + // Please note that if you want to create a new user, the pre-login hook response must + // include all the mandatory user fields. + // + // The pre-login hook must finish within 30 seconds. + // + // If an error happens while executing the "PreLoginHook" then login will be denied. + // PreLoginHook and ExternalAuthHook are mutally exclusive. + // Leave empty to disable. + PreLoginHook string `json:"pre_login_hook" mapstructure:"pre_login_hook"` + // Absolute path to an external program or an HTTP URL to invoke after the user login. + // Based on the configured scope you can choose if notify failed or successful logins + // or both + PostLoginHook string `json:"post_login_hook" mapstructure:"post_login_hook"` + // PostLoginScope defines the scope for the post-login hook. + // - 0 means notify both failed and successful logins + // - 1 means notify failed logins + // - 2 means notify successful logins + PostLoginScope int `json:"post_login_scope" mapstructure:"post_login_scope"` + // Absolute path to an external program or an HTTP URL to invoke just before password + // authentication. This hook allows you to externally check the provided password, + // its main use case is to allow to easily support things like password+OTP for protocols + // without keyboard interactive support such as FTP and WebDAV. You can ask your users + // to login using a string consisting of a fixed password and a One Time Token, you + // can verify the token inside the hook and ask to SFTPGo to verify the fixed part. + CheckPasswordHook string `json:"check_password_hook" mapstructure:"check_password_hook"` + // CheckPasswordScope defines the scope for the check password hook. + // - 0 means all protocols + // - 1 means SSH + // - 2 means FTP + // - 4 means WebDAV + // you can combine the scopes, for example 6 means FTP and WebDAV + CheckPasswordScope int `json:"check_password_scope" mapstructure:"check_password_scope"` + // Defines how the database will be initialized/updated: + // - 0 means automatically + // - 1 means manually using the initprovider sub-command + UpdateMode int `json:"update_mode" mapstructure:"update_mode"` + // PasswordHashing defines the configuration for password hashing + PasswordHashing PasswordHashing `json:"password_hashing" mapstructure:"password_hashing"` + // PasswordValidation defines the password validation rules + PasswordValidation PasswordValidation `json:"password_validation" mapstructure:"password_validation"` + // Verifying argon2 passwords has a high memory and computational cost, + // by enabling, in memory, password caching you reduce this cost. + PasswordCaching bool `json:"password_caching" mapstructure:"password_caching"` + // DelayedQuotaUpdate defines the number of seconds to accumulate quota updates. + // If there are a lot of close uploads, accumulating quota updates can save you many + // queries to the data provider. + // If you want to track quotas, a scheduled quota update is recommended in any case, the stored + // quota size may be incorrect for several reasons, such as an unexpected shutdown, temporary provider + // failures, file copied outside of SFTPGo, and so on. + // 0 means immediate quota update. + DelayedQuotaUpdate int `json:"delayed_quota_update" mapstructure:"delayed_quota_update"` + // If enabled, a default admin user with username "admin" and password "password" will be created + // on first start. + // You can also create the first admin user by using the web interface or by loading initial data. + CreateDefaultAdmin bool `json:"create_default_admin" mapstructure:"create_default_admin"` + // Rules for usernames and folder names: + // - 0 means no rules + // - 1 means you can use any UTF-8 character. The names are used in URIs for REST API and Web admin. + // By default only unreserved URI characters are allowed: ALPHA / DIGIT / "-" / "." / "_" / "~". + // - 2 means names are converted to lowercase before saving/matching and so case + // insensitive matching is possible + // - 4 means trimming trailing and leading white spaces before saving/matching + // Rules can be combined, for example 3 means both converting to lowercase and allowing any UTF-8 character. + // Enabling these options for existing installations could be backward incompatible, some users + // could be unable to login, for example existing users with mixed cases in their usernames. + // You have to ensure that all existing users respect the defined rules. + NamingRules int `json:"naming_rules" mapstructure:"naming_rules"` + // If the data provider is shared across multiple SFTPGo instances, set this parameter to 1. + // MySQL, PostgreSQL and CockroachDB can be shared, this setting is ignored for other data + // providers. For shared data providers, SFTPGo periodically reloads the latest updated users, + // based on the "updated_at" field, and updates its internal caches if users are updated from + // a different instance. This check, if enabled, is executed every 10 minutes. + // For shared data providers, active transfers are persisted in the database and thus + // quota checks between ongoing transfers will work cross multiple instances + IsShared int `json:"is_shared" mapstructure:"is_shared"` + // Node defines the configuration for this cluster node. + // Ignored if the provider is not shared/shareable + Node NodeConfig `json:"node" mapstructure:"node"` + // Path to the backup directory. This can be an absolute path or a path relative to the config dir + BackupsPath string `json:"backups_path" mapstructure:"backups_path"` +} + +// GetShared returns the provider share mode. +// This method is called before the provider is initialized +func (c *Config) GetShared() int { + if !slices.Contains(sharedProviders, c.Driver) { + return 0 + } + return c.IsShared +} + +func (c *Config) convertName(name string) string { + if c.NamingRules <= 1 { + return name + } + if c.NamingRules&2 != 0 { + name = strings.ToLower(name) + } + if c.NamingRules&4 != 0 { + name = strings.TrimSpace(name) + } + + return name +} + +// IsDefenderSupported returns true if the configured provider supports the defender +func (c *Config) IsDefenderSupported() bool { + switch c.Driver { + case MySQLDataProviderName, PGSQLDataProviderName, CockroachDataProviderName: + return true + default: + return false + } +} + +func (c *Config) requireCustomTLSForMySQL() bool { + if config.DisableSNI { + return config.SSLMode != 0 + } + if config.RootCert != "" && util.IsFileInputValid(config.RootCert) { + return config.SSLMode != 0 + } + if config.ClientCert != "" && config.ClientKey != "" && util.IsFileInputValid(config.ClientCert) && + util.IsFileInputValid(config.ClientKey) { + return config.SSLMode != 0 + } + return false +} + +func (c *Config) doBackup() (string, error) { + now := time.Now().UTC() + outputFile := filepath.Join(c.BackupsPath, fmt.Sprintf("backup_%s_%d.json", now.Weekday(), now.Hour())) + providerLog(logger.LevelDebug, "starting backup to file %q", outputFile) + err := os.MkdirAll(filepath.Dir(outputFile), 0700) + if err != nil { + providerLog(logger.LevelError, "unable to create backup dir %q: %v", outputFile, err) + return outputFile, fmt.Errorf("unable to create backup dir: %w", err) + } + backup, err := DumpData(nil) + if err != nil { + providerLog(logger.LevelError, "unable to execute backup: %v", err) + return outputFile, fmt.Errorf("unable to dump backup data: %w", err) + } + dump, err := json.Marshal(backup) + if err != nil { + providerLog(logger.LevelError, "unable to marshal backup as JSON: %v", err) + return outputFile, fmt.Errorf("unable to marshal backup data as JSON: %w", err) + } + err = os.WriteFile(outputFile, dump, 0600) + if err != nil { + providerLog(logger.LevelError, "unable to save backup: %v", err) + return outputFile, fmt.Errorf("unable to save backup: %w", err) + } + providerLog(logger.LevelDebug, "backup saved to %q", outputFile) + return outputFile, nil +} + +// SetTZ sets the configured timezone. +func SetTZ(val string) { + tz = val +} + +// UseLocalTime returns true if local time should be used instead of UTC. +func UseLocalTime() bool { + return tz == "local" +} + +// ExecuteBackup executes a backup +func ExecuteBackup() (string, error) { + return config.doBackup() +} + +// ConvertName converts the given name based on the configured rules +func ConvertName(name string) string { + return config.convertName(name) +} + +// ActiveTransfer defines an active protocol transfer +type ActiveTransfer struct { + ID int64 + Type int + ConnID string + Username string + FolderName string + IP string + TruncatedSize int64 + CurrentULSize int64 + CurrentDLSize int64 + CreatedAt int64 + UpdatedAt int64 +} + +// TransferQuota stores the allowed transfer quota fields +type TransferQuota struct { + ULSize int64 + DLSize int64 + TotalSize int64 + AllowedULSize int64 + AllowedDLSize int64 + AllowedTotalSize int64 +} + +// HasSizeLimits returns true if any size limit is set +func (q *TransferQuota) HasSizeLimits() bool { + return q.AllowedDLSize > 0 || q.AllowedULSize > 0 || q.AllowedTotalSize > 0 +} + +// HasUploadSpace returns true if there is transfer upload space available +func (q *TransferQuota) HasUploadSpace() bool { + if q.TotalSize <= 0 && q.ULSize <= 0 { + return true + } + if q.TotalSize > 0 { + return q.AllowedTotalSize > 0 + } + return q.AllowedULSize > 0 +} + +// HasDownloadSpace returns true if there is transfer download space available +func (q *TransferQuota) HasDownloadSpace() bool { + if q.TotalSize <= 0 && q.DLSize <= 0 { + return true + } + if q.TotalSize > 0 { + return q.AllowedTotalSize > 0 + } + return q.AllowedDLSize > 0 +} + +// DefenderEntry defines a defender entry +type DefenderEntry struct { + ID int64 `json:"-"` + IP string `json:"ip"` + Score int `json:"score,omitempty"` + BanTime time.Time `json:"ban_time,omitempty"` +} + +// GetID returns an unique ID for a defender entry +func (d *DefenderEntry) GetID() string { + return hex.EncodeToString([]byte(d.IP)) +} + +// GetBanTime returns the ban time for a defender entry as string +func (d *DefenderEntry) GetBanTime() string { + if d.BanTime.IsZero() { + return "" + } + return d.BanTime.UTC().Format(time.RFC3339) +} + +// MarshalJSON returns the JSON encoding of a DefenderEntry. +func (d *DefenderEntry) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + ID string `json:"id"` + IP string `json:"ip"` + Score int `json:"score,omitempty"` + BanTime string `json:"ban_time,omitempty"` + }{ + ID: d.GetID(), + IP: d.IP, + Score: d.Score, + BanTime: d.GetBanTime(), + }) +} + +// BackupData defines the structure for the backup/restore files +type BackupData struct { + Users []User `json:"users"` + Groups []Group `json:"groups"` + Folders []vfs.BaseVirtualFolder `json:"folders"` + Admins []Admin `json:"admins"` + APIKeys []APIKey `json:"api_keys"` + Shares []Share `json:"shares"` + EventActions []BaseEventAction `json:"event_actions"` + EventRules []EventRule `json:"event_rules"` + Roles []Role `json:"roles"` + IPLists []IPListEntry `json:"ip_lists"` + Configs *Configs `json:"configs"` + Version int `json:"version"` +} + +// HasFolder returns true if the folder with the given name is included +func (d *BackupData) HasFolder(name string) bool { + for _, folder := range d.Folders { + if folder.Name == name { + return true + } + } + return false +} + +type checkPasswordRequest struct { + Username string `json:"username"` + IP string `json:"ip"` + Password string `json:"password"` + Protocol string `json:"protocol"` +} + +type checkPasswordResponse struct { + // 0 KO, 1 OK, 2 partial success, -1 not executed + Status int `json:"status"` + // for status = 2 this is the password to check against the one stored + // inside the SFTPGo data provider + ToVerify string `json:"to_verify"` +} + +// GetQuotaTracking returns the configured mode for user's quota tracking +func GetQuotaTracking() int { + return config.TrackQuota +} + +// HasUsersBaseDir returns true if users base dir is set +func HasUsersBaseDir() bool { + return config.UsersBaseDir != "" +} + +// Provider defines the interface that data providers must implement. +type Provider interface { + validateUserAndPass(username, password, ip, protocol string) (User, error) + validateUserAndPubKey(username string, pubKey []byte, isSSHCert bool) (User, string, error) + validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) + updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error + updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error + getUsedQuota(username string) (int, int64, int64, int64, error) + userExists(username, role string) (User, error) + addUser(user *User) error + updateUser(user *User) error + deleteUser(user User, softDelete bool) error + updateUserPassword(username, password string) error // used internally when converting passwords from other hash + getUsers(limit int, offset int, order, role string) ([]User, error) + dumpUsers() ([]User, error) + getRecentlyUpdatedUsers(after int64) ([]User, error) + getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) + updateLastLogin(username string) error + updateAdminLastLogin(username string) error + setUpdatedAt(username string) + getAdminSignature(username string) (string, error) + getUserSignature(username string) (string, error) + getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) + getFolderByName(name string) (vfs.BaseVirtualFolder, error) + addFolder(folder *vfs.BaseVirtualFolder) error + updateFolder(folder *vfs.BaseVirtualFolder) error + deleteFolder(folder vfs.BaseVirtualFolder) error + updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error + getUsedFolderQuota(name string) (int, int64, error) + dumpFolders() ([]vfs.BaseVirtualFolder, error) + getGroups(limit, offset int, order string, minimal bool) ([]Group, error) + getGroupsWithNames(names []string) ([]Group, error) + getUsersInGroups(names []string) ([]string, error) + groupExists(name string) (Group, error) + addGroup(group *Group) error + updateGroup(group *Group) error + deleteGroup(group Group) error + dumpGroups() ([]Group, error) + adminExists(username string) (Admin, error) + addAdmin(admin *Admin) error + updateAdmin(admin *Admin) error + deleteAdmin(admin Admin) error + getAdmins(limit int, offset int, order string) ([]Admin, error) + dumpAdmins() ([]Admin, error) + validateAdminAndPass(username, password, ip string) (Admin, error) + apiKeyExists(keyID string) (APIKey, error) + addAPIKey(apiKey *APIKey) error + updateAPIKey(apiKey *APIKey) error + deleteAPIKey(apiKey APIKey) error + getAPIKeys(limit int, offset int, order string) ([]APIKey, error) + dumpAPIKeys() ([]APIKey, error) + updateAPIKeyLastUse(keyID string) error + shareExists(shareID, username string) (Share, error) + addShare(share *Share) error + updateShare(share *Share) error + deleteShare(share Share) error + getShares(limit int, offset int, order, username string) ([]Share, error) + dumpShares() ([]Share, error) + updateShareLastUse(shareID string, numTokens int) error + getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) + getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) + isDefenderHostBanned(ip string) (DefenderEntry, error) + updateDefenderBanTime(ip string, minutes int) error + deleteDefenderHost(ip string) error + addDefenderEvent(ip string, score int) error + setDefenderBanTime(ip string, banTime int64) error + cleanupDefender(from int64) error + addActiveTransfer(transfer ActiveTransfer) error + updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error + removeActiveTransfer(transferID int64, connectionID string) error + cleanupActiveTransfers(before time.Time) error + getActiveTransfers(from time.Time) ([]ActiveTransfer, error) + addSharedSession(session Session) error + deleteSharedSession(key string, sessionType SessionType) error + getSharedSession(key string, sessionType SessionType) (Session, error) + cleanupSharedSessions(sessionType SessionType, before int64) error + getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) + dumpEventActions() ([]BaseEventAction, error) + eventActionExists(name string) (BaseEventAction, error) + addEventAction(action *BaseEventAction) error + updateEventAction(action *BaseEventAction) error + deleteEventAction(action BaseEventAction) error + getEventRules(limit, offset int, order string) ([]EventRule, error) + dumpEventRules() ([]EventRule, error) + getRecentlyUpdatedRules(after int64) ([]EventRule, error) + eventRuleExists(name string) (EventRule, error) + addEventRule(rule *EventRule) error + updateEventRule(rule *EventRule) error + deleteEventRule(rule EventRule, softDelete bool) error + getTaskByName(name string) (Task, error) + addTask(name string) error + updateTask(name string, version int64) error + updateTaskTimestamp(name string) error + setFirstDownloadTimestamp(username string) error + setFirstUploadTimestamp(username string) error + addNode() error + getNodeByName(name string) (Node, error) + getNodes() ([]Node, error) + updateNodeTimestamp() error + cleanupNodes() error + roleExists(name string) (Role, error) + addRole(role *Role) error + updateRole(role *Role) error + deleteRole(role Role) error + getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) + dumpRoles() ([]Role, error) + ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) + addIPListEntry(entry *IPListEntry) error + updateIPListEntry(entry *IPListEntry) error + deleteIPListEntry(entry IPListEntry, softDelete bool) error + getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) + getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) + dumpIPListEntries() ([]IPListEntry, error) + countIPListEntries(listType IPListType) (int64, error) + getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) + getConfigs() (Configs, error) + setConfigs(configs *Configs) error + checkAvailability() error + close() error + reloadConfig() error + initializeDatabase() error + migrateDatabase() error + revertDatabase(targetVersion int) error + resetDatabase() error +} + +// SetAllowSelfConnections sets the desired behaviour for self connections +func SetAllowSelfConnections(value int) { + allowSelfConnections = value +} + +// SetTempPath sets the path for temporary files +func SetTempPath(fsPath string) { + tempPath = fsPath +} + +func checkSharedMode() { + if !slices.Contains(sharedProviders, config.Driver) { + config.IsShared = 0 + } +} + +// Initialize the data provider. +// An error is returned if the configured driver is invalid or if the data provider cannot be initialized +func Initialize(cnf Config, basePath string, checkAdmins bool) error { + config = cnf + checkSharedMode() + config.Actions.ExecuteOn = util.RemoveDuplicates(config.Actions.ExecuteOn, true) + config.Actions.ExecuteFor = util.RemoveDuplicates(config.Actions.ExecuteFor, true) + + cnf.BackupsPath = getConfigPath(cnf.BackupsPath, basePath) + if cnf.BackupsPath == "" { + return fmt.Errorf("required directory is invalid, backup path %q", cnf.BackupsPath) + } + absoluteBackupPath, err := util.GetAbsolutePath(cnf.BackupsPath) + if err != nil { + return fmt.Errorf("unable to get absolute backup path: %w", err) + } + config.BackupsPath = absoluteBackupPath + + if err := initializeHashingAlgo(&cnf); err != nil { + return err + } + if err := validateHooks(); err != nil { + return err + } + if err := createProvider(basePath); err != nil { + return err + } + if err := checkDatabase(checkAdmins); err != nil { + return err + } + admins, err := provider.getAdmins(1, 0, OrderASC) + if err != nil { + return err + } + isAdminCreated.Store(len(admins) > 0) + if err := config.Node.validate(); err != nil { + return err + } + delayedQuotaUpdater.start() + if currentNode != nil { + config.BackupsPath = filepath.Join(config.BackupsPath, currentNode.Name) + } + providerLog(logger.LevelDebug, "absolute backup path %q", config.BackupsPath) + return startScheduler() +} + +func checkDatabase(checkAdmins bool) error { + if config.UpdateMode == 0 { + err := provider.initializeDatabase() + if err != nil && err != ErrNoInitRequired { + logger.WarnToConsole("unable to initialize data provider: %v", err) + providerLog(logger.LevelError, "unable to initialize data provider: %v", err) + return err + } + if err == nil { + logger.DebugToConsole("data provider successfully initialized") + providerLog(logger.LevelInfo, "data provider successfully initialized") + } + err = provider.migrateDatabase() + if err != nil && err != ErrNoInitRequired { + providerLog(logger.LevelError, "database migration error: %v", err) + return err + } + if checkAdmins && config.CreateDefaultAdmin { + err = checkDefaultAdmin() + if err != nil { + providerLog(logger.LevelError, "erro checking the default admin: %v", err) + return err + } + } + } else { + providerLog(logger.LevelInfo, "database initialization/migration skipped, manual mode is configured") + } + return nil +} + +func validateHooks() error { + var hooks []string + if config.PreLoginHook != "" && !strings.HasPrefix(config.PreLoginHook, "http") { + hooks = append(hooks, config.PreLoginHook) + } + if config.ExternalAuthHook != "" && !strings.HasPrefix(config.ExternalAuthHook, "http") { + hooks = append(hooks, config.ExternalAuthHook) + } + if config.PostLoginHook != "" && !strings.HasPrefix(config.PostLoginHook, "http") { + hooks = append(hooks, config.PostLoginHook) + } + if config.CheckPasswordHook != "" && !strings.HasPrefix(config.CheckPasswordHook, "http") { + hooks = append(hooks, config.CheckPasswordHook) + } + + for _, hook := range hooks { + if !filepath.IsAbs(hook) { + return fmt.Errorf("invalid hook: %q must be an absolute path", hook) + } + _, err := os.Stat(hook) + if err != nil { + providerLog(logger.LevelError, "invalid hook: %v", err) + return err + } + } + + return nil +} + +// GetBackupsPath returns the normalized backups path +func GetBackupsPath() string { + return config.BackupsPath +} + +// GetProviderFromValue returns the FilesystemProvider matching the specified value. +// If no match is found LocalFilesystemProvider is returned. +func GetProviderFromValue(value string) sdk.FilesystemProvider { + val, err := strconv.Atoi(value) + if err != nil { + return sdk.LocalFilesystemProvider + } + result := sdk.FilesystemProvider(val) + if sdk.IsProviderSupported(result) { + return result + } + return sdk.LocalFilesystemProvider +} + +func initializeHashingAlgo(cnf *Config) error { + parallelism := cnf.PasswordHashing.Argon2Options.Parallelism + if parallelism == 0 { + parallelism = uint8(runtime.NumCPU()) + } + argon2Params = &argon2id.Params{ + Memory: cnf.PasswordHashing.Argon2Options.Memory, + Iterations: cnf.PasswordHashing.Argon2Options.Iterations, + Parallelism: parallelism, + SaltLength: 16, + KeyLength: 32, + } + + if config.PasswordHashing.Algo == HashingAlgoBcrypt { + if config.PasswordHashing.BcryptOptions.Cost > bcrypt.MaxCost { + err := fmt.Errorf("invalid bcrypt cost %v, max allowed %v", config.PasswordHashing.BcryptOptions.Cost, bcrypt.MaxCost) + logger.WarnToConsole("Unable to initialize data provider: %v", err) + providerLog(logger.LevelError, "Unable to initialize data provider: %v", err) + return err + } + } + return nil +} + +func validateSQLTablesPrefix() error { + initSQLTables() + if config.SQLTablesPrefix != "" { + for _, char := range config.SQLTablesPrefix { + if !strings.Contains(sqlPrefixValidChars, strings.ToLower(string(char))) { + return errors.New("invalid sql_tables_prefix only chars in range 'a..z', 'A..Z', '0-9' and '_' are allowed") + } + } + sqlTableUsers = config.SQLTablesPrefix + sqlTableUsers + sqlTableFolders = config.SQLTablesPrefix + sqlTableFolders + sqlTableUsersFoldersMapping = config.SQLTablesPrefix + sqlTableUsersFoldersMapping + sqlTableAdmins = config.SQLTablesPrefix + sqlTableAdmins + sqlTableAPIKeys = config.SQLTablesPrefix + sqlTableAPIKeys + sqlTableShares = config.SQLTablesPrefix + sqlTableShares + sqlTableSharesGroupsMapping = config.SQLTablesPrefix + sqlTableSharesGroupsMapping + sqlTableDefenderEvents = config.SQLTablesPrefix + sqlTableDefenderEvents + sqlTableDefenderHosts = config.SQLTablesPrefix + sqlTableDefenderHosts + sqlTableActiveTransfers = config.SQLTablesPrefix + sqlTableActiveTransfers + sqlTableGroups = config.SQLTablesPrefix + sqlTableGroups + sqlTableUsersGroupsMapping = config.SQLTablesPrefix + sqlTableUsersGroupsMapping + sqlTableAdminsGroupsMapping = config.SQLTablesPrefix + sqlTableAdminsGroupsMapping + sqlTableGroupsFoldersMapping = config.SQLTablesPrefix + sqlTableGroupsFoldersMapping + sqlTableSharedSessions = config.SQLTablesPrefix + sqlTableSharedSessions + sqlTableEventsActions = config.SQLTablesPrefix + sqlTableEventsActions + sqlTableEventsRules = config.SQLTablesPrefix + sqlTableEventsRules + sqlTableRulesActionsMapping = config.SQLTablesPrefix + sqlTableRulesActionsMapping + sqlTableTasks = config.SQLTablesPrefix + sqlTableTasks + sqlTableNodes = config.SQLTablesPrefix + sqlTableNodes + sqlTableRoles = config.SQLTablesPrefix + sqlTableRoles + sqlTableIPLists = config.SQLTablesPrefix + sqlTableIPLists + sqlTableConfigs = config.SQLTablesPrefix + sqlTableConfigs + sqlTableSchemaVersion = config.SQLTablesPrefix + sqlTableSchemaVersion + providerLog(logger.LevelDebug, "sql table for users %q, folders %q users folders mapping %q admins %q "+ + "api keys %q shares %q defender hosts %q defender events %q transfers %q groups %q "+ + "users groups mapping %q admins groups mapping %q groups folders mapping %q shared sessions %q "+ + "schema version %q events actions %q events rules %q rules actions mapping %q tasks %q nodes %q roles %q"+ + "ip lists %q share groups mapping %q configs %q", + sqlTableUsers, sqlTableFolders, sqlTableUsersFoldersMapping, sqlTableAdmins, sqlTableAPIKeys, + sqlTableShares, sqlTableDefenderHosts, sqlTableDefenderEvents, sqlTableActiveTransfers, sqlTableGroups, + sqlTableUsersGroupsMapping, sqlTableAdminsGroupsMapping, sqlTableGroupsFoldersMapping, sqlTableSharedSessions, + sqlTableSchemaVersion, sqlTableEventsActions, sqlTableEventsRules, sqlTableRulesActionsMapping, + sqlTableTasks, sqlTableNodes, sqlTableRoles, sqlTableIPLists, sqlTableSharesGroupsMapping, sqlTableConfigs) + } + return nil +} + +func checkDefaultAdmin() error { + admins, err := provider.getAdmins(1, 0, OrderASC) + if err != nil { + return err + } + if len(admins) > 0 { + return nil + } + logger.Debug(logSender, "", "no admins found, try to create the default one") + // we need to create the default admin + admin := &Admin{} + if err := admin.setFromEnv(); err != nil { + return err + } + return provider.addAdmin(admin) +} + +// InitializeDatabase creates the initial database structure +func InitializeDatabase(cnf Config, basePath string) error { + config = cnf + + if err := initializeHashingAlgo(&cnf); err != nil { + return err + } + + err := createProvider(basePath) + if err != nil { + return err + } + err = provider.initializeDatabase() + if err != nil && err != ErrNoInitRequired { + return err + } + return provider.migrateDatabase() +} + +// RevertDatabase restores schema and/or data to a previous version +func RevertDatabase(cnf Config, basePath string, targetVersion int) error { + config = cnf + + err := createProvider(basePath) + if err != nil { + return err + } + err = provider.initializeDatabase() + if err != nil && err != ErrNoInitRequired { + return err + } + return provider.revertDatabase(targetVersion) +} + +// ResetDatabase restores schema and/or data to a previous version +func ResetDatabase(cnf Config, basePath string) error { + config = cnf + + if err := createProvider(basePath); err != nil { + return err + } + return provider.resetDatabase() +} + +// CheckAdminAndPass validates the given admin and password connecting from ip +func CheckAdminAndPass(username, password, ip string) (Admin, error) { + username = config.convertName(username) + return provider.validateAdminAndPass(username, password, ip) +} + +// CheckCachedUserCredentials checks the credentials for a cached user +func CheckCachedUserCredentials(user *CachedUser, password, ip, loginMethod, protocol string, tlsCert *x509.Certificate) (*CachedUser, *User, error) { + if !user.User.skipExternalAuth() && isExternalAuthConfigured(loginMethod) { + u, _, err := CheckCompositeCredentials(user.User.Username, password, ip, loginMethod, protocol, tlsCert) + if err != nil { + return nil, nil, err + } + webDAVUsersCache.swap(&u, password) + cu, _ := webDAVUsersCache.get(u.Username) + return cu, &u, nil + } + if err := user.User.CheckLoginConditions(); err != nil { + return user, nil, err + } + if loginMethod == LoginMethodPassword && user.User.Filters.IsAnonymous { + return user, nil, nil + } + if loginMethod != LoginMethodPassword { + _, err := checkUserAndTLSCertificate(&user.User, protocol, tlsCert) + if err != nil { + return user, nil, err + } + if loginMethod == LoginMethodTLSCertificate { + if !user.User.IsLoginMethodAllowed(LoginMethodTLSCertificate, protocol) { + return user, nil, fmt.Errorf("certificate login method is not allowed for user %q", user.User.Username) + } + return user, nil, nil + } + } + if password == "" { + return user, nil, ErrInvalidCredentials + } + if user.Password != "" { + if password == user.Password { + return user, nil, nil + } + } else { + if ok, _ := isPasswordOK(&user.User, password); ok { + return user, nil, nil + } + } + return user, nil, ErrInvalidCredentials +} + +// CheckCompositeCredentials checks multiple credentials. +// WebDAV users can send both a password and a TLS certificate within the same request +func CheckCompositeCredentials(username, password, ip, loginMethod, protocol string, tlsCert *x509.Certificate) (User, string, error) { + username = config.convertName(username) + if loginMethod == LoginMethodPassword { + user, err := CheckUserAndPass(username, password, ip, protocol) + return user, loginMethod, err + } + user, err := CheckUserBeforeTLSAuth(username, ip, protocol, tlsCert) + if err != nil { + return user, loginMethod, err + } + if !user.IsTLSVerificationEnabled() { + // for backward compatibility with 2.0.x we only check the password and change the login method here + // in future updates we have to return an error + user, err := CheckUserAndPass(username, password, ip, protocol) + return user, LoginMethodPassword, err + } + user, err = checkUserAndTLSCertificate(&user, protocol, tlsCert) + if err != nil { + return user, loginMethod, err + } + if loginMethod == LoginMethodTLSCertificate && !user.IsLoginMethodAllowed(LoginMethodTLSCertificate, protocol) { + return user, loginMethod, fmt.Errorf("certificate login method is not allowed for user %q", user.Username) + } + if loginMethod == LoginMethodTLSCertificateAndPwd { + if plugin.Handler.HasAuthScope(plugin.AuthScopePassword) { + user, err = doPluginAuth(username, password, nil, ip, protocol, nil, plugin.AuthScopePassword) + } else if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) { + user, err = doExternalAuth(username, password, nil, "", ip, protocol, nil) + } else if config.PreLoginHook != "" { + user, err = executePreLoginHook(username, LoginMethodPassword, ip, protocol, nil) + } + if err != nil { + return user, loginMethod, err + } + user, err = checkUserAndPass(&user, password, ip, protocol) + } + return user, loginMethod, err +} + +// CheckUserBeforeTLSAuth checks if a user exits before trying mutual TLS +func CheckUserBeforeTLSAuth(username, ip, protocol string, tlsCert *x509.Certificate) (User, error) { + username = config.convertName(username) + if plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) { + user, err := doPluginAuth(username, "", nil, ip, protocol, tlsCert, plugin.AuthScopeTLSCertificate) + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err + } + if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&8 != 0) { + user, err := doExternalAuth(username, "", nil, "", ip, protocol, tlsCert) + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err + } + if config.PreLoginHook != "" { + user, err := executePreLoginHook(username, LoginMethodTLSCertificate, ip, protocol, nil) + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err + } + user, err := UserExists(username, "") + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err +} + +// CheckUserAndTLSCert returns the SFTPGo user with the given username and check if the +// given TLS certificate allow authentication without password +func CheckUserAndTLSCert(username, ip, protocol string, tlsCert *x509.Certificate) (User, error) { + username = config.convertName(username) + if plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) { + user, err := doPluginAuth(username, "", nil, ip, protocol, tlsCert, plugin.AuthScopeTLSCertificate) + if err != nil { + return user, err + } + return checkUserAndTLSCertificate(&user, protocol, tlsCert) + } + if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&8 != 0) { + user, err := doExternalAuth(username, "", nil, "", ip, protocol, tlsCert) + if err != nil { + return user, err + } + return checkUserAndTLSCertificate(&user, protocol, tlsCert) + } + if config.PreLoginHook != "" { + user, err := executePreLoginHook(username, LoginMethodTLSCertificate, ip, protocol, nil) + if err != nil { + return user, err + } + return checkUserAndTLSCertificate(&user, protocol, tlsCert) + } + return provider.validateUserAndTLSCert(username, protocol, tlsCert) +} + +// CheckUserAndPass retrieves the SFTPGo user with the given username and password if a match is found or an error +func CheckUserAndPass(username, password, ip, protocol string) (User, error) { + username = config.convertName(username) + if plugin.Handler.HasAuthScope(plugin.AuthScopePassword) { + user, err := doPluginAuth(username, password, nil, ip, protocol, nil, plugin.AuthScopePassword) + if err != nil { + return user, err + } + return checkUserAndPass(&user, password, ip, protocol) + } + if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&1 != 0) { + user, err := doExternalAuth(username, password, nil, "", ip, protocol, nil) + if err != nil { + return user, err + } + return checkUserAndPass(&user, password, ip, protocol) + } + if config.PreLoginHook != "" { + user, err := executePreLoginHook(username, LoginMethodPassword, ip, protocol, nil) + if err != nil { + return user, err + } + return checkUserAndPass(&user, password, ip, protocol) + } + return provider.validateUserAndPass(username, password, ip, protocol) +} + +// CheckUserAndPubKey retrieves the SFTP user with the given username and public key if a match is found or an error +func CheckUserAndPubKey(username string, pubKey []byte, ip, protocol string, isSSHCert bool) (User, string, error) { + username = config.convertName(username) + if plugin.Handler.HasAuthScope(plugin.AuthScopePublicKey) { + user, err := doPluginAuth(username, "", pubKey, ip, protocol, nil, plugin.AuthScopePublicKey) + if err != nil { + return user, "", err + } + return checkUserAndPubKey(&user, pubKey, isSSHCert) + } + if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&2 != 0) { + user, err := doExternalAuth(username, "", pubKey, "", ip, protocol, nil) + if err != nil { + return user, "", err + } + return checkUserAndPubKey(&user, pubKey, isSSHCert) + } + if config.PreLoginHook != "" { + user, err := executePreLoginHook(username, SSHLoginMethodPublicKey, ip, protocol, nil) + if err != nil { + return user, "", err + } + return checkUserAndPubKey(&user, pubKey, isSSHCert) + } + return provider.validateUserAndPubKey(username, pubKey, isSSHCert) +} + +// CheckKeyboardInteractiveAuth checks the keyboard interactive authentication and returns +// the authenticated user or an error +func CheckKeyboardInteractiveAuth(username, authHook string, client ssh.KeyboardInteractiveChallenge, + ip, protocol string, isPartialAuth bool, +) (User, error) { + var user User + var err error + username = config.convertName(username) + if plugin.Handler.HasAuthScope(plugin.AuthScopeKeyboardInteractive) { + user, err = doPluginAuth(username, "", nil, ip, protocol, nil, plugin.AuthScopeKeyboardInteractive) + } else if config.ExternalAuthHook != "" && (config.ExternalAuthScope == 0 || config.ExternalAuthScope&4 != 0) { + user, err = doExternalAuth(username, "", nil, "1", ip, protocol, nil) + } else if config.PreLoginHook != "" { + user, err = executePreLoginHook(username, SSHLoginMethodKeyboardInteractive, ip, protocol, nil) + } else { + user, err = provider.userExists(username, "") + } + if err != nil { + return user, err + } + return doKeyboardInteractiveAuth(&user, authHook, client, ip, protocol, isPartialAuth) +} + +// GetFTPPreAuthUser returns the SFTPGo user with the specified username +// after receiving the FTP "USER" command. +// If a pre-login hook is defined it will be executed so the SFTPGo user +// can be created if it does not exist +func GetFTPPreAuthUser(username, ip string) (User, error) { + var user User + var err error + if config.PreLoginHook != "" { + user, err = executePreLoginHook(username, "", ip, protocolFTP, nil) + } else { + user, err = UserExists(username, "") + } + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err +} + +// GetUserAfterIDPAuth returns the SFTPGo user with the specified username +// after a successful authentication with an external identity provider. +// If a pre-login hook is defined it will be executed so the SFTPGo user +// can be created if it does not exist +func GetUserAfterIDPAuth(username, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { + var user User + var err error + if config.PreLoginHook != "" { + user, err = executePreLoginHook(username, LoginMethodIDP, ip, protocol, oidcTokenFields) + user.Filters.RequirePasswordChange = false + } else { + user, err = UserExists(username, "") + } + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err +} + +// GetDefenderHosts returns hosts that are banned or for which some violations have been detected +func GetDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { + return provider.getDefenderHosts(from, limit) +} + +// GetDefenderHostByIP returns a defender host by ip, if any +func GetDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { + return provider.getDefenderHostByIP(ip, from) +} + +// IsDefenderHostBanned returns a defender entry and no error if the specified host is banned +func IsDefenderHostBanned(ip string) (DefenderEntry, error) { + return provider.isDefenderHostBanned(ip) +} + +// UpdateDefenderBanTime increments ban time for the specified ip +func UpdateDefenderBanTime(ip string, minutes int) error { + return provider.updateDefenderBanTime(ip, minutes) +} + +// DeleteDefenderHost removes the specified IP from the defender lists +func DeleteDefenderHost(ip string) error { + return provider.deleteDefenderHost(ip) +} + +// AddDefenderEvent adds an event for the given IP with the given score +// and returns the host with the updated score +func AddDefenderEvent(ip string, score int, from int64) (DefenderEntry, error) { + if err := provider.addDefenderEvent(ip, score); err != nil { + return DefenderEntry{}, err + } + return provider.getDefenderHostByIP(ip, from) +} + +// SetDefenderBanTime sets the ban time for the specified IP +func SetDefenderBanTime(ip string, banTime int64) error { + return provider.setDefenderBanTime(ip, banTime) +} + +// CleanupDefender removes events and hosts older than "from" from the data provider +func CleanupDefender(from int64) error { + return provider.cleanupDefender(from) +} + +// UpdateShareLastUse updates the LastUseAt and UsedTokens for the given share +func UpdateShareLastUse(share *Share, numTokens int) error { + return provider.updateShareLastUse(share.ShareID, numTokens) +} + +// UpdateAPIKeyLastUse updates the LastUseAt field for the given API key +func UpdateAPIKeyLastUse(apiKey *APIKey) error { + lastUse := util.GetTimeFromMsecSinceEpoch(apiKey.LastUseAt) + diff := -time.Until(lastUse) + if diff < 0 || diff > lastLoginMinDelay { + return provider.updateAPIKeyLastUse(apiKey.KeyID) + } + return nil +} + +// UpdateLastLogin updates the last login field for the given SFTPGo user +func UpdateLastLogin(user *User) { + delay := lastLoginMinDelay + if user.Filters.ExternalAuthCacheTime > 0 { + delay = time.Duration(user.Filters.ExternalAuthCacheTime) * time.Second + } + if user.LastLogin <= user.UpdatedAt || !isLastActivityRecent(user.LastLogin, delay) { + err := provider.updateLastLogin(user.Username) + if err == nil { + webDAVUsersCache.updateLastLogin(user.Username) + } + } +} + +// UpdateAdminLastLogin updates the last login field for the given SFTPGo admin +func UpdateAdminLastLogin(admin *Admin) { + if !isLastActivityRecent(admin.LastLogin, lastLoginMinDelay) { + provider.updateAdminLastLogin(admin.Username) //nolint:errcheck + } +} + +// UpdateUserQuota updates the quota for the given SFTPGo user adding filesAdd and sizeAdd. +// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. +func UpdateUserQuota(user *User, filesAdd int, sizeAdd int64, reset bool) error { + if config.TrackQuota == 0 { + return util.NewMethodDisabledError(trackQuotaDisabledError) + } else if config.TrackQuota == 2 && !reset && !user.HasQuotaRestrictions() { + return nil + } + if filesAdd == 0 && sizeAdd == 0 && !reset { + return nil + } + if config.DelayedQuotaUpdate == 0 || reset { + if reset { + delayedQuotaUpdater.resetUserQuota(user.Username) + } + return provider.updateQuota(user.Username, filesAdd, sizeAdd, reset) + } + delayedQuotaUpdater.updateUserQuota(user.Username, filesAdd, sizeAdd) + return nil +} + +// UpdateUserFolderQuota updates the quota for the given user and virtual folder. +func UpdateUserFolderQuota(folder *vfs.VirtualFolder, user *User, filesAdd int, sizeAdd int64, reset bool) { + if folder.IsIncludedInUserQuota() { + UpdateUserQuota(user, filesAdd, sizeAdd, reset) //nolint:errcheck + return + } + UpdateVirtualFolderQuota(&folder.BaseVirtualFolder, filesAdd, sizeAdd, reset) //nolint:errcheck +} + +// UpdateVirtualFolderQuota updates the quota for the given virtual folder adding filesAdd and sizeAdd. +// If reset is true filesAdd and sizeAdd indicates the total files and the total size instead of the difference. +func UpdateVirtualFolderQuota(vfolder *vfs.BaseVirtualFolder, filesAdd int, sizeAdd int64, reset bool) error { + if config.TrackQuota == 0 { + return util.NewMethodDisabledError(trackQuotaDisabledError) + } + if filesAdd == 0 && sizeAdd == 0 && !reset { + return nil + } + if config.DelayedQuotaUpdate == 0 || reset { + if reset { + delayedQuotaUpdater.resetFolderQuota(vfolder.Name) + } + return provider.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd, reset) + } + delayedQuotaUpdater.updateFolderQuota(vfolder.Name, filesAdd, sizeAdd) + return nil +} + +// UpdateUserTransferQuota updates the transfer quota for the given SFTPGo user. +// If reset is true uploadSize and downloadSize indicates the actual sizes instead of the difference. +func UpdateUserTransferQuota(user *User, uploadSize, downloadSize int64, reset bool) error { + if config.TrackQuota == 0 { + return util.NewMethodDisabledError(trackQuotaDisabledError) + } else if config.TrackQuota == 2 && !reset && !user.HasTransferQuotaRestrictions() { + return nil + } + if downloadSize == 0 && uploadSize == 0 && !reset { + return nil + } + if config.DelayedQuotaUpdate == 0 || reset { + if reset { + delayedQuotaUpdater.resetUserTransferQuota(user.Username) + } + return provider.updateTransferQuota(user.Username, uploadSize, downloadSize, reset) + } + delayedQuotaUpdater.updateUserTransferQuota(user.Username, uploadSize, downloadSize) + return nil +} + +// UpdateUserTransferTimestamps updates the first download/upload fields if unset +func UpdateUserTransferTimestamps(username string, isUpload bool) error { + if isUpload { + err := provider.setFirstUploadTimestamp(username) + if err != nil { + providerLog(logger.LevelWarn, "unable to set first upload: %v", err) + } + return err + } + err := provider.setFirstDownloadTimestamp(username) + if err != nil { + providerLog(logger.LevelWarn, "unable to set first download: %v", err) + } + return err +} + +// GetUsedQuota returns the used quota for the given SFTPGo user. +func GetUsedQuota(username string) (int, int64, int64, int64, error) { + if config.TrackQuota == 0 { + return 0, 0, 0, 0, util.NewMethodDisabledError(trackQuotaDisabledError) + } + files, size, ulTransferSize, dlTransferSize, err := provider.getUsedQuota(username) + if err != nil { + return files, size, ulTransferSize, dlTransferSize, err + } + delayedFiles, delayedSize := delayedQuotaUpdater.getUserPendingQuota(username) + delayedUlTransferSize, delayedDLTransferSize := delayedQuotaUpdater.getUserPendingTransferQuota(username) + + return files + delayedFiles, size + delayedSize, ulTransferSize + delayedUlTransferSize, + dlTransferSize + delayedDLTransferSize, err +} + +// GetUsedVirtualFolderQuota returns the used quota for the given virtual folder. +func GetUsedVirtualFolderQuota(name string) (int, int64, error) { + if config.TrackQuota == 0 { + return 0, 0, util.NewMethodDisabledError(trackQuotaDisabledError) + } + files, size, err := provider.getUsedFolderQuota(name) + if err != nil { + return files, size, err + } + delayedFiles, delayedSize := delayedQuotaUpdater.getFolderPendingQuota(name) + return files + delayedFiles, size + delayedSize, err +} + +// GetConfigs returns the configurations +func GetConfigs() (Configs, error) { + return provider.getConfigs() +} + +// UpdateConfigs updates configurations +func UpdateConfigs(configs *Configs, executor, ipAddress, role string) error { + if configs == nil { + configs = &Configs{} + } else { + configs.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + err := provider.setConfigs(configs) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectConfigs, "configs", role, configs) + } + return err +} + +// AddShare adds a new share +func AddShare(share *Share, executor, ipAddress, role string) error { + err := provider.addShare(share) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectShare, share.ShareID, role, share) + } + return err +} + +// UpdateShare updates an existing share +func UpdateShare(share *Share, executor, ipAddress, role string) error { + err := provider.updateShare(share) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectShare, share.ShareID, role, share) + } + return err +} + +// DeleteShare deletes an existing share +func DeleteShare(shareID string, executor, ipAddress, role string) error { + share, err := provider.shareExists(shareID, executor) + if err != nil { + return err + } + err = provider.deleteShare(share) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectShare, shareID, role, &share) + } + return err +} + +// ShareExists returns the share with the given ID if it exists +func ShareExists(shareID, username string) (Share, error) { + if shareID == "" { + return Share{}, util.NewRecordNotFoundError(fmt.Sprintf("Share %q does not exist", shareID)) + } + return provider.shareExists(shareID, username) +} + +// AddIPListEntry adds a new IP list entry +func AddIPListEntry(entry *IPListEntry, executor, ipAddress, executorRole string) error { + err := provider.addIPListEntry(entry) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectIPListEntry, entry.getName(), executorRole, entry) + for _, l := range inMemoryLists { + l.addEntry(entry) + } + } + return err +} + +// UpdateIPListEntry updates an existing IP list entry +func UpdateIPListEntry(entry *IPListEntry, executor, ipAddress, executorRole string) error { + err := provider.updateIPListEntry(entry) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectIPListEntry, entry.getName(), executorRole, entry) + for _, l := range inMemoryLists { + l.updateEntry(entry) + } + } + return err +} + +// DeleteIPListEntry deletes an existing IP list entry +func DeleteIPListEntry(ipOrNet string, listType IPListType, executor, ipAddress, executorRole string) error { + entry, err := provider.ipListEntryExists(ipOrNet, listType) + if err != nil { + return err + } + err = provider.deleteIPListEntry(entry, config.IsShared == 1) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectIPListEntry, entry.getName(), executorRole, &entry) + for _, l := range inMemoryLists { + l.removeEntry(&entry) + } + } + return err +} + +// IPListEntryExists returns the IP list entry with the given IP/net and type if it exists +func IPListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { + return provider.ipListEntryExists(ipOrNet, listType) +} + +// GetIPListEntries returns the IP list entries applying the specified criteria and search limit +func GetIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { + if !slices.Contains(supportedIPListType, listType) { + return nil, util.NewValidationError(fmt.Sprintf("invalid list type %d", listType)) + } + return provider.getIPListEntries(listType, filter, from, order, limit) +} + +// AddRole adds a new role +func AddRole(role *Role, executor, ipAddress, executorRole string) error { + role.Name = config.convertName(role.Name) + err := provider.addRole(role) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectRole, role.Name, executorRole, role) + } + return err +} + +// UpdateRole updates an existing Role +func UpdateRole(role *Role, executor, ipAddress, executorRole string) error { + err := provider.updateRole(role) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectRole, role.Name, executorRole, role) + } + return err +} + +// DeleteRole deletes an existing Role +func DeleteRole(name string, executor, ipAddress, executorRole string) error { + name = config.convertName(name) + role, err := provider.roleExists(name) + if err != nil { + return err + } + if len(role.Admins) > 0 { + errorString := fmt.Sprintf("the role %q is referenced, it cannot be removed", role.Name) + return util.NewValidationError(errorString) + } + err = provider.deleteRole(role) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectRole, role.Name, executorRole, &role) + for _, user := range role.Users { + provider.setUpdatedAt(user) + u, err := provider.userExists(user, "") + if err == nil { + webDAVUsersCache.swap(&u, "") + executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) + } + } + } + return err +} + +// RoleExists returns the Role with the given name if it exists +func RoleExists(name string) (Role, error) { + name = config.convertName(name) + return provider.roleExists(name) +} + +// AddGroup adds a new group +func AddGroup(group *Group, executor, ipAddress, role string) error { + group.Name = config.convertName(group.Name) + err := provider.addGroup(group) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectGroup, group.Name, role, group) + } + return err +} + +// UpdateGroup updates an existing Group +func UpdateGroup(group *Group, users []string, executor, ipAddress, role string) error { + err := provider.updateGroup(group) + if err == nil { + for _, user := range users { + provider.setUpdatedAt(user) + u, err := provider.userExists(user, "") + if err == nil { + webDAVUsersCache.swap(&u, "") + } else { + RemoveCachedWebDAVUser(user) + } + } + executeAction(operationUpdate, executor, ipAddress, actionObjectGroup, group.Name, role, group) + } + return err +} + +// DeleteGroup deletes an existing Group +func DeleteGroup(name string, executor, ipAddress, role string) error { + name = config.convertName(name) + group, err := provider.groupExists(name) + if err != nil { + return err + } + if len(group.Users) > 0 { + errorString := fmt.Sprintf("the group %q is referenced, it cannot be removed", group.Name) + return util.NewValidationError(errorString) + } + err = provider.deleteGroup(group) + if err == nil { + for _, user := range group.Users { + provider.setUpdatedAt(user) + u, err := provider.userExists(user, "") + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) + } + RemoveCachedWebDAVUser(user) + } + executeAction(operationDelete, executor, ipAddress, actionObjectGroup, group.Name, role, &group) + } + return err +} + +// GroupExists returns the Group with the given name if it exists +func GroupExists(name string) (Group, error) { + name = config.convertName(name) + return provider.groupExists(name) +} + +// AddAPIKey adds a new API key +func AddAPIKey(apiKey *APIKey, executor, ipAddress, role string) error { + err := provider.addAPIKey(apiKey) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectAPIKey, apiKey.KeyID, role, apiKey) + } + return err +} + +// UpdateAPIKey updates an existing API key +func UpdateAPIKey(apiKey *APIKey, executor, ipAddress, role string) error { + err := provider.updateAPIKey(apiKey) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectAPIKey, apiKey.KeyID, role, apiKey) + } + return err +} + +// DeleteAPIKey deletes an existing API key +func DeleteAPIKey(keyID string, executor, ipAddress, role string) error { + apiKey, err := provider.apiKeyExists(keyID) + if err != nil { + return err + } + err = provider.deleteAPIKey(apiKey) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectAPIKey, apiKey.KeyID, role, &apiKey) + cachedAPIKeys.Remove(keyID) + } + return err +} + +// APIKeyExists returns the API key with the given ID if it exists +func APIKeyExists(keyID string) (APIKey, error) { + if keyID == "" { + return APIKey{}, util.NewRecordNotFoundError(fmt.Sprintf("API key %q does not exist", keyID)) + } + return provider.apiKeyExists(keyID) +} + +// GetEventActions returns an array of event actions respecting limit and offset +func GetEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { + return provider.getEventActions(limit, offset, order, minimal) +} + +// EventActionExists returns the event action with the given name if it exists +func EventActionExists(name string) (BaseEventAction, error) { + name = config.convertName(name) + return provider.eventActionExists(name) +} + +// AddEventAction adds a new event action +func AddEventAction(action *BaseEventAction, executor, ipAddress, role string) error { + action.Name = config.convertName(action.Name) + err := provider.addEventAction(action) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectEventAction, action.Name, role, action) + } + return err +} + +// UpdateEventAction updates an existing event action +func UpdateEventAction(action *BaseEventAction, executor, ipAddress, role string) error { + err := provider.updateEventAction(action) + if err == nil { + if fnReloadRules != nil { + fnReloadRules() + } + executeAction(operationUpdate, executor, ipAddress, actionObjectEventAction, action.Name, role, action) + } + return err +} + +// DeleteEventAction deletes an existing event action +func DeleteEventAction(name string, executor, ipAddress, role string) error { + name = config.convertName(name) + action, err := provider.eventActionExists(name) + if err != nil { + return err + } + if len(action.Rules) > 0 { + errorString := fmt.Sprintf("the event action %#q is referenced, it cannot be removed", action.Name) + return util.NewValidationError(errorString) + } + err = provider.deleteEventAction(action) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectEventAction, action.Name, role, &action) + } + return err +} + +// GetEventRules returns an array of event rules respecting limit and offset +func GetEventRules(limit, offset int, order string) ([]EventRule, error) { + return provider.getEventRules(limit, offset, order) +} + +// GetRecentlyUpdatedRules returns the event rules updated after the specified time +func GetRecentlyUpdatedRules(after int64) ([]EventRule, error) { + return provider.getRecentlyUpdatedRules(after) +} + +// EventRuleExists returns the event rule with the given name if it exists +func EventRuleExists(name string) (EventRule, error) { + name = config.convertName(name) + return provider.eventRuleExists(name) +} + +// AddEventRule adds a new event rule +func AddEventRule(rule *EventRule, executor, ipAddress, role string) error { + rule.Name = config.convertName(rule.Name) + err := provider.addEventRule(rule) + if err == nil { + if fnReloadRules != nil { + fnReloadRules() + } + executeAction(operationAdd, executor, ipAddress, actionObjectEventRule, rule.Name, role, rule) + } + return err +} + +// UpdateEventRule updates an existing event rule +func UpdateEventRule(rule *EventRule, executor, ipAddress, role string) error { + err := provider.updateEventRule(rule) + if err == nil { + if fnReloadRules != nil { + fnReloadRules() + } + executeAction(operationUpdate, executor, ipAddress, actionObjectEventRule, rule.Name, role, rule) + } + return err +} + +// DeleteEventRule deletes an existing event rule +func DeleteEventRule(name string, executor, ipAddress, role string) error { + name = config.convertName(name) + rule, err := provider.eventRuleExists(name) + if err != nil { + return err + } + err = provider.deleteEventRule(rule, config.IsShared == 1) + if err == nil { + if fnRemoveRule != nil { + fnRemoveRule(rule.Name) + } + executeAction(operationDelete, executor, ipAddress, actionObjectEventRule, rule.Name, role, &rule) + } + return err +} + +// RemoveEventRule delets an existing event rule without marking it as deleted +func RemoveEventRule(rule EventRule) error { + return provider.deleteEventRule(rule, false) +} + +// GetTaskByName returns the task with the specified name +func GetTaskByName(name string) (Task, error) { + return provider.getTaskByName(name) +} + +// AddTask add a task with the specified name +func AddTask(name string) error { + return provider.addTask(name) +} + +// UpdateTask updates the task with the specified name and version +func UpdateTask(name string, version int64) error { + return provider.updateTask(name, version) +} + +// UpdateTaskTimestamp updates the timestamp for the task with the specified name +func UpdateTaskTimestamp(name string) error { + return provider.updateTaskTimestamp(name) +} + +// GetNodes returns the other cluster nodes +func GetNodes() ([]Node, error) { + if currentNode == nil { + return nil, nil + } + nodes, err := provider.getNodes() + if err != nil { + providerLog(logger.LevelError, "unable to get other cluster nodes %v", err) + } + return nodes, err +} + +// GetNodeByName returns a node, different from the current one, by name +func GetNodeByName(name string) (Node, error) { + if currentNode == nil { + return Node{}, util.NewRecordNotFoundError(errNoClusterNodes.Error()) + } + if name == currentNode.Name { + return Node{}, util.NewValidationError(fmt.Sprintf("%s is the current node, it must refer to other nodes", name)) + } + return provider.getNodeByName(name) +} + +// HasAdmin returns true if the first admin has been created +// and so SFTPGo is ready to be used +func HasAdmin() bool { + return isAdminCreated.Load() +} + +// AddAdmin adds a new SFTPGo admin +func AddAdmin(admin *Admin, executor, ipAddress, role string) error { + admin.Filters.RecoveryCodes = nil + admin.Filters.TOTPConfig = AdminTOTPConfig{ + Enabled: false, + } + admin.Username = config.convertName(admin.Username) + err := provider.addAdmin(admin) + if err == nil { + isAdminCreated.Store(true) + executeAction(operationAdd, executor, ipAddress, actionObjectAdmin, admin.Username, role, admin) + } + return err +} + +// UpdateAdmin updates an existing SFTPGo admin +func UpdateAdmin(admin *Admin, executor, ipAddress, role string) error { + err := provider.updateAdmin(admin) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectAdmin, admin.Username, role, admin) + } + return err +} + +// DeleteAdmin deletes an existing SFTPGo admin +func DeleteAdmin(username, executor, ipAddress, role string) error { + username = config.convertName(username) + admin, err := provider.adminExists(username) + if err != nil { + return err + } + err = provider.deleteAdmin(admin) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectAdmin, admin.Username, role, &admin) + cachedAdminPasswords.Remove(username) + } + return err +} + +// AdminExists returns the admin with the given username if it exists +func AdminExists(username string) (Admin, error) { + username = config.convertName(username) + return provider.adminExists(username) +} + +// UserExists checks if the given SFTPGo username exists, returns an error if no match is found +func UserExists(username, role string) (User, error) { + username = config.convertName(username) + return provider.userExists(username, role) +} + +// GetAdminSignature returns the signature for the admin with the specified +// username. +func GetAdminSignature(username string) (string, error) { + username = config.convertName(username) + return provider.getAdminSignature(username) +} + +// GetUserSignature returns the signature for the user with the specified +// username. +func GetUserSignature(username string) (string, error) { + username = config.convertName(username) + return provider.getUserSignature(username) +} + +// GetUserWithGroupSettings tries to return the user with the specified username +// loading also the group settings +func GetUserWithGroupSettings(username, role string) (User, error) { + username = config.convertName(username) + user, err := provider.userExists(username, role) + if err != nil { + return user, err + } + err = user.LoadAndApplyGroupSettings() + return user, err +} + +// GetUserVariants tries to return the user with the specified username with and without +// group settings applied +func GetUserVariants(username, role string) (User, User, error) { + username = config.convertName(username) + user, err := provider.userExists(username, role) + if err != nil { + return user, User{}, err + } + userWithGroupSettings := user.getACopy() + err = userWithGroupSettings.LoadAndApplyGroupSettings() + return user, userWithGroupSettings, err +} + +// AddUser adds a new SFTPGo user. +func AddUser(user *User, executor, ipAddress, role string) error { + user.Username = config.convertName(user.Username) + err := provider.addUser(user) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectUser, user.Username, role, user) + } + return err +} + +// UpdateUserPassword updates the user password +func UpdateUserPassword(username, plainPwd, executor, ipAddress, role string) error { + user, err := provider.userExists(username, role) + if err != nil { + return err + } + userCopy := user.getACopy() + userCopy.Password = plainPwd + if err := createUserPasswordHash(&userCopy); err != nil { + return err + } + user.LastPasswordChange = userCopy.LastPasswordChange + user.Password = userCopy.Password + user.Filters.RequirePasswordChange = false + // the last password change is set when validating the user + if err := provider.updateUser(&user); err != nil { + return err + } + webDAVUsersCache.swap(&user, plainPwd) + executeAction(operationUpdate, executor, ipAddress, actionObjectUser, username, role, &user) + return nil +} + +// UpdateUser updates an existing SFTPGo user. +func UpdateUser(user *User, executor, ipAddress, role string) error { + if user.groupSettingsApplied { + return errors.New("cannot save a user with group settings applied") + } + err := provider.updateUser(user) + if err == nil { + webDAVUsersCache.swap(user, "") + executeAction(operationUpdate, executor, ipAddress, actionObjectUser, user.Username, role, user) + } + return err +} + +// DeleteUser deletes an existing SFTPGo user. +func DeleteUser(username, executor, ipAddress, role string) error { + username = config.convertName(username) + user, err := provider.userExists(username, role) + if err != nil { + return err + } + err = provider.deleteUser(user, config.IsShared == 1) + if err == nil { + RemoveCachedWebDAVUser(user.Username) + delayedQuotaUpdater.resetUserQuota(user.Username) + cachedUserPasswords.Remove(username) + executeAction(operationDelete, executor, ipAddress, actionObjectUser, user.Username, role, &user) + } + return err +} + +// AddActiveTransfer stores the specified transfer +func AddActiveTransfer(transfer ActiveTransfer) { + if err := provider.addActiveTransfer(transfer); err != nil { + providerLog(logger.LevelError, "unable to add transfer id %v, connection id %v: %v", + transfer.ID, transfer.ConnID, err) + } +} + +// UpdateActiveTransferSizes updates the current upload and download sizes for the specified transfer +func UpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) { + if err := provider.updateActiveTransferSizes(ulSize, dlSize, transferID, connectionID); err != nil { + providerLog(logger.LevelError, "unable to update sizes for transfer id %v, connection id %v: %v", + transferID, connectionID, err) + } +} + +// RemoveActiveTransfer removes the specified transfer +func RemoveActiveTransfer(transferID int64, connectionID string) { + if err := provider.removeActiveTransfer(transferID, connectionID); err != nil { + providerLog(logger.LevelError, "unable to delete transfer id %v, connection id %v: %v", + transferID, connectionID, err) + } +} + +// CleanupActiveTransfers removes the transfer before the specified time +func CleanupActiveTransfers(before time.Time) error { + err := provider.cleanupActiveTransfers(before) + if err == nil { + providerLog(logger.LevelDebug, "deleted active transfers updated before: %v", before) + } else { + providerLog(logger.LevelError, "error deleting active transfers updated before %v: %v", before, err) + } + return err +} + +// GetActiveTransfers retrieves the active transfers with an update time after the specified value +func GetActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return provider.getActiveTransfers(from) +} + +// AddSharedSession stores a new session within the data provider +func AddSharedSession(session Session) error { + err := provider.addSharedSession(session) + if err != nil { + providerLog(logger.LevelError, "unable to add shared session, key %q, type: %v, err: %v", + session.Key, session.Type, err) + } + return err +} + +// DeleteSharedSession deletes the session with the specified key +func DeleteSharedSession(key string, sessionType SessionType) error { + err := provider.deleteSharedSession(key, sessionType) + if err != nil { + providerLog(logger.LevelError, "unable to add shared session, key %q, err: %v", key, err) + } + return err +} + +// GetSharedSession retrieves the session with the specified key +func GetSharedSession(key string, sessionType SessionType) (Session, error) { + return provider.getSharedSession(key, sessionType) +} + +// CleanupSharedSessions removes the shared session with the specified type and +// before the specified time +func CleanupSharedSessions(sessionType SessionType, before time.Time) error { + err := provider.cleanupSharedSessions(sessionType, util.GetTimeAsMsSinceEpoch(before)) + if err == nil { + providerLog(logger.LevelDebug, "deleted shared sessions before: %v, type: %v", before, sessionType) + } else { + providerLog(logger.LevelError, "error deleting shared session before %v, type %v: %v", before, sessionType, err) + } + return err +} + +// ReloadConfig reloads provider configuration. +// Currently only implemented for memory provider, allows to reload the users +// from the configured file, if defined +func ReloadConfig() error { + return provider.reloadConfig() +} + +// GetShares returns an array of shares respecting limit and offset +func GetShares(limit, offset int, order, username string) ([]Share, error) { + return provider.getShares(limit, offset, order, username) +} + +// GetAPIKeys returns an array of API keys respecting limit and offset +func GetAPIKeys(limit, offset int, order string) ([]APIKey, error) { + return provider.getAPIKeys(limit, offset, order) +} + +// GetAdmins returns an array of admins respecting limit and offset +func GetAdmins(limit, offset int, order string) ([]Admin, error) { + return provider.getAdmins(limit, offset, order) +} + +// GetRoles returns an array of roles respecting limit and offset +func GetRoles(limit, offset int, order string, minimal bool) ([]Role, error) { + return provider.getRoles(limit, offset, order, minimal) +} + +// GetGroups returns an array of groups respecting limit and offset +func GetGroups(limit, offset int, order string, minimal bool) ([]Group, error) { + return provider.getGroups(limit, offset, order, minimal) +} + +// GetUsers returns an array of users respecting limit and offset +func GetUsers(limit, offset int, order, role string) ([]User, error) { + return provider.getUsers(limit, offset, order, role) +} + +// GetUsersForQuotaCheck returns the users with the fields required for a quota check +func GetUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return provider.getUsersForQuotaCheck(toFetch) +} + +// AddFolder adds a new virtual folder. +func AddFolder(folder *vfs.BaseVirtualFolder, executor, ipAddress, role string) error { + folder.Name = config.convertName(folder.Name) + err := provider.addFolder(folder) + if err == nil { + executeAction(operationAdd, executor, ipAddress, actionObjectFolder, folder.Name, role, &wrappedFolder{Folder: *folder}) + } + return err +} + +// UpdateFolder updates the specified virtual folder +func UpdateFolder(folder *vfs.BaseVirtualFolder, users []string, groups []string, executor, ipAddress, role string) error { + err := provider.updateFolder(folder) + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectFolder, folder.Name, role, &wrappedFolder{Folder: *folder}) + usersInGroups, errGrp := provider.getUsersInGroups(groups) + if errGrp == nil { + users = append(users, usersInGroups...) + users = util.RemoveDuplicates(users, false) + } else { + providerLog(logger.LevelWarn, "unable to get users in groups %+v: %v", groups, errGrp) + } + for _, user := range users { + provider.setUpdatedAt(user) + u, err := provider.userExists(user, "") + if err == nil { + webDAVUsersCache.swap(&u, "") + executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) + } else { + RemoveCachedWebDAVUser(user) + } + } + } + return err +} + +// DeleteFolder deletes an existing folder. +func DeleteFolder(folderName, executor, ipAddress, role string) error { + folderName = config.convertName(folderName) + folder, err := provider.getFolderByName(folderName) + if err != nil { + return err + } + err = provider.deleteFolder(folder) + if err == nil { + executeAction(operationDelete, executor, ipAddress, actionObjectFolder, folder.Name, role, &wrappedFolder{Folder: folder}) + users := folder.Users + usersInGroups, errGrp := provider.getUsersInGroups(folder.Groups) + if errGrp == nil { + users = append(users, usersInGroups...) + users = util.RemoveDuplicates(users, false) + } else { + providerLog(logger.LevelWarn, "unable to get users in groups %+v: %v", folder.Groups, errGrp) + } + for _, user := range users { + provider.setUpdatedAt(user) + u, err := provider.userExists(user, "") + if err == nil { + executeAction(operationUpdate, executor, ipAddress, actionObjectUser, u.Username, u.Role, &u) + } + RemoveCachedWebDAVUser(user) + } + delayedQuotaUpdater.resetFolderQuota(folderName) + } + return err +} + +// GetFolderByName returns the folder with the specified name if any +func GetFolderByName(name string) (vfs.BaseVirtualFolder, error) { + name = config.convertName(name) + return provider.getFolderByName(name) +} + +// GetFolders returns an array of folders respecting limit and offset +func GetFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { + return provider.getFolders(limit, offset, order, minimal) +} + +func dumpUsers(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeUsers) { + users, err := provider.dumpUsers() + if err != nil { + return err + } + data.Users = users + } + return nil +} + +func dumpFolders(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeFolders) { + folders, err := provider.dumpFolders() + if err != nil { + return err + } + data.Folders = folders + } + return nil +} + +func dumpGroups(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeGroups) { + groups, err := provider.dumpGroups() + if err != nil { + return err + } + data.Groups = groups + } + return nil +} + +func dumpAdmins(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAdmins) { + admins, err := provider.dumpAdmins() + if err != nil { + return err + } + data.Admins = admins + } + return nil +} + +func dumpAPIKeys(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeAPIKeys) { + apiKeys, err := provider.dumpAPIKeys() + if err != nil { + return err + } + data.APIKeys = apiKeys + } + return nil +} + +func dumpShares(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeShares) { + shares, err := provider.dumpShares() + if err != nil { + return err + } + data.Shares = shares + } + return nil +} + +func dumpActions(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeActions) { + actions, err := provider.dumpEventActions() + if err != nil { + return err + } + data.EventActions = actions + } + return nil +} + +func dumpRules(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRules) { + rules, err := provider.dumpEventRules() + if err != nil { + return err + } + data.EventRules = rules + } + return nil +} + +func dumpRoles(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeRoles) { + roles, err := provider.dumpRoles() + if err != nil { + return err + } + data.Roles = roles + } + return nil +} + +func dumpIPLists(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeIPLists) { + ipLists, err := provider.dumpIPListEntries() + if err != nil { + return err + } + data.IPLists = ipLists + } + return nil +} + +func dumpConfigs(data *BackupData, scopes []string) error { + if len(scopes) == 0 || slices.Contains(scopes, DumpScopeConfigs) { + configs, err := provider.getConfigs() + if err != nil { + return err + } + data.Configs = &configs + } + return nil +} + +// DumpData returns a dump containing the requested scopes. +// Empty scopes means all +func DumpData(scopes []string) (BackupData, error) { + data := BackupData{ + Version: DumpVersion, + } + if err := dumpGroups(&data, scopes); err != nil { + return data, err + } + if err := dumpUsers(&data, scopes); err != nil { + return data, err + } + if err := dumpFolders(&data, scopes); err != nil { + return data, err + } + if err := dumpAdmins(&data, scopes); err != nil { + return data, err + } + if err := dumpAPIKeys(&data, scopes); err != nil { + return data, err + } + if err := dumpShares(&data, scopes); err != nil { + return data, err + } + if err := dumpActions(&data, scopes); err != nil { + return data, err + } + if err := dumpRules(&data, scopes); err != nil { + return data, err + } + if err := dumpRoles(&data, scopes); err != nil { + return data, err + } + if err := dumpIPLists(&data, scopes); err != nil { + return data, err + } + if err := dumpConfigs(&data, scopes); err != nil { + return data, err + } + + return data, nil +} + +// ParseDumpData tries to parse data as BackupData +func ParseDumpData(data []byte) (BackupData, error) { + var dump BackupData + err := json.Unmarshal(data, &dump) + if err != nil { + return dump, err + } + if dump.Version < 17 { + providerLog(logger.LevelInfo, "updating placeholders for actions restored from dump version %d", dump.Version) + eventActions, err := updateEventActionPlaceholders(dump.EventActions) + if err != nil { + return dump, fmt.Errorf("unable to update event action placeholders for dump version %d: %w", dump.Version, err) + } + dump.EventActions = eventActions + } + return dump, err +} + +// GetProviderConfig returns the current provider configuration +func GetProviderConfig() Config { + return config +} + +// GetProviderStatus returns an error if the provider is not available +func GetProviderStatus() ProviderStatus { + err := provider.checkAvailability() + status := ProviderStatus{ + Driver: config.Driver, + } + if err == nil { + status.IsActive = true + } else { + status.IsActive = false + status.Error = err.Error() + } + return status +} + +// Close releases all provider resources. +// This method is used in test cases. +// Closing an uninitialized provider is not supported +func Close() error { + stopScheduler() + return provider.close() +} + +func createProvider(basePath string) error { + sqlPlaceholders = getSQLPlaceholders() + if err := validateSQLTablesPrefix(); err != nil { + return err + } + logSender = fmt.Sprintf("dataprovider_%v", config.Driver) + + switch config.Driver { + case SQLiteDataProviderName: + return initializeSQLiteProvider(basePath) + case PGSQLDataProviderName, CockroachDataProviderName: + return initializePGSQLProvider() + case MySQLDataProviderName: + return initializeMySQLProvider() + case BoltDataProviderName: + return initializeBoltProvider(basePath) + case MemoryDataProviderName: + if err := initializeMemoryProvider(basePath); err != nil { + logger.Warn(logSender, "", "provider initialized but data loading failed: %v", err) + logger.WarnToConsole("provider initialized but data loading failed: %v", err) + } + return nil + default: + return fmt.Errorf("unsupported data provider: %v", config.Driver) + } +} + +func copyBaseUserFilters(in sdk.BaseUserFilters) sdk.BaseUserFilters { + filters := sdk.BaseUserFilters{} + filters.MaxUploadFileSize = in.MaxUploadFileSize + filters.TLSUsername = in.TLSUsername + filters.UserType = in.UserType + filters.AllowedIP = make([]string, len(in.AllowedIP)) + copy(filters.AllowedIP, in.AllowedIP) + filters.DeniedIP = make([]string, len(in.DeniedIP)) + copy(filters.DeniedIP, in.DeniedIP) + filters.DeniedLoginMethods = make([]string, len(in.DeniedLoginMethods)) + copy(filters.DeniedLoginMethods, in.DeniedLoginMethods) + filters.FilePatterns = make([]sdk.PatternsFilter, len(in.FilePatterns)) + copy(filters.FilePatterns, in.FilePatterns) + filters.DeniedProtocols = make([]string, len(in.DeniedProtocols)) + copy(filters.DeniedProtocols, in.DeniedProtocols) + filters.TwoFactorAuthProtocols = make([]string, len(in.TwoFactorAuthProtocols)) + copy(filters.TwoFactorAuthProtocols, in.TwoFactorAuthProtocols) + filters.Hooks.ExternalAuthDisabled = in.Hooks.ExternalAuthDisabled + filters.Hooks.PreLoginDisabled = in.Hooks.PreLoginDisabled + filters.Hooks.CheckPasswordDisabled = in.Hooks.CheckPasswordDisabled + filters.DisableFsChecks = in.DisableFsChecks + filters.StartDirectory = in.StartDirectory + filters.FTPSecurity = in.FTPSecurity + filters.IsAnonymous = in.IsAnonymous + filters.AllowAPIKeyAuth = in.AllowAPIKeyAuth + filters.ExternalAuthCacheTime = in.ExternalAuthCacheTime + filters.DefaultSharesExpiration = in.DefaultSharesExpiration + filters.MaxSharesExpiration = in.MaxSharesExpiration + filters.PasswordExpiration = in.PasswordExpiration + filters.PasswordStrength = in.PasswordStrength + filters.WebClient = make([]string, len(in.WebClient)) + copy(filters.WebClient, in.WebClient) + filters.TLSCerts = make([]string, len(in.TLSCerts)) + copy(filters.TLSCerts, in.TLSCerts) + filters.BandwidthLimits = make([]sdk.BandwidthLimit, 0, len(in.BandwidthLimits)) + for _, limit := range in.BandwidthLimits { + bwLimit := sdk.BandwidthLimit{ + UploadBandwidth: limit.UploadBandwidth, + DownloadBandwidth: limit.DownloadBandwidth, + Sources: make([]string, 0, len(limit.Sources)), + } + bwLimit.Sources = make([]string, len(limit.Sources)) + copy(bwLimit.Sources, limit.Sources) + filters.BandwidthLimits = append(filters.BandwidthLimits, bwLimit) + } + filters.AccessTime = make([]sdk.TimePeriod, 0, len(in.AccessTime)) + for _, period := range in.AccessTime { + filters.AccessTime = append(filters.AccessTime, sdk.TimePeriod{ + DayOfWeek: period.DayOfWeek, + From: period.From, + To: period.To, + }) + } + return filters +} + +func buildUserHomeDir(user *User) { + if user.HomeDir == "" { + if config.UsersBaseDir != "" { + user.HomeDir = filepath.Join(config.UsersBaseDir, user.Username) + return + } + switch user.FsConfig.Provider { + case sdk.SFTPFilesystemProvider, sdk.S3FilesystemProvider, sdk.AzureBlobFilesystemProvider, sdk.GCSFilesystemProvider, sdk.HTTPFilesystemProvider: + if tempPath != "" { + user.HomeDir = filepath.Join(tempPath, user.Username) + } else { + user.HomeDir = filepath.Join(os.TempDir(), user.Username) + } + } + } else { + user.HomeDir = filepath.Clean(user.HomeDir) + } +} + +func validateFolderQuotaLimits(folder vfs.VirtualFolder) error { + if folder.QuotaSize < -1 { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid quota_size: %v folder path %q", folder.QuotaSize, folder.MappedPath)), + util.I18nErrorFolderQuotaSizeInvalid, + ) + } + if folder.QuotaFiles < -1 { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid quota_file: %v folder path %q", folder.QuotaFiles, folder.MappedPath)), + util.I18nErrorFolderQuotaFileInvalid, + ) + } + if (folder.QuotaSize == -1 && folder.QuotaFiles != -1) || (folder.QuotaFiles == -1 && folder.QuotaSize != -1) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("virtual folder quota_size and quota_files must be both -1 or >= 0, quota_size: %v quota_files: %v", + folder.QuotaFiles, folder.QuotaSize)), + util.I18nErrorFolderQuotaInvalid, + ) + } + return nil +} + +func validateUserGroups(user *User) error { + if len(user.Groups) == 0 { + return nil + } + hasPrimary := false + groupNames := make(map[string]bool) + + for _, g := range user.Groups { + if g.Type < sdk.GroupTypePrimary || g.Type > sdk.GroupTypeMembership { + return util.NewValidationError(fmt.Sprintf("invalid group type: %v", g.Type)) + } + if g.Type == sdk.GroupTypePrimary { + if hasPrimary { + return util.NewI18nError( + util.NewValidationError("only one primary group is allowed"), + util.I18nErrorPrimaryGroup, + ) + } + hasPrimary = true + } + if groupNames[g.Name] { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("the group %q is duplicated", g.Name)), + util.I18nErrorDuplicateGroup, + ) + } + groupNames[g.Name] = true + } + return nil +} + +func validateAssociatedVirtualFolders(vfolders []vfs.VirtualFolder) ([]vfs.VirtualFolder, error) { + if len(vfolders) == 0 { + return []vfs.VirtualFolder{}, nil + } + var virtualFolders []vfs.VirtualFolder + folderNames := make(map[string]bool) + + for _, v := range vfolders { + v.Name = config.convertName(v.Name) + if v.VirtualPath == "" { + return nil, util.NewI18nError( + util.NewValidationError("mount/virtual path is mandatory"), + util.I18nErrorFolderMountPathRequired, + ) + } + cleanedVPath := util.CleanPath(v.VirtualPath) + if err := validateFolderQuotaLimits(v); err != nil { + return nil, err + } + if v.Name == "" { + return nil, util.NewI18nError(util.NewValidationError("folder name is mandatory"), util.I18nErrorFolderNameRequired) + } + if folderNames[v.Name] { + return nil, util.NewI18nError( + util.NewValidationError(fmt.Sprintf("the folder %q is duplicated", v.Name)), + util.I18nErrorDuplicatedFolders, + ) + } + for _, vFolder := range virtualFolders { + if util.IsDirOverlapped(vFolder.VirtualPath, cleanedVPath, false, "/") { + return nil, util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid virtual folder %q, it overlaps with virtual folder %q", + v.VirtualPath, vFolder.VirtualPath)), + util.I18nErrorOverlappedFolders, + ) + } + } + virtualFolders = append(virtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: v.Name, + }, + VirtualPath: cleanedVPath, + QuotaSize: v.QuotaSize, + QuotaFiles: v.QuotaFiles, + }) + folderNames[v.Name] = true + } + return virtualFolders, nil +} + +func validateUserTOTPConfig(c *UserTOTPConfig, username string) error { + if !c.Enabled { + c.ConfigName = "" + c.Secret = kms.NewEmptySecret() + c.Protocols = nil + return nil + } + if c.ConfigName == "" { + return util.NewValidationError("totp: config name is mandatory") + } + if !slices.Contains(mfa.GetAvailableTOTPConfigNames(), c.ConfigName) { + return util.NewValidationError(fmt.Sprintf("totp: config name %q not found", c.ConfigName)) + } + if c.Secret.IsEmpty() { + return util.NewValidationError("totp: secret is mandatory") + } + if c.Secret.IsPlain() { + c.Secret.SetAdditionalData(username) + if err := c.Secret.Encrypt(); err != nil { + return util.NewValidationError(fmt.Sprintf("totp: unable to encrypt secret: %v", err)) + } + } + if len(c.Protocols) == 0 { + return util.NewValidationError("totp: specify at least one protocol") + } + for _, protocol := range c.Protocols { + if !slices.Contains(MFAProtocols, protocol) { + return util.NewValidationError(fmt.Sprintf("totp: invalid protocol %q", protocol)) + } + } + return nil +} + +func validateUserRecoveryCodes(user *User) error { + for i := 0; i < len(user.Filters.RecoveryCodes); i++ { + code := &user.Filters.RecoveryCodes[i] + if code.Secret.IsEmpty() { + return util.NewValidationError("mfa: recovery code cannot be empty") + } + if code.Secret.IsPlain() { + code.Secret.SetAdditionalData(user.Username) + if err := code.Secret.Encrypt(); err != nil { + return util.NewValidationError(fmt.Sprintf("mfa: unable to encrypt recovery code: %v", err)) + } + } + } + return nil +} + +func validateUserPermissions(permsToCheck map[string][]string) (map[string][]string, error) { + permissions := make(map[string][]string) + for dir, perms := range permsToCheck { + if len(perms) == 0 && dir == "/" { + return permissions, util.NewValidationError(fmt.Sprintf("no permissions granted for the directory: %q", dir)) + } + if len(perms) > len(ValidPerms) { + return permissions, util.NewValidationError("invalid permissions") + } + for _, p := range perms { + if !slices.Contains(ValidPerms, p) { + return permissions, util.NewValidationError(fmt.Sprintf("invalid permission: %q", p)) + } + } + cleanedDir := filepath.ToSlash(path.Clean(dir)) + if cleanedDir != "/" { + cleanedDir = strings.TrimSuffix(cleanedDir, "/") + } + if !path.IsAbs(cleanedDir) { + return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for non absolute path: %q", dir)) + } + if dir != cleanedDir && cleanedDir == "/" { + return permissions, util.NewValidationError(fmt.Sprintf("cannot set permissions for invalid subdirectory: %q is an alias for \"/\"", dir)) + } + if slices.Contains(perms, PermAny) { + permissions[cleanedDir] = []string{PermAny} + } else { + permissions[cleanedDir] = util.RemoveDuplicates(perms, false) + } + } + + return permissions, nil +} + +func validatePermissions(user *User) error { + if len(user.Permissions) == 0 { + return util.NewI18nError(util.NewValidationError("please grant some permissions to this user"), util.I18nErrorNoPermission) + } + if _, ok := user.Permissions["/"]; !ok { + return util.NewI18nError(util.NewValidationError("permissions for the root dir \"/\" must be set"), util.I18nErrorNoRootPermission) + } + permissions, err := validateUserPermissions(user.Permissions) + if err != nil { + return util.NewI18nError(err, util.I18nErrorGenericPermission) + } + user.Permissions = permissions + return nil +} + +func validatePublicKeys(user *User) error { + if len(user.PublicKeys) == 0 { + user.PublicKeys = []string{} + } + var validatedKeys []string + for idx, key := range user.PublicKeys { + if key == "" { + continue + } + out, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key)) + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("error parsing public key at position %d: %v", idx, err)), + util.I18nErrorPubKeyInvalid, + ) + } + if out.Type() == ssh.InsecureKeyAlgoDSA { //nolint:staticcheck + providerLog(logger.LevelError, "dsa public key not accepted, position: %d", idx) + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("DSA key format is insecure and it is not allowed for key at position %d", idx)), + util.I18nErrorKeyInsecure, + ) + } + if k, ok := out.(ssh.CryptoPublicKey); ok { + cryptoKey := k.CryptoPublicKey() + if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok { + if size := rsaKey.N.BitLen(); size < 2048 { + providerLog(logger.LevelError, "rsa key with size %d at position %d not accepted, minimum 2048", size, idx) + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid size %d for rsa key at position %d, minimum 2048", + size, idx)), + util.I18nErrorKeySizeInvalid, + ) + } + } + } + + validatedKeys = append(validatedKeys, key) + } + user.PublicKeys = util.RemoveDuplicates(validatedKeys, false) + return nil +} + +func validateFiltersPatternExtensions(baseFilters *sdk.BaseUserFilters) error { + if len(baseFilters.FilePatterns) == 0 { + baseFilters.FilePatterns = []sdk.PatternsFilter{} + return nil + } + filteredPaths := []string{} + var filters []sdk.PatternsFilter + for _, f := range baseFilters.FilePatterns { + cleanedPath := filepath.ToSlash(path.Clean(f.Path)) + if !path.IsAbs(cleanedPath) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid path %q for file patterns filter", f.Path)), + util.I18nErrorFilePatternPathInvalid, + ) + } + if slices.Contains(filteredPaths, cleanedPath) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("duplicate file patterns filter for path %q", f.Path)), + util.I18nErrorFilePatternDuplicated, + ) + } + if len(f.AllowedPatterns) == 0 && len(f.DeniedPatterns) == 0 { + return util.NewValidationError(fmt.Sprintf("empty file patterns filter for path %q", f.Path)) + } + if f.DenyPolicy < sdk.DenyPolicyDefault || f.DenyPolicy > sdk.DenyPolicyHide { + return util.NewValidationError(fmt.Sprintf("invalid deny policy %v for path %q", f.DenyPolicy, f.Path)) + } + f.Path = cleanedPath + allowed := make([]string, 0, len(f.AllowedPatterns)) + denied := make([]string, 0, len(f.DeniedPatterns)) + for _, pattern := range f.AllowedPatterns { + _, err := path.Match(pattern, "abc") + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid file pattern filter %q", pattern)), + util.I18nErrorFilePatternInvalid, + ) + } + allowed = append(allowed, strings.ToLower(pattern)) + } + for _, pattern := range f.DeniedPatterns { + _, err := path.Match(pattern, "abc") + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid file pattern filter %q", pattern)), + util.I18nErrorFilePatternInvalid, + ) + } + denied = append(denied, strings.ToLower(pattern)) + } + f.AllowedPatterns = util.RemoveDuplicates(allowed, false) + f.DeniedPatterns = util.RemoveDuplicates(denied, false) + filters = append(filters, f) + filteredPaths = append(filteredPaths, cleanedPath) + } + baseFilters.FilePatterns = filters + return nil +} + +func checkEmptyFiltersStruct(filters *sdk.BaseUserFilters) { + if len(filters.AllowedIP) == 0 { + filters.AllowedIP = []string{} + } + if len(filters.DeniedIP) == 0 { + filters.DeniedIP = []string{} + } + if len(filters.DeniedLoginMethods) == 0 { + filters.DeniedLoginMethods = []string{} + } + if len(filters.DeniedProtocols) == 0 { + filters.DeniedProtocols = []string{} + } +} + +func validateIPFilters(filters *sdk.BaseUserFilters) error { + filters.DeniedIP = util.RemoveDuplicates(filters.DeniedIP, false) + for _, IPMask := range filters.DeniedIP { + _, _, err := net.ParseCIDR(IPMask) + if err != nil { + return util.NewValidationError(fmt.Sprintf("could not parse denied IP/Mask %q: %v", IPMask, err)) + } + } + filters.AllowedIP = util.RemoveDuplicates(filters.AllowedIP, false) + for _, IPMask := range filters.AllowedIP { + _, _, err := net.ParseCIDR(IPMask) + if err != nil { + return util.NewValidationError(fmt.Sprintf("could not parse allowed IP/Mask %q: %v", IPMask, err)) + } + } + return nil +} + +func validateBandwidthLimit(bl sdk.BandwidthLimit) error { + if len(bl.Sources) == 0 { + return util.NewValidationError("no bandwidth limit source specified") + } + for _, source := range bl.Sources { + _, _, err := net.ParseCIDR(source) + if err != nil { + return util.NewValidationError(fmt.Sprintf("could not parse bandwidth limit source %q: %v", source, err)) + } + } + return nil +} + +func validateBandwidthLimitsFilter(filters *sdk.BaseUserFilters) error { + for idx, bandwidthLimit := range filters.BandwidthLimits { + if err := validateBandwidthLimit(bandwidthLimit); err != nil { + return err + } + if bandwidthLimit.DownloadBandwidth < 0 { + filters.BandwidthLimits[idx].DownloadBandwidth = 0 + } + if bandwidthLimit.UploadBandwidth < 0 { + filters.BandwidthLimits[idx].UploadBandwidth = 0 + } + } + return nil +} + +func updateFiltersValues(filters *sdk.BaseUserFilters) { + if filters.StartDirectory != "" { + filters.StartDirectory = util.CleanPath(filters.StartDirectory) + if filters.StartDirectory == "/" { + filters.StartDirectory = "" + } + } +} + +func validateFilterProtocols(filters *sdk.BaseUserFilters) error { + if len(filters.DeniedProtocols) >= len(ValidProtocols) { + return util.NewValidationError("invalid denied_protocols") + } + for _, p := range filters.DeniedProtocols { + if !slices.Contains(ValidProtocols, p) { + return util.NewValidationError(fmt.Sprintf("invalid denied protocol %q", p)) + } + } + + for _, p := range filters.TwoFactorAuthProtocols { + if !slices.Contains(MFAProtocols, p) { + return util.NewValidationError(fmt.Sprintf("invalid two factor protocol %q", p)) + } + } + return nil +} + +func validateTLSCerts(certs []string) ([]string, error) { + var validateCerts []string + for idx, cert := range certs { + if cert == "" { + continue + } + derBlock, _ := pem.Decode([]byte(cert)) + if derBlock == nil { + return nil, util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid TLS certificate %d", idx)), + util.I18nErrorInvalidTLSCert, + ) + } + crt, err := x509.ParseCertificate(derBlock.Bytes) + if err != nil { + return nil, util.NewI18nError( + util.NewValidationError(fmt.Sprintf("error parsing TLS certificate %d", idx)), + util.I18nErrorInvalidTLSCert, + ) + } + if crt.PublicKeyAlgorithm == x509.RSA { + if rsaCert, ok := crt.PublicKey.(*rsa.PublicKey); ok { + if size := rsaCert.N.BitLen(); size < 2048 { + providerLog(logger.LevelError, "rsa cert with size %d not accepted, minimum 2048", size) + return nil, util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid size %d for rsa cert at position %d, minimum 2048", + size, idx)), + util.I18nErrorKeySizeInvalid, + ) + } + } + } + validateCerts = append(validateCerts, cert) + } + return validateCerts, nil +} + +func validateBaseFilters(filters *sdk.BaseUserFilters) error { + checkEmptyFiltersStruct(filters) + if err := validateIPFilters(filters); err != nil { + return util.NewI18nError(err, util.I18nErrorIPFiltersInvalid) + } + if err := validateBandwidthLimitsFilter(filters); err != nil { + return util.NewI18nError(err, util.I18nErrorSourceBWLimitInvalid) + } + if len(filters.DeniedLoginMethods) >= len(ValidLoginMethods) { + return util.NewValidationError("invalid denied_login_methods") + } + for _, loginMethod := range filters.DeniedLoginMethods { + if !slices.Contains(ValidLoginMethods, loginMethod) { + return util.NewValidationError(fmt.Sprintf("invalid login method: %q", loginMethod)) + } + } + if err := validateFilterProtocols(filters); err != nil { + return err + } + if filters.TLSUsername != "" { + if !slices.Contains(validTLSUsernames, string(filters.TLSUsername)) { + return util.NewValidationError(fmt.Sprintf("invalid TLS username: %q", filters.TLSUsername)) + } + } + certs, err := validateTLSCerts(filters.TLSCerts) + if err != nil { + return err + } + filters.TLSCerts = certs + for _, opts := range filters.WebClient { + if !slices.Contains(sdk.WebClientOptions, opts) { + return util.NewValidationError(fmt.Sprintf("invalid web client options %q", opts)) + } + } + if filters.MaxSharesExpiration > 0 && filters.MaxSharesExpiration < filters.DefaultSharesExpiration { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("default shares expiration: %d must be less than or equal to max shares expiration: %d", + filters.DefaultSharesExpiration, filters.MaxSharesExpiration)), + util.I18nErrorShareExpirationInvalid, + ) + } + updateFiltersValues(filters) + + if err := validateAccessTimeFilters(filters); err != nil { + return err + } + + return validateFiltersPatternExtensions(filters) +} + +func isTimeOfDayValid(value string) bool { + if len(value) != 5 { + return false + } + parts := strings.Split(value, ":") + if len(parts) != 2 { + return false + } + hour, err := strconv.Atoi(parts[0]) + if err != nil { + return false + } + if hour < 0 || hour > 23 { + return false + } + minute, err := strconv.Atoi(parts[1]) + if err != nil { + return false + } + if minute < 0 || minute > 59 { + return false + } + return true +} + +func validateAccessTimeFilters(filters *sdk.BaseUserFilters) error { + for _, period := range filters.AccessTime { + if period.DayOfWeek < int(time.Sunday) || period.DayOfWeek > int(time.Saturday) { + return util.NewValidationError(fmt.Sprintf("invalid day of week: %d", period.DayOfWeek)) + } + if !isTimeOfDayValid(period.From) || !isTimeOfDayValid(period.To) { + return util.NewI18nError( + util.NewValidationError("invalid time of day. Supported format: HH:MM"), + util.I18nErrorTimeOfDayInvalid, + ) + } + if period.To <= period.From { + return util.NewI18nError( + util.NewValidationError("invalid time of day. The end time cannot be earlier than the start time"), + util.I18nErrorTimeOfDayConflict, + ) + } + } + + return nil +} + +func validateCombinedUserFilters(user *User) error { + if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) { + return util.NewI18nError( + util.NewValidationError("two-factor authentication cannot be disabled for a user with an active configuration"), + util.I18nErrorDisableActive2FA, + ) + } + if user.Filters.RequirePasswordChange && slices.Contains(user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) { + return util.NewI18nError( + util.NewValidationError("you cannot require password change and at the same time disallow it"), + util.I18nErrorPwdChangeConflict, + ) + } + if len(user.Filters.TwoFactorAuthProtocols) > 0 && slices.Contains(user.Filters.WebClient, sdk.WebClientMFADisabled) { + return util.NewI18nError( + util.NewValidationError("you cannot require two-factor authentication and at the same time disallow it"), + util.I18nError2FAConflict, + ) + } + return nil +} + +func validateEmails(user *User) error { + if user.Email != "" && !util.IsEmailValid(user.Email) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("email %q is not valid", user.Email)), + util.I18nErrorInvalidEmail, + ) + } + for _, email := range user.Filters.AdditionalEmails { + if !util.IsEmailValid(email) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("email %q is not valid", email)), + util.I18nErrorInvalidEmail, + ) + } + } + return nil +} + +func validateBaseParams(user *User) error { + if user.Username == "" { + return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) + } + if !util.IsNameValid(user.Username) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if err := checkReservedUsernames(user.Username); err != nil { + return util.NewI18nError(err, util.I18nErrorReservedUsername) + } + if err := validateEmails(user); err != nil { + return err + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(user.Username) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("username %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", user.Username)), + util.I18nErrorInvalidUser, + ) + } + if user.hasRedactedSecret() { + return util.NewValidationError("cannot save a user with a redacted secret") + } + if user.HomeDir == "" { + return util.NewI18nError(util.NewValidationError("home_dir is mandatory"), util.I18nErrorHomeRequired) + } + // we can have users with no passwords and public keys, they can authenticate via SSH user certs or OIDC + /*if user.Password == "" && len(user.PublicKeys) == 0 { + return util.NewValidationError("please set a password or at least a public_key") + }*/ + if !filepath.IsAbs(user.HomeDir) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("home_dir must be an absolute path, actual value: %v", user.HomeDir)), + util.I18nErrorHomeInvalid, + ) + } + if user.DownloadBandwidth < 0 { + user.DownloadBandwidth = 0 + } + if user.UploadBandwidth < 0 { + user.UploadBandwidth = 0 + } + if user.TotalDataTransfer > 0 { + // if a total data transfer is defined we reset the separate upload and download limits + user.UploadDataTransfer = 0 + user.DownloadDataTransfer = 0 + } + if user.Filters.IsAnonymous { + user.setAnonymousSettings() + } + err := user.FsConfig.Validate(user.GetEncryptionAdditionalData()) + if err != nil { + return err + } + return nil +} + +func hashPlainPassword(plainPwd string) (string, error) { + if config.PasswordHashing.Algo == HashingAlgoBcrypt { + pwd, err := bcrypt.GenerateFromPassword([]byte(plainPwd), config.PasswordHashing.BcryptOptions.Cost) + if err != nil { + return "", fmt.Errorf("bcrypt hashing error: %w", err) + } + return util.BytesToString(pwd), nil + } + pwd, err := argon2id.CreateHash(plainPwd, argon2Params) + if err != nil { + return "", fmt.Errorf("argon2ID hashing error: %w", err) + } + return pwd, nil +} + +func createUserPasswordHash(user *User) error { + if user.Password != "" && !user.IsPasswordHashed() { + for _, g := range user.Groups { + if g.Type == sdk.GroupTypePrimary { + group, err := GroupExists(g.Name) + if err != nil { + return errors.New("unable to load group password policies") + } + if minEntropy := group.UserSettings.Filters.PasswordStrength; minEntropy > 0 { + if err := passwordvalidator.Validate(user.Password, float64(minEntropy)); err != nil { + return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) + } + } + } + } + if minEntropy := user.getMinPasswordEntropy(); minEntropy > 0 { + if err := passwordvalidator.Validate(user.Password, minEntropy); err != nil { + return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) + } + } + hashedPwd, err := hashPlainPassword(user.Password) + if err != nil { + return err + } + user.Password = hashedPwd + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now()) + } + return nil +} + +// ValidateFolder returns an error if the folder is not valid +// FIXME: this should be defined as Folder struct method +func ValidateFolder(folder *vfs.BaseVirtualFolder) error { + folder.FsConfig.SetEmptySecretsIfNil() + if folder.Name == "" { + return util.NewI18nError(util.NewValidationError("folder name is mandatory"), util.I18nErrorNameRequired) + } + if !util.IsNameValid(folder.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(folder.Name) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("folder name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", folder.Name)), + util.I18nErrorInvalidName, + ) + } + if folder.FsConfig.Provider == sdk.LocalFilesystemProvider || folder.FsConfig.Provider == sdk.CryptedFilesystemProvider || + folder.MappedPath != "" { + cleanedMPath := filepath.Clean(folder.MappedPath) + if !filepath.IsAbs(cleanedMPath) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid folder mapped path %q", folder.MappedPath)), + util.I18nErrorInvalidHomeDir, + ) + } + folder.MappedPath = cleanedMPath + } + if folder.HasRedactedSecret() { + return errors.New("cannot save a folder with a redacted secret") + } + return folder.FsConfig.Validate(folder.GetEncryptionAdditionalData()) +} + +// ValidateUser returns an error if the user is not valid +// FIXME: this should be defined as User struct method +func ValidateUser(user *User) error { + user.OIDCCustomFields = nil + user.HasPassword = false + user.SetEmptySecretsIfNil() + user.applyNamingRules() + buildUserHomeDir(user) + if err := validateBaseParams(user); err != nil { + return err + } + if err := validateUserGroups(user); err != nil { + return err + } + if err := validatePermissions(user); err != nil { + return err + } + if err := validateUserTOTPConfig(&user.Filters.TOTPConfig, user.Username); err != nil { + return util.NewI18nError(err, util.I18nError2FAInvalid) + } + if err := validateUserRecoveryCodes(user); err != nil { + return util.NewI18nError(err, util.I18nErrorRecoveryCodesInvalid) + } + vfolders, err := validateAssociatedVirtualFolders(user.VirtualFolders) + if err != nil { + return err + } + user.VirtualFolders = vfolders + if user.Status < 0 || user.Status > 1 { + return util.NewValidationError(fmt.Sprintf("invalid user status: %v", user.Status)) + } + if err := createUserPasswordHash(user); err != nil { + return err + } + if err := validatePublicKeys(user); err != nil { + return err + } + if err := validateBaseFilters(&user.Filters.BaseUserFilters); err != nil { + return err + } + if !user.HasExternalAuth() { + user.Filters.ExternalAuthCacheTime = 0 + } + return validateCombinedUserFilters(user) +} + +func isPasswordOK(user *User, password string) (bool, error) { + if config.PasswordCaching { + found, match := cachedUserPasswords.Check(user.Username, password, user.Password) + if found { + return match, nil + } + } + + match := false + updatePwd := true + var err error + + switch { + case strings.HasPrefix(user.Password, bcryptPwdPrefix): + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { + return match, ErrInvalidCredentials + } + match = true + updatePwd = config.PasswordHashing.Algo != HashingAlgoBcrypt + case strings.HasPrefix(user.Password, argonPwdPrefix): + match, err = argon2id.ComparePasswordAndHash(password, user.Password) + if err != nil { + providerLog(logger.LevelError, "error comparing password with argon hash: %v", err) + return match, err + } + updatePwd = config.PasswordHashing.Algo != HashingAlgoArgon2ID + case util.IsStringPrefixInSlice(user.Password, unixPwdPrefixes): + match, err = compareUnixPasswordAndHash(user, password) + if err != nil { + return match, err + } + case util.IsStringPrefixInSlice(user.Password, pbkdfPwdPrefixes): + match, err = comparePbkdf2PasswordAndHash(password, user.Password) + if err != nil { + return match, err + } + case util.IsStringPrefixInSlice(user.Password, digestPwdPrefixes): + match = compareDigestPasswordAndHash(user, password) + } + + if err == nil && match { + cachedUserPasswords.Add(user.Username, password, user.Password) + if updatePwd { + convertUserPassword(user.Username, password) + } + } + return match, err +} + +func convertUserPassword(username, plainPwd string) { + hashedPwd, err := hashPlainPassword(plainPwd) + if err == nil { + err = provider.updateUserPassword(username, hashedPwd) + } + if err != nil { + providerLog(logger.LevelWarn, "unable to convert password for user %s: %v", username, err) + } else { + providerLog(logger.LevelDebug, "password converted for user %s", username) + } +} + +func checkUserAndTLSCertificate(user *User, protocol string, tlsCert *x509.Certificate) (User, error) { + err := user.LoadAndApplyGroupSettings() + if err != nil { + return *user, err + } + err = user.CheckLoginConditions() + if err != nil { + return *user, err + } + switch protocol { + case protocolFTP, protocolWebDAV: + for _, cert := range user.Filters.TLSCerts { + derBlock, _ := pem.Decode(util.StringToBytes(cert)) + if derBlock != nil && bytes.Equal(derBlock.Bytes, tlsCert.Raw) { + return *user, nil + } + } + if user.Filters.TLSUsername == sdk.TLSUsernameCN { + if user.Username == tlsCert.Subject.CommonName { + return *user, nil + } + return *user, fmt.Errorf("CN %q does not match username %q", tlsCert.Subject.CommonName, user.Username) + } + return *user, errors.New("TLS certificate is not valid") + default: + return *user, fmt.Errorf("certificate authentication is not supported for protocol %v", protocol) + } +} + +func checkUserAndPass(user *User, password, ip, protocol string) (User, error) { + err := user.LoadAndApplyGroupSettings() + if err != nil { + return *user, err + } + err = user.CheckLoginConditions() + if err != nil { + return *user, err + } + if protocol != protocolHTTP && user.MustChangePassword() { + return *user, errors.New("login not allowed, password change required") + } + if user.Filters.IsAnonymous { + user.setAnonymousSettings() + return *user, nil + } + password, err = checkUserPasscode(user, password, protocol) + if err != nil { + return *user, ErrInvalidCredentials + } + if user.Password == "" || strings.TrimSpace(password) == "" { + return *user, errors.New("credentials cannot be null or empty") + } + if !user.Filters.Hooks.CheckPasswordDisabled { + hookResponse, err := executeCheckPasswordHook(user.Username, password, ip, protocol) + if err != nil { + providerLog(logger.LevelDebug, "error executing check password hook for user %q, ip %v, protocol %v: %v", + user.Username, ip, protocol, err) + return *user, errors.New("unable to check credentials") + } + switch hookResponse.Status { + case -1: + // no hook configured + case 1: + providerLog(logger.LevelDebug, "password accepted by check password hook for user %q, ip %v, protocol %v", + user.Username, ip, protocol) + return *user, nil + case 2: + providerLog(logger.LevelDebug, "partial success from check password hook for user %q, ip %v, protocol %v", + user.Username, ip, protocol) + password = hookResponse.ToVerify + default: + providerLog(logger.LevelDebug, "password rejected by check password hook for user %q, ip %v, protocol %v, status: %v", + user.Username, ip, protocol, hookResponse.Status) + return *user, ErrInvalidCredentials + } + } + + match, err := isPasswordOK(user, password) + if !match { + err = ErrInvalidCredentials + } + return *user, err +} + +func checkUserPasscode(user *User, password, protocol string) (string, error) { + if user.Filters.TOTPConfig.Enabled { + switch protocol { + case protocolFTP: + if slices.Contains(user.Filters.TOTPConfig.Protocols, protocol) { + // the TOTP passcode has six digits + pwdLen := len(password) + if pwdLen < 7 { + providerLog(logger.LevelDebug, "password len %v is too short to contain a passcode, user %q, protocol %v", + pwdLen, user.Username, protocol) + return "", util.NewValidationError("password too short, cannot contain the passcode") + } + err := user.Filters.TOTPConfig.Secret.TryDecrypt() + if err != nil { + providerLog(logger.LevelError, "unable to decrypt TOTP secret for user %q, protocol %v, err: %v", + user.Username, protocol, err) + return "", err + } + pwd := password[0:(pwdLen - 6)] + passcode := password[(pwdLen - 6):] + match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, passcode, + user.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + providerLog(logger.LevelWarn, "invalid passcode for user %q, protocol %v, err: %v", + user.Username, protocol, err) + return "", util.NewValidationError("invalid passcode") + } + return pwd, nil + } + } + } + return password, nil +} + +func checkUserAndPubKey(user *User, pubKey []byte, isSSHCert bool) (User, string, error) { + err := user.LoadAndApplyGroupSettings() + if err != nil { + return *user, "", err + } + err = user.CheckLoginConditions() + if err != nil { + return *user, "", err + } + if isSSHCert { + return *user, "", nil + } + if len(user.PublicKeys) == 0 { + return *user, "", ErrInvalidCredentials + } + for idx, key := range user.PublicKeys { + storedKey, comment, _, _, err := ssh.ParseAuthorizedKey(util.StringToBytes(key)) + if err != nil { + providerLog(logger.LevelError, "error parsing stored public key %d for user %s: %v", idx, user.Username, err) + return *user, "", err + } + if bytes.Equal(storedKey.Marshal(), pubKey) { + return *user, fmt.Sprintf("%s:%s", ssh.FingerprintSHA256(storedKey), comment), nil + } + } + return *user, "", ErrInvalidCredentials +} + +func compareDigestPasswordAndHash(user *User, password string) bool { + if strings.HasPrefix(user.Password, md5DigestPwdPrefix) { + h := md5.New() + h.Write([]byte(password)) + return fmt.Sprintf("%s%x", md5DigestPwdPrefix, h.Sum(nil)) == user.Password + } + if strings.HasPrefix(user.Password, sha256DigestPwdPrefix) { + h := sha256.New() + h.Write([]byte(password)) + return fmt.Sprintf("%s%x", sha256DigestPwdPrefix, h.Sum(nil)) == user.Password + } + if strings.HasPrefix(user.Password, sha512DigestPwdPrefix) { + h := sha512.New() + h.Write([]byte(password)) + return fmt.Sprintf("%s%x", sha512DigestPwdPrefix, h.Sum(nil)) == user.Password + } + return false +} + +func compareUnixPasswordAndHash(user *User, password string) (bool, error) { + if strings.HasPrefix(user.Password, yescryptPwdPrefix) { + return compareYescryptPassword(user.Password, password) + } + var crypter crypt.Crypter + if strings.HasPrefix(user.Password, sha512cryptPwdPrefix) { + crypter = sha512_crypt.New() + } else if strings.HasPrefix(user.Password, sha256cryptPwdPrefix) { + crypter = sha256_crypt.New() + } else if strings.HasPrefix(user.Password, md5cryptPwdPrefix) { + crypter = md5_crypt.New() + } else if strings.HasPrefix(user.Password, md5cryptApr1PwdPrefix) { + crypter = apr1_crypt.New() + } else { + return false, errors.New("unix crypt: invalid or unsupported hash format") + } + if err := crypter.Verify(user.Password, []byte(password)); err != nil { + return false, err + } + return true, nil +} + +func comparePbkdf2PasswordAndHash(password, hashedPassword string) (bool, error) { + vals := strings.Split(hashedPassword, "$") + if len(vals) != 5 { + return false, fmt.Errorf("pbkdf2: hash is not in the correct format") + } + iterations, err := strconv.Atoi(vals[2]) + if err != nil { + return false, err + } + expected, err := base64.StdEncoding.DecodeString(vals[4]) + if err != nil { + return false, err + } + var salt []byte + if util.IsStringPrefixInSlice(hashedPassword, pbkdfPwdB64SaltPrefixes) { + salt, err = base64.StdEncoding.DecodeString(vals[3]) + if err != nil { + return false, err + } + } else { + salt = []byte(vals[3]) + } + var hashFunc func() hash.Hash + if strings.HasPrefix(hashedPassword, pbkdf2SHA256Prefix) || strings.HasPrefix(hashedPassword, pbkdf2SHA256B64SaltPrefix) { + hashFunc = sha256.New + } else if strings.HasPrefix(hashedPassword, pbkdf2SHA512Prefix) { + hashFunc = sha512.New + } else if strings.HasPrefix(hashedPassword, pbkdf2SHA1Prefix) { + hashFunc = sha1.New + } else { + return false, fmt.Errorf("pbkdf2: invalid or unsupported hash format %v", vals[1]) + } + df := pbkdf2.Key([]byte(password), salt, iterations, len(expected), hashFunc) + return subtle.ConstantTimeCompare(df, expected) == 1, nil +} + +func getSSLMode() string { + switch config.Driver { + case PGSQLDataProviderName, CockroachDataProviderName: + switch config.SSLMode { + case 0: + return "disable" + case 1: + return "require" + case 2: + return "verify-ca" + case 3: + return "verify-full" + case 4: + return "prefer" + case 5: + return "allow" + } + case MySQLDataProviderName: + if config.requireCustomTLSForMySQL() { + return "custom" + } + switch config.SSLMode { + case 0: + return "false" + case 1: + return "true" + case 2: + return "skip-verify" + case 3: + return "preferred" + } + } + return "" +} + +func terminateInteractiveAuthProgram(cmd *exec.Cmd, isFinished bool) { + if isFinished { + return + } + providerLog(logger.LevelInfo, "kill interactive auth program after an unexpected error") + err := cmd.Process.Kill() + if err != nil { + providerLog(logger.LevelDebug, "error killing interactive auth program: %v", err) + } +} + +func sendKeyboardAuthHTTPReq(url string, request *plugin.KeyboardAuthRequest) (*plugin.KeyboardAuthResponse, error) { + reqAsJSON, err := json.Marshal(request) + if err != nil { + providerLog(logger.LevelError, "error serializing keyboard interactive auth request: %v", err) + return nil, err + } + resp, err := httpclient.Post(url, "application/json", bytes.NewBuffer(reqAsJSON)) + if err != nil { + providerLog(logger.LevelError, "error getting keyboard interactive auth hook HTTP response: %v", err) + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("wrong keyboard interactive auth http status code: %v, expected 200", resp.StatusCode) + } + var response plugin.KeyboardAuthResponse + err = render.DecodeJSON(resp.Body, &response) + return &response, err +} + +func doBuiltinKeyboardInteractiveAuth(user *User, client ssh.KeyboardInteractiveChallenge, + ip, protocol string, isPartialAuth bool, +) (int, error) { + if err := user.LoadAndApplyGroupSettings(); err != nil { + return 0, err + } + hasSecondFactor := user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) + if !isPartialAuth || !hasSecondFactor { + answers, err := client("", "", []string{"Password: "}, []bool{false}) + if err != nil { + return 0, err + } + if len(answers) != 1 { + return 0, fmt.Errorf("unexpected number of answers: %d", len(answers)) + } + _, err = checkUserAndPass(user, answers[0], ip, protocol) + if err != nil { + return 0, err + } + } + return checkKeyboardInteractiveSecondFactor(user, client, protocol) +} + +func checkKeyboardInteractiveSecondFactor(user *User, client ssh.KeyboardInteractiveChallenge, protocol string) (int, error) { + if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) { + return 1, nil + } + err := user.Filters.TOTPConfig.Secret.TryDecrypt() + if err != nil { + providerLog(logger.LevelError, "unable to decrypt TOTP secret for user %q, protocol %v, err: %v", + user.Username, protocol, err) + return 0, err + } + answers, err := client("", "", []string{"Authentication code: "}, []bool{false}) + if err != nil { + return 0, err + } + if len(answers) != 1 { + return 0, fmt.Errorf("unexpected number of answers: %v", len(answers)) + } + match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, answers[0], + user.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + providerLog(logger.LevelWarn, "invalid passcode for user %q, protocol %v, err: %v", + user.Username, protocol, err) + return 0, util.NewValidationError("invalid passcode") + } + return 1, nil +} + +func executeKeyboardInteractivePlugin(user *User, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { + authResult := 0 + requestID := xid.New().String() + authStep := 1 + req := &plugin.KeyboardAuthRequest{ + Username: user.Username, + IP: ip, + Password: user.Password, + RequestID: requestID, + Step: authStep, + } + var response *plugin.KeyboardAuthResponse + var err error + for { + response, err = plugin.Handler.ExecuteKeyboardInteractiveStep(req) + if err != nil { + return authResult, err + } + if response.AuthResult != 0 { + return response.AuthResult, err + } + if err = response.Validate(); err != nil { + providerLog(logger.LevelInfo, "invalid response from keyboard interactive plugin: %v", err) + return authResult, err + } + answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) + if err != nil { + return authResult, err + } + authStep++ + req = &plugin.KeyboardAuthRequest{ + RequestID: requestID, + Step: authStep, + Username: user.Username, + Password: user.Password, + Answers: answers, + Questions: response.Questions, + } + } +} + +func executeKeyboardInteractiveHTTPHook(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { + authResult := 0 + requestID := xid.New().String() + authStep := 1 + req := &plugin.KeyboardAuthRequest{ + Username: user.Username, + IP: ip, + Password: user.Password, + RequestID: requestID, + Step: authStep, + } + var response *plugin.KeyboardAuthResponse + var err error + for { + response, err = sendKeyboardAuthHTTPReq(authHook, req) + if err != nil { + return authResult, err + } + if response.AuthResult != 0 { + return response.AuthResult, err + } + if err = response.Validate(); err != nil { + providerLog(logger.LevelInfo, "invalid response from keyboard interactive http hook: %v", err) + return authResult, err + } + answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) + if err != nil { + return authResult, err + } + authStep++ + req = &plugin.KeyboardAuthRequest{ + RequestID: requestID, + Step: authStep, + Username: user.Username, + Password: user.Password, + Answers: answers, + Questions: response.Questions, + } + } +} + +func getKeyboardInteractiveAnswers(client ssh.KeyboardInteractiveChallenge, response *plugin.KeyboardAuthResponse, + user *User, ip, protocol string, +) ([]string, error) { + questions := response.Questions + answers, err := client("", response.Instruction, questions, response.Echos) + if err != nil { + providerLog(logger.LevelInfo, "error getting interactive auth client response: %v", err) + return answers, err + } + if len(answers) != len(questions) { + err = fmt.Errorf("client answers does not match questions, expected: %v actual: %v", questions, answers) + providerLog(logger.LevelInfo, "keyboard interactive auth error: %v", err) + return answers, err + } + if len(answers) == 1 && response.CheckPwd > 0 { + if response.CheckPwd == 2 { + if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, protocolSSH) { + providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to check TOTP passcode, TOTP is not enabled for user %q", + user.Username) + return answers, errors.New("TOTP not enabled for SSH protocol") + } + err := user.Filters.TOTPConfig.Secret.TryDecrypt() + if err != nil { + providerLog(logger.LevelError, "unable to decrypt TOTP secret for user %q, protocol %v, err: %v", + user.Username, protocol, err) + return answers, fmt.Errorf("unable to decrypt TOTP secret: %w", err) + } + match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, answers[0], + user.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + providerLog(logger.LevelInfo, "keyboard interactive auth error: unable to validate passcode for user %q, match? %v, err: %v", + user.Username, match, err) + return answers, errors.New("unable to validate TOTP passcode") + } + } else { + _, err = checkUserAndPass(user, answers[0], ip, protocol) + providerLog(logger.LevelInfo, "interactive auth hook requested password validation for user %q, validation error: %v", + user.Username, err) + if err != nil { + return answers, err + } + } + answers[0] = "OK" + } + return answers, err +} + +func handleProgramInteractiveQuestions(client ssh.KeyboardInteractiveChallenge, response *plugin.KeyboardAuthResponse, + user *User, stdin io.WriteCloser, ip, protocol string, +) error { + answers, err := getKeyboardInteractiveAnswers(client, response, user, ip, protocol) + if err != nil { + return err + } + for _, answer := range answers { + if runtime.GOOS == "windows" { + answer += "\r" + } + answer += "\n" + _, err = stdin.Write([]byte(answer)) + if err != nil { + providerLog(logger.LevelError, "unable to write client answer to keyboard interactive program: %v", err) + return err + } + } + return nil +} + +func executeKeyboardInteractiveProgram(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, ip, protocol string) (int, error) { + authResult := 0 + timeout, env, args := command.GetConfig(authHook, command.HookKeyboardInteractive) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, authHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_AUTHD_USERNAME=%s", user.Username), + fmt.Sprintf("SFTPGO_AUTHD_IP=%s", ip), + fmt.Sprintf("SFTPGO_AUTHD_PASSWORD=%s", user.Password)) + stdout, err := cmd.StdoutPipe() + if err != nil { + return authResult, err + } + stdin, err := cmd.StdinPipe() + if err != nil { + return authResult, err + } + err = cmd.Start() + if err != nil { + return authResult, err + } + var once sync.Once + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + var response plugin.KeyboardAuthResponse + err = json.Unmarshal(scanner.Bytes(), &response) + if err != nil { + providerLog(logger.LevelInfo, "interactive auth error parsing response: %v", err) + once.Do(func() { terminateInteractiveAuthProgram(cmd, false) }) + break + } + if response.AuthResult != 0 { + authResult = response.AuthResult + break + } + if err = response.Validate(); err != nil { + providerLog(logger.LevelInfo, "invalid response from keyboard interactive program: %v", err) + once.Do(func() { terminateInteractiveAuthProgram(cmd, false) }) + break + } + go func() { + err := handleProgramInteractiveQuestions(client, &response, user, stdin, ip, protocol) + if err != nil { + once.Do(func() { terminateInteractiveAuthProgram(cmd, false) }) + } + }() + } + stdin.Close() + once.Do(func() { terminateInteractiveAuthProgram(cmd, true) }) + go func() { + _, err := cmd.Process.Wait() + if err != nil { + providerLog(logger.LevelWarn, "error waiting for %q process to exit: %v", authHook, err) + } + }() + + return authResult, err +} + +func doKeyboardInteractiveAuth(user *User, authHook string, client ssh.KeyboardInteractiveChallenge, + ip, protocol string, isPartialAuth bool, +) (User, error) { + if err := user.LoadAndApplyGroupSettings(); err != nil { + return *user, err + } + var authResult int + var err error + if !user.Filters.Hooks.ExternalAuthDisabled { + if plugin.Handler.HasAuthScope(plugin.AuthScopeKeyboardInteractive) { + authResult, err = executeKeyboardInteractivePlugin(user, client, ip, protocol) + if authResult == 1 && err == nil { + authResult, err = checkKeyboardInteractiveSecondFactor(user, client, protocol) + } + } else if authHook != "" { + if strings.HasPrefix(authHook, "http") { + authResult, err = executeKeyboardInteractiveHTTPHook(user, authHook, client, ip, protocol) + } else { + authResult, err = executeKeyboardInteractiveProgram(user, authHook, client, ip, protocol) + } + } else { + authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol, isPartialAuth) + } + } else { + authResult, err = doBuiltinKeyboardInteractiveAuth(user, client, ip, protocol, isPartialAuth) + } + if err != nil { + return *user, err + } + if authResult != 1 { + return *user, fmt.Errorf("keyboard interactive auth failed, result: %v", authResult) + } + err = user.CheckLoginConditions() + if err != nil { + return *user, err + } + return *user, nil +} + +func isCheckPasswordHookDefined(protocol string) bool { + if config.CheckPasswordHook == "" { + return false + } + if config.CheckPasswordScope == 0 { + return true + } + switch protocol { + case protocolSSH: + return config.CheckPasswordScope&1 != 0 + case protocolFTP: + return config.CheckPasswordScope&2 != 0 + case protocolWebDAV: + return config.CheckPasswordScope&4 != 0 + default: + return false + } +} + +func getPasswordHookResponse(username, password, ip, protocol string) ([]byte, error) { + if strings.HasPrefix(config.CheckPasswordHook, "http") { + var result []byte + req := checkPasswordRequest{ + Username: username, + Password: password, + IP: ip, + Protocol: protocol, + } + reqAsJSON, err := json.Marshal(req) + if err != nil { + return result, err + } + resp, err := httpclient.Post(config.CheckPasswordHook, "application/json", bytes.NewBuffer(reqAsJSON)) + if err != nil { + providerLog(logger.LevelError, "error getting check password hook response: %v", err) + return result, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("wrong http status code from chek password hook: %v, expected 200", resp.StatusCode) + } + return io.ReadAll(io.LimitReader(resp.Body, maxHookResponseSize)) + } + timeout, env, args := command.GetConfig(config.CheckPasswordHook, command.HookCheckPassword) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, config.CheckPasswordHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_AUTHD_USERNAME=%s", username), + fmt.Sprintf("SFTPGO_AUTHD_PASSWORD=%s", password), + fmt.Sprintf("SFTPGO_AUTHD_IP=%s", ip), + fmt.Sprintf("SFTPGO_AUTHD_PROTOCOL=%s", protocol), + ) + return getCmdOutput(cmd, "check_password_hook") +} + +func executeCheckPasswordHook(username, password, ip, protocol string) (checkPasswordResponse, error) { + var response checkPasswordResponse + + if !isCheckPasswordHookDefined(protocol) { + response.Status = -1 + return response, nil + } + + startTime := time.Now() + out, err := getPasswordHookResponse(username, password, ip, protocol) + providerLog(logger.LevelDebug, "check password hook executed, error: %v, elapsed: %v", err, time.Since(startTime)) + if err != nil { + return response, err + } + err = json.Unmarshal(out, &response) + return response, err +} + +func getPreLoginHookResponse(loginMethod, ip, protocol string, userAsJSON []byte) ([]byte, error) { + if strings.HasPrefix(config.PreLoginHook, "http") { + var url *url.URL + var result []byte + url, err := url.Parse(config.PreLoginHook) + if err != nil { + providerLog(logger.LevelError, "invalid url for pre-login hook %q, error: %v", config.PreLoginHook, err) + return result, err + } + q := url.Query() + q.Add("login_method", loginMethod) + q.Add("ip", ip) + q.Add("protocol", protocol) + url.RawQuery = q.Encode() + + resp, err := httpclient.Post(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) + if err != nil { + providerLog(logger.LevelWarn, "error getting pre-login hook response: %v", err) + return result, err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNoContent { + return result, nil + } + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("wrong pre-login hook http status code: %v, expected 200", resp.StatusCode) + } + return io.ReadAll(io.LimitReader(resp.Body, maxHookResponseSize)) + } + timeout, env, args := command.GetConfig(config.PreLoginHook, command.HookPreLogin) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, config.PreLoginHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_LOGIND_USER=%s", userAsJSON), + fmt.Sprintf("SFTPGO_LOGIND_METHOD=%s", loginMethod), + fmt.Sprintf("SFTPGO_LOGIND_IP=%s", ip), + fmt.Sprintf("SFTPGO_LOGIND_PROTOCOL=%s", protocol), + ) + return getCmdOutput(cmd, "pre_login_hook") +} + +func executePreLoginHook(username, loginMethod, ip, protocol string, oidcTokenFields *map[string]any) (User, error) { + var user User + + u, mergedUser, userAsJSON, err := getUserAndJSONForHook(username, oidcTokenFields) + if err != nil { + return u, err + } + if mergedUser.Filters.Hooks.PreLoginDisabled { + return u, nil + } + startTime := time.Now() + out, err := getPreLoginHookResponse(loginMethod, ip, protocol, userAsJSON) + if err != nil { + return u, fmt.Errorf("pre-login hook error: %v, username %q, ip %v, protocol %v elapsed %v", + err, username, ip, protocol, time.Since(startTime)) + } + providerLog(logger.LevelDebug, "pre-login hook completed, elapsed: %s", time.Since(startTime)) + if util.IsByteArrayEmpty(out) { + providerLog(logger.LevelDebug, "empty response from pre-login hook, no modification requested for user %q id: %d", + username, u.ID) + if u.ID == 0 { + return u, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + return u, nil + } + err = json.Unmarshal(out, &user) + if err != nil { + return u, fmt.Errorf("invalid pre-login hook response %q, error: %v", out, err) + } + if u.ID > 0 { + user.ID = u.ID + user.UsedQuotaSize = u.UsedQuotaSize + user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer + user.LastQuotaUpdate = u.LastQuotaUpdate + user.LastLogin = u.LastLogin + user.LastPasswordChange = u.LastPasswordChange + user.FirstDownload = u.FirstDownload + user.FirstUpload = u.FirstUpload + // preserve TOTP config and recovery codes + user.Filters.TOTPConfig = u.Filters.TOTPConfig + user.Filters.RecoveryCodes = u.Filters.RecoveryCodes + if err := provider.updateUser(&user); err != nil { + return u, err + } + } else { + if err := provider.addUser(&user); err != nil { + return u, err + } + } + user, err = provider.userExists(user.Username, "") + if err != nil { + return u, err + } + providerLog(logger.LevelDebug, "user %q added/updated from pre-login hook response, id: %d", username, u.ID) + if u.ID > 0 { + webDAVUsersCache.swap(&user, "") + } + return user, nil +} + +// ExecutePostLoginHook executes the post login hook if defined +func ExecutePostLoginHook(user *User, loginMethod, ip, protocol string, err error) { + if config.PostLoginHook == "" { + return + } + if config.PostLoginScope == 1 && err == nil { + return + } + if config.PostLoginScope == 2 && err != nil { + return + } + + go func() { + actionsConcurrencyGuard <- struct{}{} + defer func() { + <-actionsConcurrencyGuard + }() + + status := "0" + if err == nil { + status = "1" + } + + user.PrepareForRendering() + userAsJSON, err := json.Marshal(user) + if err != nil { + providerLog(logger.LevelError, "error serializing user in post login hook: %v", err) + return + } + if strings.HasPrefix(config.PostLoginHook, "http") { + var url *url.URL + url, err := url.Parse(config.PostLoginHook) + if err != nil { + providerLog(logger.LevelDebug, "Invalid post-login hook %q", config.PostLoginHook) + return + } + q := url.Query() + q.Add("login_method", loginMethod) + q.Add("ip", ip) + q.Add("protocol", protocol) + q.Add("status", status) + url.RawQuery = q.Encode() + + startTime := time.Now() + respCode := 0 + resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(userAsJSON)) + if err == nil { + respCode = resp.StatusCode + resp.Body.Close() + } + providerLog(logger.LevelDebug, "post login hook executed for user %q, ip %v, protocol %v, response code: %v, elapsed: %v err: %v", + user.Username, ip, protocol, respCode, time.Since(startTime), err) + return + } + timeout, env, args := command.GetConfig(config.PostLoginHook, command.HookPostLogin) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, config.PostLoginHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_LOGIND_USER=%s", userAsJSON), + fmt.Sprintf("SFTPGO_LOGIND_IP=%s", ip), + fmt.Sprintf("SFTPGO_LOGIND_METHOD=%s", loginMethod), + fmt.Sprintf("SFTPGO_LOGIND_STATUS=%s", status), + fmt.Sprintf("SFTPGO_LOGIND_PROTOCOL=%s", protocol)) + startTime := time.Now() + err = cmd.Run() + providerLog(logger.LevelDebug, "post login hook executed for user %q, ip %v, protocol %v, elapsed %v err: %v", + user.Username, ip, protocol, time.Since(startTime), err) + }() +} + +func getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, protocol string, cert *x509.Certificate, + user User, +) ([]byte, error) { + var tlsCert string + if cert != nil { + var err error + tlsCert, err = util.EncodeTLSCertToPem(cert) + if err != nil { + return nil, err + } + } + if strings.HasPrefix(config.ExternalAuthHook, "http") { + var result []byte + authRequest := make(map[string]any) + authRequest["username"] = username + authRequest["ip"] = ip + authRequest["password"] = password + authRequest["public_key"] = pkey + authRequest["protocol"] = protocol + authRequest["keyboard_interactive"] = keyboardInteractive + authRequest["tls_cert"] = tlsCert + if user.ID > 0 { + authRequest["user"] = user + } + authRequestAsJSON, err := json.Marshal(authRequest) + if err != nil { + providerLog(logger.LevelError, "error serializing external auth request: %v", err) + return result, err + } + resp, err := httpclient.Post(config.ExternalAuthHook, "application/json", bytes.NewBuffer(authRequestAsJSON)) + if err != nil { + providerLog(logger.LevelWarn, "error getting external auth hook HTTP response: %v", err) + return result, err + } + defer resp.Body.Close() + providerLog(logger.LevelDebug, "external auth hook executed, response code: %v", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("wrong external auth http status code: %v, expected 200", resp.StatusCode) + } + + return io.ReadAll(io.LimitReader(resp.Body, maxHookResponseSize)) + } + var userAsJSON []byte + var err error + if user.ID > 0 { + userAsJSON, err = json.Marshal(user) + if err != nil { + return nil, fmt.Errorf("unable to serialize user as JSON: %w", err) + } + } + timeout, env, args := command.GetConfig(config.ExternalAuthHook, command.HookExternalAuth) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, config.ExternalAuthHook, args...) + cmd.Env = append(env, + fmt.Sprintf("SFTPGO_AUTHD_USERNAME=%s", username), + fmt.Sprintf("SFTPGO_AUTHD_USER=%s", userAsJSON), + fmt.Sprintf("SFTPGO_AUTHD_IP=%s", ip), + fmt.Sprintf("SFTPGO_AUTHD_PASSWORD=%s", password), + fmt.Sprintf("SFTPGO_AUTHD_PUBLIC_KEY=%s", pkey), + fmt.Sprintf("SFTPGO_AUTHD_PROTOCOL=%s", protocol), + fmt.Sprintf("SFTPGO_AUTHD_TLS_CERT=%s", strings.ReplaceAll(tlsCert, "\n", "\\n")), + fmt.Sprintf("SFTPGO_AUTHD_KEYBOARD_INTERACTIVE=%v", keyboardInteractive)) + + return getCmdOutput(cmd, "external_auth_hook") +} + +func updateUserFromExtAuthResponse(user *User, password, pkey string) { + if password != "" { + user.Password = password + } + if pkey != "" && !util.IsStringPrefixInSlice(pkey, user.PublicKeys) { + user.PublicKeys = append(user.PublicKeys, pkey) + } + user.LastPasswordChange = 0 +} + +func checkPasswordAfterEmptyExtAuthResponse(user *User, plainPwd, protocol string) error { + if plainPwd == "" { + return nil + } + match, err := isPasswordOK(user, plainPwd) + if match && err == nil { + return nil + } + + hashedPwd, err := hashPlainPassword(plainPwd) + if err != nil { + providerLog(logger.LevelError, "unable to hash password for user %q after empty external response: %v", + user.Username, err) + return err + } + err = provider.updateUserPassword(user.Username, hashedPwd) + if err != nil { + providerLog(logger.LevelError, "unable to update password for user %q after empty external response: %v", + user.Username, err) + } + user.Password = hashedPwd + cachedUserPasswords.Add(user.Username, plainPwd, user.Password) + if protocol != protocolWebDAV { + webDAVUsersCache.swap(user, plainPwd) + } + providerLog(logger.LevelDebug, "updated password for user %q after empty external auth response", user.Username) + return nil +} + +func doExternalAuth(username, password string, pubKey []byte, keyboardInteractive, ip, protocol string, + tlsCert *x509.Certificate, +) (User, error) { + var user User + + u, mergedUser, err := getUserForHook(username, nil) + if err != nil { + return user, err + } + + if mergedUser.skipExternalAuth() { + return u, nil + } + + pkey, err := util.GetSSHPublicKeyAsString(pubKey) + if err != nil { + return user, err + } + + startTime := time.Now() + out, err := getExternalAuthResponse(username, password, pkey, keyboardInteractive, ip, protocol, tlsCert, u) + if err != nil { + return user, fmt.Errorf("external auth error for user %q, elapsed: %s: %w", username, time.Since(startTime), err) + } + providerLog(logger.LevelDebug, "external auth completed for user %q, elapsed: %s", username, time.Since(startTime)) + if util.IsByteArrayEmpty(out) { + providerLog(logger.LevelDebug, "empty response from external hook, no modification requested for user %q, id: %d", + username, u.ID) + if u.ID == 0 { + return u, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + err = checkPasswordAfterEmptyExtAuthResponse(&u, password, protocol) + return u, err + } + err = json.Unmarshal(out, &user) + if err != nil { + return user, fmt.Errorf("invalid external auth response: %v", err) + } + // an empty username means authentication failure + if user.Username == "" { + return user, ErrInvalidCredentials + } + updateUserFromExtAuthResponse(&user, password, pkey) + // some users want to map multiple login usernames with a single SFTPGo account + // for example an SFTP user logins using "user1" or "user2" and the external auth + // returns "user" in both cases, so we use the username returned from + // external auth and not the one used to login + if user.Username != username { + u, err = provider.userExists(user.Username, "") + } + if u.ID > 0 && err == nil { + user.ID = u.ID + user.UsedQuotaSize = u.UsedQuotaSize + user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer + user.LastQuotaUpdate = u.LastQuotaUpdate + user.LastLogin = u.LastLogin + user.LastPasswordChange = u.LastPasswordChange + user.FirstDownload = u.FirstDownload + user.FirstUpload = u.FirstUpload + user.CreatedAt = u.CreatedAt + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + // preserve TOTP config and recovery codes + user.Filters.TOTPConfig = u.Filters.TOTPConfig + user.Filters.RecoveryCodes = u.Filters.RecoveryCodes + user, err = updateUserAfterExternalAuth(&user) + if err == nil { + if protocol != protocolWebDAV { + webDAVUsersCache.swap(&user, password) + } + cachedUserPasswords.Add(user.Username, password, user.Password) + } + return user, err + } + err = provider.addUser(&user) + if err != nil { + return user, err + } + return provider.userExists(user.Username, "") +} + +func doPluginAuth(username, password string, pubKey []byte, ip, protocol string, + tlsCert *x509.Certificate, authScope int, +) (User, error) { + var user User + + u, mergedUser, userAsJSON, err := getUserAndJSONForHook(username, nil) + if err != nil { + return user, err + } + + if mergedUser.skipExternalAuth() { + return u, nil + } + + pkey, err := util.GetSSHPublicKeyAsString(pubKey) + if err != nil { + return user, err + } + + startTime := time.Now() + + out, err := plugin.Handler.Authenticate(username, password, ip, protocol, pkey, tlsCert, authScope, userAsJSON) + if err != nil { + return user, fmt.Errorf("plugin auth error for user %q: %v, elapsed: %v, auth scope: %d", + username, err, time.Since(startTime), authScope) + } + providerLog(logger.LevelDebug, "plugin auth completed for user %q, elapsed: %v, auth scope: %d", + username, time.Since(startTime), authScope) + if util.IsByteArrayEmpty(out) { + providerLog(logger.LevelDebug, "empty response from plugin auth, no modification requested for user %q id: %d, auth scope: %d", + username, u.ID, authScope) + if u.ID == 0 { + return u, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + err = checkPasswordAfterEmptyExtAuthResponse(&u, password, protocol) + return u, err + } + err = json.Unmarshal(out, &user) + if err != nil { + return user, fmt.Errorf("invalid plugin auth response: %v", err) + } + updateUserFromExtAuthResponse(&user, password, pkey) + if u.ID > 0 { + user.ID = u.ID + user.UsedQuotaSize = u.UsedQuotaSize + user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer + user.LastQuotaUpdate = u.LastQuotaUpdate + user.LastLogin = u.LastLogin + user.LastPasswordChange = u.LastPasswordChange + user.FirstDownload = u.FirstDownload + user.FirstUpload = u.FirstUpload + // preserve TOTP config and recovery codes + user.Filters.TOTPConfig = u.Filters.TOTPConfig + user.Filters.RecoveryCodes = u.Filters.RecoveryCodes + user, err = updateUserAfterExternalAuth(&user) + if err == nil { + if protocol != protocolWebDAV { + webDAVUsersCache.swap(&user, password) + } + cachedUserPasswords.Add(user.Username, password, user.Password) + } + return user, err + } + err = provider.addUser(&user) + if err != nil { + return user, err + } + return provider.userExists(user.Username, "") +} + +func updateUserAfterExternalAuth(user *User) (User, error) { + if err := provider.updateUser(user); err != nil { + return *user, err + } + return provider.userExists(user.Username, "") +} + +func getUserForHook(username string, oidcTokenFields *map[string]any) (User, User, error) { + u, err := provider.userExists(username, "") + if err != nil { + if !errors.Is(err, util.ErrNotFound) { + return u, u, err + } + u = User{ + BaseUser: sdk.BaseUser{ + ID: 0, + Username: username, + }, + } + } + mergedUser := u.getACopy() + err = mergedUser.LoadAndApplyGroupSettings() + if err != nil { + return u, mergedUser, err + } + + u.OIDCCustomFields = oidcTokenFields + return u, mergedUser, err +} + +func getUserAndJSONForHook(username string, oidcTokenFields *map[string]any) (User, User, []byte, error) { + u, mergedUser, err := getUserForHook(username, oidcTokenFields) + if err != nil { + return u, mergedUser, nil, err + } + userAsJSON, err := json.Marshal(u) + if err != nil { + return u, mergedUser, userAsJSON, err + } + return u, mergedUser, userAsJSON, err +} + +func isLastActivityRecent(lastActivity int64, minDelay time.Duration) bool { + lastActivityTime := util.GetTimeFromMsecSinceEpoch(lastActivity) + diff := -time.Until(lastActivityTime) + if diff < -10*time.Second { + return false + } + return diff < minDelay +} + +func isExternalAuthConfigured(loginMethod string) bool { + if config.ExternalAuthHook != "" { + if config.ExternalAuthScope == 0 { + return true + } + switch loginMethod { + case LoginMethodPassword: + return config.ExternalAuthScope&1 != 0 + case LoginMethodTLSCertificate: + return config.ExternalAuthScope&8 != 0 + case LoginMethodTLSCertificateAndPwd: + return config.ExternalAuthScope&1 != 0 || config.ExternalAuthScope&8 != 0 + } + } + switch loginMethod { + case LoginMethodPassword: + return plugin.Handler.HasAuthScope(plugin.AuthScopePassword) + case LoginMethodTLSCertificate: + return plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) + case LoginMethodTLSCertificateAndPwd: + return plugin.Handler.HasAuthScope(plugin.AuthScopePassword) || + plugin.Handler.HasAuthScope(plugin.AuthScopeTLSCertificate) + default: + return false + } +} + +func replaceTemplateVars(input string) string { + var result strings.Builder + i := 0 + for i < len(input) { + if i+2 <= len(input) && input[i:i+2] == "{{" { + if i+2 < len(input) { + nextChar := input[i+2] + if nextChar == ' ' || nextChar == '.' || nextChar == '-' { + // Don't replace if followed by space, dot or minus. + result.WriteString("{{") + i += 2 + continue + } + } + + // Find the closing "}}" + closing := strings.Index(input[i:], "}}") + if closing != -1 { + // Replace with {{. only if it's a proper template variable. + result.WriteString("{{.") + result.WriteString(input[i+2 : i+closing]) + result.WriteString("}}") + i += closing + 2 + continue + } + } + result.WriteByte(input[i]) + i++ + } + return result.String() +} + +func updateEventActionPlaceholders(actions []BaseEventAction) ([]BaseEventAction, error) { + var result []BaseEventAction + + for _, action := range actions { + options, err := json.Marshal(action.Options) + if err != nil { + return nil, err + } + convertedOptions := replaceTemplateVars(string(options)) + var opts BaseEventActionOptions + err = json.Unmarshal([]byte(convertedOptions), &opts) + if err != nil { + return nil, err + } + action.Options = opts + result = append(result, action) + } + + return result, nil +} + +func getConfigPath(name, configDir string) string { + if !util.IsFileInputValid(name) { + return "" + } + if name != "" && !filepath.IsAbs(name) { + return filepath.Join(configDir, name) + } + return name +} + +func checkReservedUsernames(username string) error { + if slices.Contains(reservedUsers, username) { + return util.NewValidationError("this username is reserved") + } + return nil +} + +func errSchemaVersionTooOld(version int) error { + return fmt.Errorf("database schema version %d is too old, please see the upgrading docs: https://docs.sftpgo.com/latest/data-provider/#upgrading", version) +} + +func getCmdOutput(cmd *exec.Cmd, sender string) ([]byte, error) { + var stdout bytes.Buffer + cmd.Stdout = &stdout + + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + + err = cmd.Start() + if err != nil { + return nil, err + } + + scanner := bufio.NewScanner(stderr) + + go func() { + for scanner.Scan() { + if out := scanner.Text(); out != "" { + logger.Log(logger.LevelWarn, sender, "", "%s", out) + } + } + if err := scanner.Err(); err != nil { + logger.Log(logger.LevelError, sender, "", "error reading stderr: %v", err) + } + }() + + err = cmd.Wait() + return stdout.Bytes(), err +} + +func providerLog(level logger.LogLevel, format string, v ...any) { + logger.Log(level, logSender, "", format, v...) +} diff --git a/internal/dataprovider/eventrule.go b/internal/dataprovider/eventrule.go new file mode 100644 index 00000000..959f874c --- /dev/null +++ b/internal/dataprovider/eventrule.go @@ -0,0 +1,1948 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net/http" + "path" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/robfig/cron/v3" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// Supported event actions +const ( + ActionTypeHTTP = iota + 1 + ActionTypeCommand + ActionTypeEmail + ActionTypeBackup + ActionTypeUserQuotaReset + ActionTypeFolderQuotaReset + ActionTypeTransferQuotaReset + ActionTypeDataRetentionCheck + ActionTypeFilesystem + actionTypeReserved + ActionTypePasswordExpirationCheck + ActionTypeUserExpirationCheck + ActionTypeIDPAccountCheck + ActionTypeUserInactivityCheck + ActionTypeRotateLogs +) + +var ( + supportedEventActions = []int{ActionTypeHTTP, ActionTypeCommand, ActionTypeEmail, ActionTypeFilesystem, + ActionTypeBackup, ActionTypeUserQuotaReset, ActionTypeFolderQuotaReset, ActionTypeTransferQuotaReset, + ActionTypeDataRetentionCheck, ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck, + ActionTypeUserInactivityCheck, ActionTypeIDPAccountCheck, ActionTypeRotateLogs} + // EnabledActionCommands defines the system commands that can be executed via EventManager, + // an empty list means that no command is allowed to be executed. + EnabledActionCommands []string +) + +func isActionTypeValid(action int) bool { + return slices.Contains(supportedEventActions, action) +} + +func getActionTypeAsString(action int) string { + switch action { + case ActionTypeHTTP: + return util.I18nActionTypeHTTP + case ActionTypeEmail: + return util.I18nActionTypeEmail + case ActionTypeBackup: + return util.I18nActionTypeBackup + case ActionTypeUserQuotaReset: + return util.I18nActionTypeUserQuotaReset + case ActionTypeFolderQuotaReset: + return util.I18nActionTypeFolderQuotaReset + case ActionTypeTransferQuotaReset: + return util.I18nActionTypeTransferQuotaReset + case ActionTypeDataRetentionCheck: + return util.I18nActionTypeDataRetentionCheck + case ActionTypeFilesystem: + return util.I18nActionTypeFilesystem + case ActionTypePasswordExpirationCheck: + return util.I18nActionTypePwdExpirationCheck + case ActionTypeUserExpirationCheck: + return util.I18nActionTypeUserExpirationCheck + case ActionTypeUserInactivityCheck: + return util.I18nActionTypeUserInactivityCheck + case ActionTypeIDPAccountCheck: + return util.I18nActionTypeIDPCheck + case ActionTypeRotateLogs: + return util.I18nActionTypeRotateLogs + default: + return util.I18nActionTypeCommand + } +} + +// Supported event triggers +const ( + // Filesystem events such as upload, download, mkdir ... + EventTriggerFsEvent = iota + 1 + // Provider events such as add, update, delete + EventTriggerProviderEvent + EventTriggerSchedule + EventTriggerIPBlocked + EventTriggerCertificate + EventTriggerOnDemand + EventTriggerIDPLogin +) + +var ( + supportedEventTriggers = []int{EventTriggerFsEvent, EventTriggerProviderEvent, EventTriggerSchedule, + EventTriggerIPBlocked, EventTriggerCertificate, EventTriggerIDPLogin, EventTriggerOnDemand} +) + +func isEventTriggerValid(trigger int) bool { + return slices.Contains(supportedEventTriggers, trigger) +} + +func getTriggerTypeAsString(trigger int) string { + switch trigger { + case EventTriggerFsEvent: + return util.I18nTriggerFsEvent + case EventTriggerProviderEvent: + return util.I18nTriggerProviderEvent + case EventTriggerIPBlocked: + return util.I18nTriggerIPBlockedEvent + case EventTriggerCertificate: + return util.I18nTriggerCertificateRenewEvent + case EventTriggerOnDemand: + return util.I18nTriggerOnDemandEvent + case EventTriggerIDPLogin: + return util.I18nTriggerIDPLoginEvent + default: + return util.I18nTriggerScheduleEvent + } +} + +// Supported IDP login events +const ( + IDPLoginAny = iota + IDPLoginUser + IDPLoginAdmin +) + +var ( + supportedIDPLoginEvents = []int{IDPLoginAny, IDPLoginUser, IDPLoginAdmin} +) + +// Supported filesystem actions +const ( + FilesystemActionRename = iota + 1 + FilesystemActionDelete + FilesystemActionMkdirs + FilesystemActionExist + FilesystemActionCompress + FilesystemActionCopy +) + +const ( + // RetentionReportPlaceHolder defines the placeholder for data retention reports + RetentionReportPlaceHolder = "{{RetentionReports}}" +) + +var ( + supportedFsActions = []int{FilesystemActionRename, FilesystemActionDelete, FilesystemActionMkdirs, + FilesystemActionCopy, FilesystemActionCompress, FilesystemActionExist} +) + +func isFilesystemActionValid(value int) bool { + return slices.Contains(supportedFsActions, value) +} + +func getFsActionTypeAsString(value int) string { + switch value { + case FilesystemActionRename: + return util.I18nActionFsTypeRename + case FilesystemActionDelete: + return util.I18nActionFsTypeDelete + case FilesystemActionExist: + return util.I18nActionFsTypePathExists + case FilesystemActionCompress: + return util.I18nActionFsTypeCompress + case FilesystemActionCopy: + return util.I18nActionFsTypeCopy + default: + return util.I18nActionFsTypeCreateDirs + } +} + +// TODO: replace the copied strings with shared constants +var ( + // SupportedFsEvents defines the supported filesystem events + SupportedFsEvents = []string{"upload", "pre-upload", "first-upload", "download", "pre-download", + "first-download", "delete", "pre-delete", "rename", "mkdir", "rmdir", "copy", "ssh_cmd"} + // SupportedProviderEvents defines the supported provider events + SupportedProviderEvents = []string{operationAdd, operationUpdate, operationDelete} + // SupportedRuleConditionProtocols defines the supported protcols for rule conditions + SupportedRuleConditionProtocols = []string{"SFTP", "SCP", "SSH", "FTP", "DAV", "HTTP", "HTTPShare", + "OIDC"} + // SupporteRuleConditionProviderObjects defines the supported provider objects for rule conditions + SupporteRuleConditionProviderObjects = []string{actionObjectUser, actionObjectFolder, actionObjectGroup, + actionObjectAdmin, actionObjectAPIKey, actionObjectShare, actionObjectEventRule, actionObjectEventAction} + // SupportedHTTPActionMethods defines the supported methods for HTTP actions + SupportedHTTPActionMethods = []string{http.MethodPost, http.MethodGet, http.MethodPut, http.MethodDelete} + allowedSyncFsEvents = []string{"upload", "pre-upload", "pre-download", "pre-delete"} + mandatorySyncFsEvents = []string{"pre-upload", "pre-download", "pre-delete"} +) + +// enum mappings +var ( + EventActionTypes []EnumMapping + EventTriggerTypes []EnumMapping + FsActionTypes []EnumMapping +) + +func init() { + for _, t := range supportedEventActions { + EventActionTypes = append(EventActionTypes, EnumMapping{ + Value: t, + Name: getActionTypeAsString(t), + }) + } + for _, t := range supportedEventTriggers { + EventTriggerTypes = append(EventTriggerTypes, EnumMapping{ + Value: t, + Name: getTriggerTypeAsString(t), + }) + } + for _, t := range supportedFsActions { + FsActionTypes = append(FsActionTypes, EnumMapping{ + Value: t, + Name: getFsActionTypeAsString(t), + }) + } +} + +// EnumMapping defines a mapping between enum values and names +type EnumMapping struct { + Name string + Value int +} + +// KeyValue defines a key/value pair +type KeyValue struct { + Key string `json:"key"` + Value string `json:"value"` +} + +func (k *KeyValue) isNotValid() bool { + return k.Key == "" || k.Value == "" +} + +// HTTPPart defines a part for HTTP multipart requests +type HTTPPart struct { + Name string `json:"name,omitempty"` + Filepath string `json:"filepath,omitempty"` + Headers []KeyValue `json:"headers,omitempty"` + Body string `json:"body,omitempty"` + Order int `json:"-"` +} + +func (p *HTTPPart) validate() error { + if p.Name == "" { + return util.NewI18nError(util.NewValidationError("HTTP part name is required"), util.I18nErrorHTTPPartNameRequired) + } + for _, kv := range p.Headers { + if kv.isNotValid() { + return util.NewValidationError("invalid HTTP part headers") + } + } + if p.Filepath == "" { + if p.Body == "" { + return util.NewI18nError( + util.NewValidationError("HTTP part body is required if no file path is provided"), + util.I18nErrorHTTPPartBodyRequired, + ) + } + } else { + p.Body = "" + if p.Filepath != RetentionReportPlaceHolder { + p.Filepath = util.CleanPath(p.Filepath) + } + } + return nil +} + +// EventActionHTTPConfig defines the configuration for an HTTP event target +type EventActionHTTPConfig struct { + Endpoint string `json:"endpoint,omitempty"` + Username string `json:"username,omitempty"` + Password *kms.Secret `json:"password,omitempty"` + Headers []KeyValue `json:"headers,omitempty"` + Timeout int `json:"timeout,omitempty"` + SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` + Method string `json:"method,omitempty"` + QueryParameters []KeyValue `json:"query_parameters,omitempty"` + Body string `json:"body,omitempty"` + Parts []HTTPPart `json:"parts,omitempty"` +} + +// HasJSONBody returns true if the content type header indicates a JSON body +func (c *EventActionHTTPConfig) HasJSONBody() bool { + for _, h := range c.Headers { + if http.CanonicalHeaderKey(h.Key) == "Content-Type" { + return strings.Contains(strings.ToLower(h.Value), "application/json") + } + } + return false +} + +func (c *EventActionHTTPConfig) isTimeoutNotValid() bool { + if c.HasMultipartFiles() { + return false + } + return c.Timeout < 1 || c.Timeout > 180 +} + +func (c *EventActionHTTPConfig) validateMultiparts() error { + filePaths := make(map[string]bool) + for idx := range c.Parts { + if err := c.Parts[idx].validate(); err != nil { + return err + } + if filePath := c.Parts[idx].Filepath; filePath != "" { + if filePaths[filePath] { + return util.NewI18nError(fmt.Errorf("filepath %q is duplicated", filePath), util.I18nErrorPathDuplicated) + } + filePaths[filePath] = true + } + } + if len(c.Parts) > 0 { + if c.Body != "" { + return util.NewI18nError( + util.NewValidationError("multipart requests require no body. The request body is build from the specified parts"), + util.I18nErrorMultipartBody, + ) + } + for _, k := range c.Headers { + if strings.EqualFold(k.Key, "content-type") { + return util.NewI18nError( + util.NewValidationError("content type is automatically set for multipart requests"), + util.I18nErrorMultipartCType, + ) + } + } + } + return nil +} + +func (c *EventActionHTTPConfig) validate(additionalData string) error { + if c.Endpoint == "" { + return util.NewI18nError(util.NewValidationError("HTTP endpoint is required"), util.I18nErrorURLRequired) + } + if !util.IsStringPrefixInSlice(c.Endpoint, []string{"http://", "https://"}) { + return util.NewI18nError( + util.NewValidationError("invalid HTTP endpoint schema: http and https are supported"), + util.I18nErrorURLInvalid, + ) + } + if c.isTimeoutNotValid() { + return util.NewValidationError(fmt.Sprintf("invalid HTTP timeout %d", c.Timeout)) + } + for _, kv := range c.Headers { + if kv.isNotValid() { + return util.NewValidationError("invalid HTTP headers") + } + } + if err := c.validateMultiparts(); err != nil { + return err + } + if c.Password.IsRedacted() { + return util.NewValidationError("cannot save HTTP configuration with a redacted secret") + } + if c.Password.IsPlain() { + c.Password.SetAdditionalData(additionalData) + err := c.Password.Encrypt() + if err != nil { + return util.NewValidationError(fmt.Sprintf("could not encrypt HTTP password: %v", err)) + } + } + if !slices.Contains(SupportedHTTPActionMethods, c.Method) { + return util.NewValidationError(fmt.Sprintf("unsupported HTTP method: %s", c.Method)) + } + for _, kv := range c.QueryParameters { + if kv.isNotValid() { + return util.NewValidationError("invalid HTTP query parameters") + } + } + return nil +} + +// GetContext returns the context and the cancel func to use for the HTTP request +func (c *EventActionHTTPConfig) GetContext() (context.Context, context.CancelFunc) { + if c.HasMultipartFiles() { + return context.WithCancel(context.Background()) + } + return context.WithTimeout(context.Background(), time.Duration(c.Timeout)*time.Second) +} + +// HasObjectData returns true if the {{ObjectData}} placeholder is defined +func (c *EventActionHTTPConfig) HasObjectData() bool { + if strings.Contains(c.Body, "{{ObjectData}}") || strings.Contains(c.Body, "{{ObjectDataString}}") { + return true + } + for _, part := range c.Parts { + if strings.Contains(part.Body, "{{ObjectData}}") || strings.Contains(part.Body, "{{ObjectDataString}}") { + return true + } + } + return false +} + +// HasMultipartFiles returns true if at least a file must be uploaded via a multipart request +func (c *EventActionHTTPConfig) HasMultipartFiles() bool { + for _, part := range c.Parts { + if part.Filepath != "" && part.Filepath != RetentionReportPlaceHolder { + return true + } + } + return false +} + +// TryDecryptPassword decrypts the password if encryptet +func (c *EventActionHTTPConfig) TryDecryptPassword() error { + if c.Password != nil && !c.Password.IsEmpty() { + if err := c.Password.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt HTTP password: %w", err) + } + } + return nil +} + +// GetHTTPClient returns an HTTP client based on the config +func (c *EventActionHTTPConfig) GetHTTPClient() *http.Client { + client := &http.Client{} + if c.SkipTLSVerify { + transport := http.DefaultTransport.(*http.Transport).Clone() + if transport.TLSClientConfig != nil { + transport.TLSClientConfig.InsecureSkipVerify = true + } else { + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + client.Transport = transport + } + return client +} + +// IsActionCommandAllowed returns true if the specified command is allowed +func IsActionCommandAllowed(cmd string) bool { + return slices.Contains(EnabledActionCommands, cmd) +} + +// EventActionCommandConfig defines the configuration for a command event target +type EventActionCommandConfig struct { + Cmd string `json:"cmd,omitempty"` + Args []string `json:"args,omitempty"` + Timeout int `json:"timeout,omitempty"` + EnvVars []KeyValue `json:"env_vars,omitempty"` +} + +func (c *EventActionCommandConfig) validate() error { + if c.Cmd == "" { + return util.NewI18nError(util.NewValidationError("command is required"), util.I18nErrorCommandRequired) + } + if !IsActionCommandAllowed(c.Cmd) { + return util.NewValidationError(fmt.Sprintf("command %q is not allowed", c.Cmd)) + } + if !filepath.IsAbs(c.Cmd) { + return util.NewI18nError( + util.NewValidationError("invalid command, it must be an absolute path"), + util.I18nErrorCommandInvalid, + ) + } + if c.Timeout < 1 || c.Timeout > 120 { + return util.NewValidationError(fmt.Sprintf("invalid command action timeout %d", c.Timeout)) + } + for _, kv := range c.EnvVars { + if kv.isNotValid() { + return util.NewValidationError("invalid command env vars") + } + } + c.Args = util.RemoveDuplicates(c.Args, true) + for _, arg := range c.Args { + if arg == "" { + return util.NewValidationError("invalid command args") + } + } + return nil +} + +// GetArgumentsAsString returns the list of command arguments as comma separated string +func (c EventActionCommandConfig) GetArgumentsAsString() string { + return strings.Join(c.Args, ",") +} + +// EventActionEmailConfig defines the configuration options for SMTP event actions +type EventActionEmailConfig struct { + Recipients []string `json:"recipients,omitempty"` + Bcc []string `json:"bcc,omitempty"` + Subject string `json:"subject,omitempty"` + Body string `json:"body,omitempty"` + Attachments []string `json:"attachments,omitempty"` + ContentType int `json:"content_type,omitempty"` +} + +// GetRecipientsAsString returns the list of recipients as comma separated string +func (c EventActionEmailConfig) GetRecipientsAsString() string { + return strings.Join(c.Recipients, ",") +} + +// GetBccAsString returns the list of bcc as comma separated string +func (c EventActionEmailConfig) GetBccAsString() string { + return strings.Join(c.Bcc, ",") +} + +// GetAttachmentsAsString returns the list of attachments as comma separated string +func (c EventActionEmailConfig) GetAttachmentsAsString() string { + return strings.Join(c.Attachments, ",") +} + +func (c *EventActionEmailConfig) hasFilesAttachments() bool { + for _, a := range c.Attachments { + if a != RetentionReportPlaceHolder { + return true + } + } + return false +} + +func (c *EventActionEmailConfig) validate() error { + if len(c.Recipients) == 0 { + return util.NewI18nError( + util.NewValidationError("at least one email recipient is required"), + util.I18nErrorEmailRecipientRequired, + ) + } + c.Recipients = util.RemoveDuplicates(c.Recipients, false) + for _, r := range c.Recipients { + if r == "" { + return util.NewValidationError("invalid email recipients") + } + } + c.Bcc = util.RemoveDuplicates(c.Bcc, false) + for _, r := range c.Bcc { + if r == "" { + return util.NewValidationError("invalid email bcc") + } + } + if c.Subject == "" { + return util.NewI18nError( + util.NewValidationError("email subject is required"), + util.I18nErrorEmailSubjectRequired, + ) + } + if c.Body == "" { + return util.NewI18nError( + util.NewValidationError("email body is required"), + util.I18nErrorEmailBodyRequired, + ) + } + if c.ContentType < 0 || c.ContentType > 1 { + return util.NewValidationError("invalid email content type") + } + for idx, val := range c.Attachments { + val = strings.TrimSpace(val) + if val == "" { + return util.NewValidationError("invalid path to attach") + } + if val == RetentionReportPlaceHolder { + c.Attachments[idx] = val + } else { + c.Attachments[idx] = util.CleanPath(val) + } + } + c.Attachments = util.RemoveDuplicates(c.Attachments, false) + return nil +} + +// FolderRetention defines a folder retention configuration +type FolderRetention struct { + // Path is the virtual directory path, if no other specific retention is defined, + // the retention applies for sub directories too. For example if retention is defined + // for the paths "/" and "/sub" then the retention for "/" is applied for any file outside + // the "/sub" directory + Path string `json:"path"` + // Retention time in hours. 0 means exclude this path + Retention int `json:"retention"` + // DeleteEmptyDirs defines if empty directories will be deleted. + // The user need the delete permission + DeleteEmptyDirs bool `json:"delete_empty_dirs,omitempty"` +} + +// Validate returns an error if the configuration is not valid +func (f *FolderRetention) Validate() error { + f.Path = util.CleanPath(f.Path) + if f.Retention < 0 { + return util.NewValidationError(fmt.Sprintf("invalid folder retention %v, it must be greater or equal to zero", + f.Retention)) + } + return nil +} + +// EventActionDataRetentionConfig defines the configuration for a data retention check +type EventActionDataRetentionConfig struct { + Folders []FolderRetention `json:"folders,omitempty"` +} + +func (c *EventActionDataRetentionConfig) validate() error { + folderPaths := make(map[string]bool) + nothingToDo := true + for idx := range c.Folders { + f := &c.Folders[idx] + if err := f.Validate(); err != nil { + return err + } + if f.Retention > 0 { + nothingToDo = false + } + if _, ok := folderPaths[f.Path]; ok { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("duplicated folder path %q", f.Path)), + util.I18nErrorPathDuplicated, + ) + } + folderPaths[f.Path] = true + } + if nothingToDo { + return util.NewI18nError( + util.NewValidationError("nothing to delete!"), + util.I18nErrorRetentionDirRequired, + ) + } + return nil +} + +// EventActionFsCompress defines the configuration for the compress filesystem action +type EventActionFsCompress struct { + // Archive path + Name string `json:"name,omitempty"` + // Paths to compress + Paths []string `json:"paths,omitempty"` +} + +func (c *EventActionFsCompress) validate() error { + if c.Name == "" { + return util.NewI18nError(util.NewValidationError("archive name is mandatory"), util.I18nErrorArchiveNameRequired) + } + c.Name = util.CleanPath(strings.TrimSpace(c.Name)) + if c.Name == "/" { + return util.NewI18nError(util.NewValidationError("invalid archive name"), util.I18nErrorRootNotAllowed) + } + if len(c.Paths) == 0 { + return util.NewI18nError(util.NewValidationError("no path to compress specified"), util.I18nErrorPathRequired) + } + for idx, val := range c.Paths { + val = strings.TrimSpace(val) + if val == "" { + return util.NewValidationError("invalid path to compress") + } + c.Paths[idx] = util.CleanPath(val) + } + c.Paths = util.RemoveDuplicates(c.Paths, false) + return nil +} + +// RenameConfig defines the configuration for a filesystem rename +type RenameConfig struct { + // key is the source and target the value + KeyValue + // This setting only applies to storage providers that support + // changing modification times. + UpdateModTime bool `json:"update_modtime,omitempty"` +} + +// EventActionFilesystemConfig defines the configuration for filesystem actions +type EventActionFilesystemConfig struct { + // Filesystem actions, see the above enum + Type int `json:"type,omitempty"` + // files/dirs to rename + Renames []RenameConfig `json:"renames,omitempty"` + // directories to create + MkDirs []string `json:"mkdirs,omitempty"` + // files/dirs to delete + Deletes []string `json:"deletes,omitempty"` + // file/dirs to check for existence + Exist []string `json:"exist,omitempty"` + // files/dirs to copy, key is the source and target the value + Copy []KeyValue `json:"copy,omitempty"` + // paths to compress and archive name + Compress EventActionFsCompress `json:"compress"` +} + +// GetDeletesAsString returns the list of items to delete as comma separated string. +// Using a pointer receiver will not work in web templates +func (c EventActionFilesystemConfig) GetDeletesAsString() string { + return strings.Join(c.Deletes, ",") +} + +// GetMkDirsAsString returns the list of directories to create as comma separated string. +// Using a pointer receiver will not work in web templates +func (c EventActionFilesystemConfig) GetMkDirsAsString() string { + return strings.Join(c.MkDirs, ",") +} + +// GetExistAsString returns the list of items to check for existence as comma separated string. +// Using a pointer receiver will not work in web templates +func (c EventActionFilesystemConfig) GetExistAsString() string { + return strings.Join(c.Exist, ",") +} + +// GetCompressPathsAsString returns the list of items to compress as comma separated string. +// Using a pointer receiver will not work in web templates +func (c EventActionFilesystemConfig) GetCompressPathsAsString() string { + return strings.Join(c.Compress.Paths, ",") +} + +func (c *EventActionFilesystemConfig) validateRenames() error { + if len(c.Renames) == 0 { + return util.NewI18nError(util.NewValidationError("no path to rename specified"), util.I18nErrorPathRequired) + } + for idx, cfg := range c.Renames { + key := strings.TrimSpace(cfg.Key) + value := strings.TrimSpace(cfg.Value) + if key == "" || value == "" { + return util.NewValidationError("invalid paths to rename") + } + key = util.CleanPath(key) + value = util.CleanPath(value) + if key == value { + return util.NewI18nError( + util.NewValidationError("rename source and target cannot be equal"), + util.I18nErrorSourceDestMatch, + ) + } + if key == "/" || value == "/" { + return util.NewI18nError( + util.NewValidationError("renaming the root directory is not allowed"), + util.I18nErrorRootNotAllowed, + ) + } + c.Renames[idx] = RenameConfig{ + KeyValue: KeyValue{ + Key: key, + Value: value, + }, + UpdateModTime: cfg.UpdateModTime, + } + } + return nil +} + +func (c *EventActionFilesystemConfig) validateCopy() error { + if len(c.Copy) == 0 { + return util.NewI18nError(util.NewValidationError("no path to copy specified"), util.I18nErrorPathRequired) + } + for idx, kv := range c.Copy { + key := strings.TrimSpace(kv.Key) + value := strings.TrimSpace(kv.Value) + if key == "" || value == "" { + return util.NewValidationError("invalid paths to copy") + } + key = util.CleanPath(key) + value = util.CleanPath(value) + if key == value { + return util.NewI18nError( + util.NewValidationError("copy source and target cannot be equal"), + util.I18nErrorSourceDestMatch, + ) + } + if key == "/" || value == "/" { + return util.NewI18nError( + util.NewValidationError("copying the root directory is not allowed"), + util.I18nErrorRootNotAllowed, + ) + } + if strings.HasSuffix(c.Copy[idx].Key, "/") { + key += "/" + } + if strings.HasSuffix(c.Copy[idx].Value, "/") { + value += "/" + } + c.Copy[idx] = KeyValue{ + Key: key, + Value: value, + } + } + return nil +} + +func (c *EventActionFilesystemConfig) validateDeletes() error { + if len(c.Deletes) == 0 { + return util.NewI18nError(util.NewValidationError("no path to delete specified"), util.I18nErrorPathRequired) + } + for idx, val := range c.Deletes { + val = strings.TrimSpace(val) + if val == "" { + return util.NewValidationError("invalid path to delete") + } + c.Deletes[idx] = util.CleanPath(val) + } + c.Deletes = util.RemoveDuplicates(c.Deletes, false) + return nil +} + +func (c *EventActionFilesystemConfig) validateMkdirs() error { + if len(c.MkDirs) == 0 { + return util.NewI18nError(util.NewValidationError("no directory to create specified"), util.I18nErrorPathRequired) + } + for idx, val := range c.MkDirs { + val = strings.TrimSpace(val) + if val == "" { + return util.NewValidationError("invalid directory to create") + } + c.MkDirs[idx] = util.CleanPath(val) + } + c.MkDirs = util.RemoveDuplicates(c.MkDirs, false) + return nil +} + +func (c *EventActionFilesystemConfig) validateExist() error { + if len(c.Exist) == 0 { + return util.NewI18nError(util.NewValidationError("no path to check for existence specified"), util.I18nErrorPathRequired) + } + for idx, val := range c.Exist { + val = strings.TrimSpace(val) + if val == "" { + return util.NewValidationError("invalid path to check for existence") + } + c.Exist[idx] = util.CleanPath(val) + } + c.Exist = util.RemoveDuplicates(c.Exist, false) + return nil +} + +func (c *EventActionFilesystemConfig) validate() error { + if !isFilesystemActionValid(c.Type) { + return util.NewValidationError(fmt.Sprintf("invalid filesystem action type: %d", c.Type)) + } + switch c.Type { + case FilesystemActionRename: + c.MkDirs = nil + c.Deletes = nil + c.Exist = nil + c.Copy = nil + c.Compress = EventActionFsCompress{} + if err := c.validateRenames(); err != nil { + return err + } + case FilesystemActionDelete: + c.Renames = nil + c.MkDirs = nil + c.Exist = nil + c.Copy = nil + c.Compress = EventActionFsCompress{} + if err := c.validateDeletes(); err != nil { + return err + } + case FilesystemActionMkdirs: + c.Renames = nil + c.Deletes = nil + c.Exist = nil + c.Copy = nil + c.Compress = EventActionFsCompress{} + if err := c.validateMkdirs(); err != nil { + return err + } + case FilesystemActionExist: + c.Renames = nil + c.Deletes = nil + c.MkDirs = nil + c.Copy = nil + c.Compress = EventActionFsCompress{} + if err := c.validateExist(); err != nil { + return err + } + case FilesystemActionCompress: + c.Renames = nil + c.MkDirs = nil + c.Deletes = nil + c.Exist = nil + c.Copy = nil + if err := c.Compress.validate(); err != nil { + return err + } + case FilesystemActionCopy: + c.Renames = nil + c.Deletes = nil + c.MkDirs = nil + c.Exist = nil + c.Compress = EventActionFsCompress{} + if err := c.validateCopy(); err != nil { + return err + } + } + return nil +} + +func (c *EventActionFilesystemConfig) getACopy() EventActionFilesystemConfig { + mkdirs := make([]string, len(c.MkDirs)) + copy(mkdirs, c.MkDirs) + deletes := make([]string, len(c.Deletes)) + copy(deletes, c.Deletes) + exist := make([]string, len(c.Exist)) + copy(exist, c.Exist) + compressPaths := make([]string, len(c.Compress.Paths)) + copy(compressPaths, c.Compress.Paths) + + return EventActionFilesystemConfig{ + Type: c.Type, + Renames: cloneRenameConfigs(c.Renames), + MkDirs: mkdirs, + Deletes: deletes, + Exist: exist, + Copy: cloneKeyValues(c.Copy), + Compress: EventActionFsCompress{ + Paths: compressPaths, + Name: c.Compress.Name, + }, + } +} + +// EventActionPasswordExpiration defines the configuration for password expiration actions +type EventActionPasswordExpiration struct { + // An email notification will be generated for users whose password expires in a number + // of days less than or equal to this threshold + Threshold int `json:"threshold,omitempty"` +} + +func (c *EventActionPasswordExpiration) validate() error { + if c.Threshold <= 0 { + return util.NewValidationError("threshold must be greater than 0") + } + return nil +} + +// EventActionUserInactivity defines the configuration for user inactivity checks. +type EventActionUserInactivity struct { + // DisableThreshold defines inactivity in days, since the last login before disabling the account + DisableThreshold int `json:"disable_threshold,omitempty"` + // DeleteThreshold defines inactivity in days, since the last login before deleting the account + DeleteThreshold int `json:"delete_threshold,omitempty"` +} + +func (c *EventActionUserInactivity) validate() error { + if c.DeleteThreshold < 0 { + c.DeleteThreshold = 0 + } + if c.DisableThreshold < 0 { + c.DisableThreshold = 0 + } + if c.DisableThreshold == 0 && c.DeleteThreshold == 0 { + return util.NewI18nError( + util.NewValidationError("at least a threshold must be defined"), + util.I18nActionThresholdRequired, + ) + } + if c.DeleteThreshold > 0 && c.DisableThreshold > 0 { + if c.DeleteThreshold <= c.DisableThreshold { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("deletion threshold %d must be greater than deactivation threshold: %d", c.DeleteThreshold, c.DisableThreshold)), + util.I18nActionThresholdsInvalid, + ) + } + } + return nil +} + +// EventActionIDPAccountCheck defines the check to execute after a successful IDP login +type EventActionIDPAccountCheck struct { + // 0 create/update, 1 create the account if it doesn't exist + Mode int `json:"mode,omitempty"` + TemplateUser string `json:"template_user,omitempty"` + TemplateAdmin string `json:"template_admin,omitempty"` +} + +func (c *EventActionIDPAccountCheck) validate() error { + if c.TemplateAdmin == "" && c.TemplateUser == "" { + return util.NewI18nError( + util.NewValidationError("at least a template must be set"), + util.I18nErrorIDPTemplateRequired, + ) + } + if c.Mode < 0 || c.Mode > 1 { + return util.NewValidationError(fmt.Sprintf("invalid account check mode: %d", c.Mode)) + } + return nil +} + +// BaseEventActionOptions defines the supported configuration options for a base event actions +type BaseEventActionOptions struct { + HTTPConfig EventActionHTTPConfig `json:"http_config"` + CmdConfig EventActionCommandConfig `json:"cmd_config"` + EmailConfig EventActionEmailConfig `json:"email_config"` + RetentionConfig EventActionDataRetentionConfig `json:"retention_config"` + FsConfig EventActionFilesystemConfig `json:"fs_config"` + PwdExpirationConfig EventActionPasswordExpiration `json:"pwd_expiration_config"` + UserInactivityConfig EventActionUserInactivity `json:"user_inactivity_config"` + IDPConfig EventActionIDPAccountCheck `json:"idp_config"` +} + +func (o *BaseEventActionOptions) getACopy() BaseEventActionOptions { + o.SetEmptySecretsIfNil() + emailRecipients := make([]string, len(o.EmailConfig.Recipients)) + copy(emailRecipients, o.EmailConfig.Recipients) + emailBcc := make([]string, len(o.EmailConfig.Bcc)) + copy(emailBcc, o.EmailConfig.Bcc) + emailAttachments := make([]string, len(o.EmailConfig.Attachments)) + copy(emailAttachments, o.EmailConfig.Attachments) + cmdArgs := make([]string, len(o.CmdConfig.Args)) + copy(cmdArgs, o.CmdConfig.Args) + folders := make([]FolderRetention, 0, len(o.RetentionConfig.Folders)) + for _, folder := range o.RetentionConfig.Folders { + folders = append(folders, FolderRetention{ + Path: folder.Path, + Retention: folder.Retention, + DeleteEmptyDirs: folder.DeleteEmptyDirs, + }) + } + httpParts := make([]HTTPPart, 0, len(o.HTTPConfig.Parts)) + for _, part := range o.HTTPConfig.Parts { + httpParts = append(httpParts, HTTPPart{ + Name: part.Name, + Filepath: part.Filepath, + Headers: cloneKeyValues(part.Headers), + Body: part.Body, + }) + } + + return BaseEventActionOptions{ + HTTPConfig: EventActionHTTPConfig{ + Endpoint: o.HTTPConfig.Endpoint, + Username: o.HTTPConfig.Username, + Password: o.HTTPConfig.Password.Clone(), + Headers: cloneKeyValues(o.HTTPConfig.Headers), + Timeout: o.HTTPConfig.Timeout, + SkipTLSVerify: o.HTTPConfig.SkipTLSVerify, + Method: o.HTTPConfig.Method, + QueryParameters: cloneKeyValues(o.HTTPConfig.QueryParameters), + Body: o.HTTPConfig.Body, + Parts: httpParts, + }, + CmdConfig: EventActionCommandConfig{ + Cmd: o.CmdConfig.Cmd, + Args: cmdArgs, + Timeout: o.CmdConfig.Timeout, + EnvVars: cloneKeyValues(o.CmdConfig.EnvVars), + }, + EmailConfig: EventActionEmailConfig{ + Recipients: emailRecipients, + Bcc: emailBcc, + Subject: o.EmailConfig.Subject, + ContentType: o.EmailConfig.ContentType, + Body: o.EmailConfig.Body, + Attachments: emailAttachments, + }, + RetentionConfig: EventActionDataRetentionConfig{ + Folders: folders, + }, + PwdExpirationConfig: EventActionPasswordExpiration{ + Threshold: o.PwdExpirationConfig.Threshold, + }, + UserInactivityConfig: EventActionUserInactivity{ + DisableThreshold: o.UserInactivityConfig.DisableThreshold, + DeleteThreshold: o.UserInactivityConfig.DeleteThreshold, + }, + IDPConfig: EventActionIDPAccountCheck{ + Mode: o.IDPConfig.Mode, + TemplateUser: o.IDPConfig.TemplateUser, + TemplateAdmin: o.IDPConfig.TemplateAdmin, + }, + FsConfig: o.FsConfig.getACopy(), + } +} + +// SetEmptySecretsIfNil sets the secrets to empty if nil +func (o *BaseEventActionOptions) SetEmptySecretsIfNil() { + if o.HTTPConfig.Password == nil { + o.HTTPConfig.Password = kms.NewEmptySecret() + } +} + +func (o *BaseEventActionOptions) setNilSecretsIfEmpty() { + if o.HTTPConfig.Password != nil && o.HTTPConfig.Password.IsEmpty() { + o.HTTPConfig.Password = nil + } +} + +func (o *BaseEventActionOptions) hideConfidentialData() { + if o.HTTPConfig.Password != nil { + o.HTTPConfig.Password.Hide() + } +} + +func (o *BaseEventActionOptions) validate(action int, name string) error { + o.SetEmptySecretsIfNil() + switch action { + case ActionTypeHTTP: + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.HTTPConfig.validate(name) + case ActionTypeCommand: + o.HTTPConfig = EventActionHTTPConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.CmdConfig.validate() + case ActionTypeEmail: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.EmailConfig.validate() + case ActionTypeDataRetentionCheck: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.RetentionConfig.validate() + case ActionTypeFilesystem: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.FsConfig.validate() + case ActionTypePasswordExpirationCheck: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.PwdExpirationConfig.validate() + case ActionTypeUserInactivityCheck: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + return o.UserInactivityConfig.validate() + case ActionTypeIDPAccountCheck: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.UserInactivityConfig = EventActionUserInactivity{} + return o.IDPConfig.validate() + default: + o.HTTPConfig = EventActionHTTPConfig{} + o.CmdConfig = EventActionCommandConfig{} + o.EmailConfig = EventActionEmailConfig{} + o.RetentionConfig = EventActionDataRetentionConfig{} + o.FsConfig = EventActionFilesystemConfig{} + o.PwdExpirationConfig = EventActionPasswordExpiration{} + o.IDPConfig = EventActionIDPAccountCheck{} + o.UserInactivityConfig = EventActionUserInactivity{} + } + return nil +} + +// BaseEventAction defines the common fields for an event action +type BaseEventAction struct { + // Data provider unique identifier + ID int64 `json:"id"` + // Action name + Name string `json:"name"` + // optional description + Description string `json:"description,omitempty"` + // ActionType, see the above enum + Type int `json:"type"` + // Configuration options specific for the action type + Options BaseEventActionOptions `json:"options"` + // list of rule names associated with this event action + Rules []string `json:"rules,omitempty"` +} + +func (a *BaseEventAction) getACopy() BaseEventAction { + rules := make([]string, len(a.Rules)) + copy(rules, a.Rules) + return BaseEventAction{ + ID: a.ID, + Name: a.Name, + Description: a.Description, + Type: a.Type, + Options: a.Options.getACopy(), + Rules: rules, + } +} + +// GetTypeAsString returns the action type as string +func (a *BaseEventAction) GetTypeAsString() string { + return getActionTypeAsString(a.Type) +} + +// GetRulesAsString returns the list of rules as comma separated string +func (a *BaseEventAction) GetRulesAsString() string { + return strings.Join(a.Rules, ",") +} + +// PrepareForRendering prepares a BaseEventAction for rendering. +// It hides confidential data and set to nil the empty secrets +// so they are not serialized +func (a *BaseEventAction) PrepareForRendering() { + a.Options.setNilSecretsIfEmpty() + a.Options.hideConfidentialData() +} + +// RenderAsJSON implements the renderer interface used within plugins +func (a *BaseEventAction) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + action, err := provider.eventActionExists(a.Name) + if err != nil { + providerLog(logger.LevelError, "unable to reload event action before rendering as json: %v", err) + return nil, err + } + action.PrepareForRendering() + return json.Marshal(action) + } + a.PrepareForRendering() + return json.Marshal(a) +} + +func (a *BaseEventAction) validate() error { + if a.Name == "" { + return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) + } + if !util.IsNameValid(a.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(a.Name) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", a.Name)), + util.I18nErrorInvalidUser, + ) + } + if !isActionTypeValid(a.Type) { + return util.NewValidationError(fmt.Sprintf("invalid action type: %d", a.Type)) + } + return a.Options.validate(a.Type, a.Name) +} + +// EventActionOptions defines the supported configuration options for an event action +type EventActionOptions struct { + IsFailureAction bool `json:"is_failure_action"` + StopOnFailure bool `json:"stop_on_failure"` + ExecuteSync bool `json:"execute_sync"` +} + +// EventAction defines an event action +type EventAction struct { + BaseEventAction + // Order defines the execution order + Order int `json:"order,omitempty"` + Options EventActionOptions `json:"relation_options"` +} + +func (a *EventAction) getACopy() EventAction { + return EventAction{ + BaseEventAction: a.BaseEventAction.getACopy(), + Order: a.Order, + Options: EventActionOptions{ + IsFailureAction: a.Options.IsFailureAction, + StopOnFailure: a.Options.StopOnFailure, + ExecuteSync: a.Options.ExecuteSync, + }, + } +} + +func (a *EventAction) validateAssociation(trigger int, fsEvents []string) error { + if a.Options.IsFailureAction { + if a.Options.ExecuteSync { + return util.NewI18nError( + util.NewValidationError("sync execution is not supported for failure actions"), + util.I18nErrorEvSyncFailureActions, + ) + } + } + if a.Options.ExecuteSync { + if trigger != EventTriggerFsEvent && trigger != EventTriggerIDPLogin { + return util.NewI18nError( + util.NewValidationError("sync execution is only supported for some filesystem events and Identity Provider logins"), + util.I18nErrorEvSyncUnsupported, + ) + } + if trigger == EventTriggerFsEvent { + for _, ev := range fsEvents { + if !slices.Contains(allowedSyncFsEvents, ev) { + return util.NewI18nError( + util.NewValidationError("sync execution is only supported for upload and pre-* events"), + util.I18nErrorEvSyncUnsupportedFs, + ) + } + } + } + } + return nil +} + +// ConditionPattern defines a pattern for condition filters +type ConditionPattern struct { + Pattern string `json:"pattern,omitempty"` + InverseMatch bool `json:"inverse_match,omitempty"` +} + +func (p *ConditionPattern) validate() error { + if p.Pattern == "" { + return util.NewValidationError("empty condition pattern not allowed") + } + _, err := path.Match(p.Pattern, "abc") + if err != nil { + return util.NewValidationError(fmt.Sprintf("invalid condition pattern %q", p.Pattern)) + } + return nil +} + +// ConditionOptions defines options for event conditions +type ConditionOptions struct { + // Usernames or folder names + Names []ConditionPattern `json:"names,omitempty"` + // Group names + GroupNames []ConditionPattern `json:"group_names,omitempty"` + // Role names + RoleNames []ConditionPattern `json:"role_names,omitempty"` + // Virtual paths + FsPaths []ConditionPattern `json:"fs_paths,omitempty"` + Protocols []string `json:"protocols,omitempty"` + ProviderObjects []string `json:"provider_objects,omitempty"` + MinFileSize int64 `json:"min_size,omitempty"` + MaxFileSize int64 `json:"max_size,omitempty"` + EventStatuses []int `json:"event_statuses,omitempty"` + // allow to execute scheduled tasks concurrently from multiple instances + ConcurrentExecution bool `json:"concurrent_execution,omitempty"` +} + +func (f *ConditionOptions) getACopy() ConditionOptions { + protocols := make([]string, len(f.Protocols)) + copy(protocols, f.Protocols) + providerObjects := make([]string, len(f.ProviderObjects)) + copy(providerObjects, f.ProviderObjects) + statuses := make([]int, len(f.EventStatuses)) + copy(statuses, f.EventStatuses) + + return ConditionOptions{ + Names: cloneConditionPatterns(f.Names), + GroupNames: cloneConditionPatterns(f.GroupNames), + RoleNames: cloneConditionPatterns(f.RoleNames), + FsPaths: cloneConditionPatterns(f.FsPaths), + Protocols: protocols, + ProviderObjects: providerObjects, + MinFileSize: f.MinFileSize, + MaxFileSize: f.MaxFileSize, + EventStatuses: statuses, + ConcurrentExecution: f.ConcurrentExecution, + } +} + +func (f *ConditionOptions) validateStatuses() error { + for _, status := range f.EventStatuses { + if status < 0 || status > 3 { + return util.NewValidationError(fmt.Sprintf("invalid event_status %d", status)) + } + } + return nil +} + +func (f *ConditionOptions) validate() error { + if err := validateConditionPatterns(f.Names); err != nil { + return err + } + if err := validateConditionPatterns(f.GroupNames); err != nil { + return err + } + if err := validateConditionPatterns(f.RoleNames); err != nil { + return err + } + if err := validateConditionPatterns(f.FsPaths); err != nil { + return err + } + + for _, p := range f.Protocols { + if !slices.Contains(SupportedRuleConditionProtocols, p) { + return util.NewValidationError(fmt.Sprintf("unsupported rule condition protocol: %q", p)) + } + } + for _, p := range f.ProviderObjects { + if !slices.Contains(SupporteRuleConditionProviderObjects, p) { + return util.NewValidationError(fmt.Sprintf("unsupported provider object: %q", p)) + } + } + if f.MinFileSize > 0 && f.MaxFileSize > 0 { + if f.MaxFileSize <= f.MinFileSize { + return util.NewValidationError(fmt.Sprintf("invalid max file size %s, it is lesser or equal than min file size %s", + util.ByteCountSI(f.MaxFileSize), util.ByteCountSI(f.MinFileSize))) + } + } + if err := f.validateStatuses(); err != nil { + return err + } + if config.IsShared == 0 { + f.ConcurrentExecution = false + } + return nil +} + +// Schedule defines an event schedule +type Schedule struct { + Hours string `json:"hour"` + DayOfWeek string `json:"day_of_week"` + DayOfMonth string `json:"day_of_month"` + Month string `json:"month"` +} + +// GetCronSpec returns the cron compatible schedule string +func (s *Schedule) GetCronSpec() string { + return fmt.Sprintf("0 %s %s %s %s", s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek) +} + +func (s *Schedule) validate() error { + _, err := cron.ParseStandard(s.GetCronSpec()) + if err != nil { + return util.NewValidationError(fmt.Sprintf("invalid schedule, hour: %q, day of month: %q, month: %q, day of week: %q", + s.Hours, s.DayOfMonth, s.Month, s.DayOfWeek)) + } + return nil +} + +// EventConditions defines the conditions for an event rule +type EventConditions struct { + // Only one between FsEvents, ProviderEvents and Schedule is allowed + FsEvents []string `json:"fs_events,omitempty"` + ProviderEvents []string `json:"provider_events,omitempty"` + Schedules []Schedule `json:"schedules,omitempty"` + // 0 any, 1 user, 2 admin + IDPLoginEvent int `json:"idp_login_event,omitempty"` + Options ConditionOptions `json:"options"` +} + +func (c *EventConditions) getACopy() EventConditions { + fsEvents := make([]string, len(c.FsEvents)) + copy(fsEvents, c.FsEvents) + providerEvents := make([]string, len(c.ProviderEvents)) + copy(providerEvents, c.ProviderEvents) + schedules := make([]Schedule, 0, len(c.Schedules)) + for _, schedule := range c.Schedules { + schedules = append(schedules, Schedule{ + Hours: schedule.Hours, + DayOfWeek: schedule.DayOfWeek, + DayOfMonth: schedule.DayOfMonth, + Month: schedule.Month, + }) + } + + return EventConditions{ + FsEvents: fsEvents, + ProviderEvents: providerEvents, + Schedules: schedules, + IDPLoginEvent: c.IDPLoginEvent, + Options: c.Options.getACopy(), + } +} + +func (c *EventConditions) validateSchedules() error { + if len(c.Schedules) == 0 { + return util.NewI18nError( + util.NewValidationError("at least one schedule is required"), + util.I18nErrorRuleScheduleRequired, + ) + } + for _, schedule := range c.Schedules { + if err := schedule.validate(); err != nil { + return util.NewI18nError(err, util.I18nErrorRuleScheduleInvalid) + } + } + return nil +} + +func (c *EventConditions) validate(trigger int) error { + switch trigger { + case EventTriggerFsEvent: + c.ProviderEvents = nil + c.Schedules = nil + c.Options.ProviderObjects = nil + c.IDPLoginEvent = 0 + if len(c.FsEvents) == 0 { + return util.NewI18nError( + util.NewValidationError("at least one filesystem event is required"), + util.I18nErrorRuleFsEventRequired, + ) + } + for _, ev := range c.FsEvents { + if !slices.Contains(SupportedFsEvents, ev) { + return util.NewValidationError(fmt.Sprintf("unsupported fs event: %q", ev)) + } + } + case EventTriggerProviderEvent: + c.FsEvents = nil + c.Schedules = nil + c.Options.FsPaths = nil + c.Options.Protocols = nil + c.Options.EventStatuses = nil + c.Options.MinFileSize = 0 + c.Options.MaxFileSize = 0 + c.IDPLoginEvent = 0 + if len(c.ProviderEvents) == 0 { + return util.NewI18nError( + util.NewValidationError("at least one provider event is required"), + util.I18nErrorRuleProviderEventRequired, + ) + } + for _, ev := range c.ProviderEvents { + if !slices.Contains(SupportedProviderEvents, ev) { + return util.NewValidationError(fmt.Sprintf("unsupported provider event: %q", ev)) + } + } + case EventTriggerSchedule: + c.FsEvents = nil + c.ProviderEvents = nil + c.Options.FsPaths = nil + c.Options.Protocols = nil + c.Options.EventStatuses = nil + c.Options.MinFileSize = 0 + c.Options.MaxFileSize = 0 + c.Options.ProviderObjects = nil + c.IDPLoginEvent = 0 + if err := c.validateSchedules(); err != nil { + return err + } + case EventTriggerIPBlocked, EventTriggerCertificate: + c.FsEvents = nil + c.ProviderEvents = nil + c.Options.Names = nil + c.Options.GroupNames = nil + c.Options.RoleNames = nil + c.Options.FsPaths = nil + c.Options.Protocols = nil + c.Options.EventStatuses = nil + c.Options.MinFileSize = 0 + c.Options.MaxFileSize = 0 + c.Schedules = nil + c.IDPLoginEvent = 0 + case EventTriggerOnDemand: + c.FsEvents = nil + c.ProviderEvents = nil + c.Options.FsPaths = nil + c.Options.Protocols = nil + c.Options.EventStatuses = nil + c.Options.MinFileSize = 0 + c.Options.MaxFileSize = 0 + c.Options.ProviderObjects = nil + c.Schedules = nil + c.IDPLoginEvent = 0 + c.Options.ConcurrentExecution = false + case EventTriggerIDPLogin: + c.FsEvents = nil + c.ProviderEvents = nil + c.Options.GroupNames = nil + c.Options.RoleNames = nil + c.Options.FsPaths = nil + c.Options.Protocols = nil + c.Options.EventStatuses = nil + c.Options.MinFileSize = 0 + c.Options.MaxFileSize = 0 + c.Schedules = nil + if !slices.Contains(supportedIDPLoginEvents, c.IDPLoginEvent) { + return util.NewValidationError(fmt.Sprintf("invalid Identity Provider login event %d", c.IDPLoginEvent)) + } + default: + c.FsEvents = nil + c.ProviderEvents = nil + c.Options.GroupNames = nil + c.Options.RoleNames = nil + c.Options.FsPaths = nil + c.Options.Protocols = nil + c.Options.EventStatuses = nil + c.Options.MinFileSize = 0 + c.Options.MaxFileSize = 0 + c.Schedules = nil + c.IDPLoginEvent = 0 + } + + return c.Options.validate() +} + +// EventRule defines the trigger, conditions and actions for an event +type EventRule struct { + // Data provider unique identifier + ID int64 `json:"id"` + // Rule name + Name string `json:"name"` + // 1 enabled, 0 disabled + Status int `json:"status"` + // optional description + Description string `json:"description,omitempty"` + // Creation time as unix timestamp in milliseconds + CreatedAt int64 `json:"created_at"` + // last update time as unix timestamp in milliseconds + UpdatedAt int64 `json:"updated_at"` + // Event trigger + Trigger int `json:"trigger"` + // Event conditions + Conditions EventConditions `json:"conditions"` + // actions to execute + Actions []EventAction `json:"actions"` + // in multi node setups we mark the rule as deleted to be able to update the cache + DeletedAt int64 `json:"-"` +} + +func (r *EventRule) getACopy() EventRule { + actions := make([]EventAction, 0, len(r.Actions)) + for _, action := range r.Actions { + actions = append(actions, action.getACopy()) + } + + return EventRule{ + ID: r.ID, + Name: r.Name, + Status: r.Status, + Description: r.Description, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + Trigger: r.Trigger, + Conditions: r.Conditions.getACopy(), + Actions: actions, + DeletedAt: r.DeletedAt, + } +} + +// GuardFromConcurrentExecution returns true if the rule cannot be executed concurrently +// from multiple instances +func (r *EventRule) GuardFromConcurrentExecution() bool { + if config.IsShared == 0 { + return false + } + return !r.Conditions.Options.ConcurrentExecution +} + +// GetTriggerAsString returns the rule trigger as string +func (r *EventRule) GetTriggerAsString() string { + return getTriggerTypeAsString(r.Trigger) +} + +// GetActionsAsString returns the list of action names as comma separated string +func (r *EventRule) GetActionsAsString() string { + actions := make([]string, 0, len(r.Actions)) + for _, action := range r.Actions { + actions = append(actions, action.Name) + } + return strings.Join(actions, ",") +} + +func (r *EventRule) isStatusValid() bool { + return r.Status >= 0 && r.Status <= 1 +} + +func (r *EventRule) validate() error { //nolint:gocyclo + if r.Name == "" { + return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) + } + if !util.IsNameValid(r.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(r.Name) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", r.Name)), + util.I18nErrorInvalidUser, + ) + } + if !r.isStatusValid() { + return util.NewValidationError(fmt.Sprintf("invalid event rule status: %d", r.Status)) + } + if !isEventTriggerValid(r.Trigger) { + return util.NewValidationError(fmt.Sprintf("invalid event rule trigger: %d", r.Trigger)) + } + if err := r.Conditions.validate(r.Trigger); err != nil { + return err + } + if len(r.Actions) == 0 { + return util.NewI18nError(util.NewValidationError("at least one action is required"), util.I18nErrorRuleActionRequired) + } + actionNames := make(map[string]bool) + actionOrders := make(map[int]bool) + failureActions := 0 + hasSyncAction := false + for idx := range r.Actions { + if r.Actions[idx].Name == "" { + return util.NewValidationError(fmt.Sprintf("invalid action at position %d, name not specified", idx)) + } + if actionNames[r.Actions[idx].Name] { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("duplicated action %q", r.Actions[idx].Name)), + util.I18nErrorRuleDuplicateActions, + ) + } + if actionOrders[r.Actions[idx].Order] { + return util.NewValidationError(fmt.Sprintf("duplicated order %d for action %q", + r.Actions[idx].Order, r.Actions[idx].Name)) + } + if err := r.Actions[idx].validateAssociation(r.Trigger, r.Conditions.FsEvents); err != nil { + return err + } + if r.Actions[idx].Options.IsFailureAction { + failureActions++ + } + if r.Actions[idx].Options.ExecuteSync { + hasSyncAction = true + } + actionNames[r.Actions[idx].Name] = true + actionOrders[r.Actions[idx].Order] = true + } + if len(r.Actions) == failureActions { + return util.NewI18nError( + util.NewValidationError("at least a non-failure action is required"), + util.I18nErrorRuleFailureActionsOnly, + ) + } + if !hasSyncAction { + return r.validateMandatorySyncActions() + } + return nil +} + +func (r *EventRule) validateMandatorySyncActions() error { + if r.Trigger != EventTriggerFsEvent { + return nil + } + for _, ev := range r.Conditions.FsEvents { + if slices.Contains(mandatorySyncFsEvents, ev) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("event %q requires at least a sync action", ev)), + util.I18nErrorRuleSyncActionRequired, + util.I18nErrorArgs(map[string]any{ + "val": ev, + }), + ) + } + } + return nil +} + +func (r *EventRule) checkIPBlockedAndCertificateActions() error { + unavailableActions := []int{ActionTypeUserQuotaReset, ActionTypeFolderQuotaReset, ActionTypeTransferQuotaReset, + ActionTypeDataRetentionCheck, ActionTypeFilesystem, ActionTypePasswordExpirationCheck, + ActionTypeUserExpirationCheck} + for _, action := range r.Actions { + if slices.Contains(unavailableActions, action.Type) { + return fmt.Errorf("action %q, type %q is not supported for event trigger %q", + action.Name, getActionTypeAsString(action.Type), getTriggerTypeAsString(r.Trigger)) + } + } + return nil +} + +func (r *EventRule) checkProviderEventActions(providerObjectType string) error { + // user quota reset, transfer quota reset, data retention check and filesystem actions + // can be executed only if we modify a user. They will be executed for the + // affected user. Folder quota reset can be executed only for folders. + userSpecificActions := []int{ActionTypeUserQuotaReset, ActionTypeTransferQuotaReset, + ActionTypeDataRetentionCheck, ActionTypeFilesystem, + ActionTypePasswordExpirationCheck, ActionTypeUserExpirationCheck} + for _, action := range r.Actions { + if slices.Contains(userSpecificActions, action.Type) && providerObjectType != actionObjectUser { + return fmt.Errorf("action %q, type %q is only supported for provider user events", + action.Name, getActionTypeAsString(action.Type)) + } + if action.Type == ActionTypeFolderQuotaReset && providerObjectType != actionObjectFolder { + return fmt.Errorf("action %q, type %q is only supported for provider folder events", + action.Name, getActionTypeAsString(action.Type)) + } + } + return nil +} + +func (r *EventRule) hasUserAssociated(providerObjectType string) bool { + switch r.Trigger { + case EventTriggerProviderEvent: + return providerObjectType == actionObjectUser + case EventTriggerFsEvent: + return true + default: + if len(r.Actions) > 0 { + // should we allow schedules where backup is not the first action? + // maybe we could pass the action index and check before that index + return r.Actions[0].Type == ActionTypeBackup + } + } + return false +} + +func (r *EventRule) checkActions(providerObjectType string) error { + numSyncAction := 0 + hasIDPAccountCheck := false + for _, action := range r.Actions { + if action.Options.ExecuteSync { + numSyncAction++ + } + if action.Type == ActionTypeEmail && action.BaseEventAction.Options.EmailConfig.hasFilesAttachments() { + if !r.hasUserAssociated(providerObjectType) { + return errors.New("cannot send an email with attachments for a rule with no user associated") + } + } + if action.Type == ActionTypeHTTP && action.BaseEventAction.Options.HTTPConfig.HasMultipartFiles() { + if !r.hasUserAssociated(providerObjectType) { + return errors.New("cannot upload file/s for a rule with no user associated") + } + } + if action.Type == ActionTypeIDPAccountCheck { + if r.Trigger != EventTriggerIDPLogin { + return errors.New("IDP account check action is only supported for IDP login trigger") + } + if !action.Options.ExecuteSync { + return errors.New("IDP account check must be a sync action") + } + hasIDPAccountCheck = true + } + } + if hasIDPAccountCheck && numSyncAction != 1 { + return errors.New("IDP account check must be the only sync action") + } + return nil +} + +// CheckActionsConsistency returns an error if the actions cannot be executed +func (r *EventRule) CheckActionsConsistency(providerObjectType string) error { + switch r.Trigger { + case EventTriggerProviderEvent: + if err := r.checkProviderEventActions(providerObjectType); err != nil { + return err + } + case EventTriggerFsEvent: + // folder quota reset cannot be executed + for _, action := range r.Actions { + if action.Type == ActionTypeFolderQuotaReset { + return fmt.Errorf("action %q, type %q is not supported for filesystem events", + action.Name, getActionTypeAsString(action.Type)) + } + } + case EventTriggerIPBlocked, EventTriggerCertificate: + if err := r.checkIPBlockedAndCertificateActions(); err != nil { + return err + } + } + return r.checkActions(providerObjectType) +} + +// PrepareForRendering prepares an EventRule for rendering. +// It hides confidential data and set to nil the empty secrets +// so they are not serialized +func (r *EventRule) PrepareForRendering() { + for idx := range r.Actions { + r.Actions[idx].PrepareForRendering() + } +} + +// RenderAsJSON implements the renderer interface used within plugins +func (r *EventRule) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + rule, err := provider.eventRuleExists(r.Name) + if err != nil { + providerLog(logger.LevelError, "unable to reload event rule before rendering as json: %v", err) + return nil, err + } + rule.PrepareForRendering() + return json.Marshal(rule) + } + r.PrepareForRendering() + return json.Marshal(r) +} + +func cloneRenameConfigs(renames []RenameConfig) []RenameConfig { + res := make([]RenameConfig, 0, len(renames)) + for _, c := range renames { + res = append(res, RenameConfig{ + KeyValue: KeyValue{ + Key: c.Key, + Value: c.Value, + }, + UpdateModTime: c.UpdateModTime, + }) + } + return res +} + +func cloneKeyValues(keyVals []KeyValue) []KeyValue { + res := make([]KeyValue, 0, len(keyVals)) + for _, kv := range keyVals { + res = append(res, KeyValue{ + Key: kv.Key, + Value: kv.Value, + }) + } + return res +} + +func cloneConditionPatterns(patterns []ConditionPattern) []ConditionPattern { + res := make([]ConditionPattern, 0, len(patterns)) + for _, p := range patterns { + res = append(res, ConditionPattern{ + Pattern: p.Pattern, + InverseMatch: p.InverseMatch, + }) + } + return res +} + +func validateConditionPatterns(patterns []ConditionPattern) error { + for _, name := range patterns { + if err := name.validate(); err != nil { + return err + } + } + return nil +} + +// Task stores the state for a scheduled task +type Task struct { + Name string `json:"name"` + UpdateAt int64 `json:"updated_at"` + Version int64 `json:"version"` +} diff --git a/internal/dataprovider/group.go b/internal/dataprovider/group.go new file mode 100644 index 00000000..fa09fda4 --- /dev/null +++ b/internal/dataprovider/group.go @@ -0,0 +1,250 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "fmt" + "path/filepath" + "strings" + + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// GroupUserSettings defines the settings to apply to users +type GroupUserSettings struct { + sdk.BaseGroupUserSettings + // Filesystem configuration details + FsConfig vfs.Filesystem `json:"filesystem"` +} + +// Group defines an SFTPGo group. +// Groups are used to easily configure similar users +type Group struct { + sdk.BaseGroup + // settings to apply to users for whom this is a primary group + UserSettings GroupUserSettings `json:"user_settings,omitempty"` + // Mapping between virtual paths and virtual folders + VirtualFolders []vfs.VirtualFolder `json:"virtual_folders,omitempty"` +} + +// GetPermissions returns the permissions as list +func (g *Group) GetPermissions() []sdk.DirectoryPermissions { + result := make([]sdk.DirectoryPermissions, 0, len(g.UserSettings.Permissions)) + for k, v := range g.UserSettings.Permissions { + result = append(result, sdk.DirectoryPermissions{ + Path: k, + Permissions: v, + }) + } + return result +} + +// GetAllowedIPAsString returns the allowed IP as comma separated string +func (g *Group) GetAllowedIPAsString() string { + return strings.Join(g.UserSettings.Filters.AllowedIP, ",") +} + +// GetDeniedIPAsString returns the denied IP as comma separated string +func (g *Group) GetDeniedIPAsString() string { + return strings.Join(g.UserSettings.Filters.DeniedIP, ",") +} + +// HasExternalAuth returns true if the external authentication is globally enabled +// and it is not disabled for this group +func (g *Group) HasExternalAuth() bool { + if g.UserSettings.Filters.Hooks.ExternalAuthDisabled { + return false + } + if config.ExternalAuthHook != "" { + return true + } + return plugin.Handler.HasAuthenticators() +} + +// SetEmptySecretsIfNil sets the secrets to empty if nil +func (g *Group) SetEmptySecretsIfNil() { + g.UserSettings.FsConfig.SetEmptySecretsIfNil() + for idx := range g.VirtualFolders { + vfolder := &g.VirtualFolders[idx] + vfolder.FsConfig.SetEmptySecretsIfNil() + } +} + +// PrepareForRendering prepares a group for rendering. +// It hides confidential data and set to nil the empty secrets +// so they are not serialized +func (g *Group) PrepareForRendering() { + g.UserSettings.FsConfig.HideConfidentialData() + g.UserSettings.FsConfig.SetNilSecretsIfEmpty() + for idx := range g.VirtualFolders { + folder := &g.VirtualFolders[idx] + folder.PrepareForRendering() + } +} + +// RenderAsJSON implements the renderer interface used within plugins +func (g *Group) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + group, err := provider.groupExists(g.Name) + if err != nil { + providerLog(logger.LevelError, "unable to reload group before rendering as json: %v", err) + return nil, err + } + group.PrepareForRendering() + return json.Marshal(group) + } + g.PrepareForRendering() + return json.Marshal(g) +} + +// GetEncryptionAdditionalData returns the additional data to use for AEAD +func (g *Group) GetEncryptionAdditionalData() string { + return fmt.Sprintf("group_%v", g.Name) +} + +// HasRedactedSecret returns true if the user has a redacted secret +func (g *Group) hasRedactedSecret() bool { + for idx := range g.VirtualFolders { + folder := &g.VirtualFolders[idx] + if folder.HasRedactedSecret() { + return true + } + } + + return g.UserSettings.FsConfig.HasRedactedSecret() +} + +func (g *Group) applyNamingRules() { + g.Name = config.convertName(g.Name) + for idx := range g.VirtualFolders { + g.VirtualFolders[idx].Name = config.convertName(g.VirtualFolders[idx].Name) + } +} + +func (g *Group) validate() error { + g.SetEmptySecretsIfNil() + g.applyNamingRules() + if g.Name == "" { + return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) + } + if !util.IsNameValid(g.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(g.Name) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", g.Name)), + util.I18nErrorInvalidName, + ) + } + if g.hasRedactedSecret() { + return util.NewValidationError("cannot save a group with a redacted secret") + } + vfolders, err := validateAssociatedVirtualFolders(g.VirtualFolders) + if err != nil { + return err + } + g.VirtualFolders = vfolders + return g.validateUserSettings() +} + +func (g *Group) validateUserSettings() error { + if g.UserSettings.HomeDir != "" { + g.UserSettings.HomeDir = filepath.Clean(g.UserSettings.HomeDir) + if !filepath.IsAbs(g.UserSettings.HomeDir) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("home_dir must be an absolute path, actual value: %v", g.UserSettings.HomeDir)), + util.I18nErrorInvalidHomeDir, + ) + } + } + if err := g.UserSettings.FsConfig.Validate(g.GetEncryptionAdditionalData()); err != nil { + return err + } + if g.UserSettings.TotalDataTransfer > 0 { + // if a total data transfer is defined we reset the separate upload and download limits + g.UserSettings.UploadDataTransfer = 0 + g.UserSettings.DownloadDataTransfer = 0 + } + if len(g.UserSettings.Permissions) > 0 { + permissions, err := validateUserPermissions(g.UserSettings.Permissions) + if err != nil { + return util.NewI18nError(err, util.I18nErrorGenericPermission) + } + g.UserSettings.Permissions = permissions + } + g.UserSettings.Filters.TLSCerts = nil + if err := validateBaseFilters(&g.UserSettings.Filters); err != nil { + return err + } + if !g.HasExternalAuth() { + g.UserSettings.Filters.ExternalAuthCacheTime = 0 + } + g.UserSettings.Filters.UserType = "" + return nil +} + +func (g *Group) getACopy() Group { + users := make([]string, len(g.Users)) + copy(users, g.Users) + admins := make([]string, len(g.Admins)) + copy(admins, g.Admins) + virtualFolders := make([]vfs.VirtualFolder, 0, len(g.VirtualFolders)) + for idx := range g.VirtualFolders { + vfolder := g.VirtualFolders[idx].GetACopy() + virtualFolders = append(virtualFolders, vfolder) + } + permissions := make(map[string][]string) + for k, v := range g.UserSettings.Permissions { + perms := make([]string, len(v)) + copy(perms, v) + permissions[k] = perms + } + + return Group{ + BaseGroup: sdk.BaseGroup{ + ID: g.ID, + Name: g.Name, + Description: g.Description, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, + Users: users, + Admins: admins, + }, + UserSettings: GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + HomeDir: g.UserSettings.HomeDir, + MaxSessions: g.UserSettings.MaxSessions, + QuotaSize: g.UserSettings.QuotaSize, + QuotaFiles: g.UserSettings.QuotaFiles, + Permissions: permissions, + UploadBandwidth: g.UserSettings.UploadBandwidth, + DownloadBandwidth: g.UserSettings.DownloadBandwidth, + UploadDataTransfer: g.UserSettings.UploadDataTransfer, + DownloadDataTransfer: g.UserSettings.DownloadDataTransfer, + TotalDataTransfer: g.UserSettings.TotalDataTransfer, + ExpiresIn: g.UserSettings.ExpiresIn, + Filters: copyBaseUserFilters(g.UserSettings.Filters), + }, + FsConfig: g.UserSettings.FsConfig.GetACopy(), + }, + VirtualFolders: virtualFolders, + } +} diff --git a/internal/dataprovider/iplist.go b/internal/dataprovider/iplist.go new file mode 100644 index 00000000..bc3b5b99 --- /dev/null +++ b/internal/dataprovider/iplist.go @@ -0,0 +1,498 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "fmt" + "net" + "net/netip" + "slices" + "strings" + "sync" + "sync/atomic" + + "github.com/yl2chen/cidranger" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + // maximum number of entries to match in memory + // if the list contains more elements than this limit a + // database query will be executed + ipListMemoryLimit = 15000 +) + +var ( + inMemoryLists map[IPListType]*IPList +) + +func init() { + inMemoryLists = map[IPListType]*IPList{} +} + +// IPListType is the enumerable for the supported IP list types +type IPListType int + +// AsString returns the string representation for the list type +func (t IPListType) AsString() string { + switch t { + case IPListTypeAllowList: + return "Allow list" + case IPListTypeDefender: + return "Defender" + case IPListTypeRateLimiterSafeList: + return "Rate limiters safe list" + default: + return "" + } +} + +// Supported IP list types +const ( + IPListTypeAllowList IPListType = iota + 1 + IPListTypeDefender + IPListTypeRateLimiterSafeList +) + +// Supported IP list modes +const ( + ListModeAllow = iota + 1 + ListModeDeny +) + +const ( + ipTypeV4 = iota + 1 + ipTypeV6 +) + +var ( + supportedIPListType = []IPListType{IPListTypeAllowList, IPListTypeDefender, IPListTypeRateLimiterSafeList} +) + +// CheckIPListType returns an error if the provided IP list type is not valid +func CheckIPListType(t IPListType) error { + if !slices.Contains(supportedIPListType, t) { + return util.NewValidationError(fmt.Sprintf("invalid list type %d", t)) + } + return nil +} + +// IPListEntry defines an entry for the IP addresses list +type IPListEntry struct { + IPOrNet string `json:"ipornet"` + Description string `json:"description,omitempty"` + Type IPListType `json:"type"` + Mode int `json:"mode"` + // Defines the protocols the entry applies to + // - 0 all the supported protocols + // - 1 SSH + // - 2 FTP + // - 4 WebDAV + // - 8 HTTP + // Protocols can be combined + Protocols int `json:"protocols"` + First []byte `json:"first,omitempty"` + Last []byte `json:"last,omitempty"` + IPType int `json:"ip_type,omitempty"` + // Creation time as unix timestamp in milliseconds + CreatedAt int64 `json:"created_at"` + // last update time as unix timestamp in milliseconds + UpdatedAt int64 `json:"updated_at"` + // in multi node setups we mark the rule as deleted to be able to update the cache + DeletedAt int64 `json:"-"` +} + +// PrepareForRendering prepares an IP list entry for rendering. +// It hides internal fields +func (e *IPListEntry) PrepareForRendering() { + e.First = nil + e.Last = nil + e.IPType = 0 +} + +// HasProtocol returns true if the specified protocol is defined +func (e *IPListEntry) HasProtocol(proto string) bool { + switch proto { + case protocolSSH: + return e.Protocols&1 != 0 + case protocolFTP: + return e.Protocols&2 != 0 + case protocolWebDAV: + return e.Protocols&4 != 0 + case protocolHTTP: + return e.Protocols&8 != 0 + default: + return false + } +} + +// RenderAsJSON implements the renderer interface used within plugins +func (e *IPListEntry) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + entry, err := provider.ipListEntryExists(e.IPOrNet, e.Type) + if err != nil { + providerLog(logger.LevelError, "unable to reload IP list entry before rendering as json: %v", err) + return nil, err + } + entry.PrepareForRendering() + return json.Marshal(entry) + } + e.PrepareForRendering() + return json.Marshal(e) +} + +func (e *IPListEntry) getKey() string { + return fmt.Sprintf("%d_%s", e.Type, e.IPOrNet) +} + +func (e *IPListEntry) getName() string { + return e.Type.AsString() + "-" + e.IPOrNet +} + +func (e *IPListEntry) getFirst() netip.Addr { + if e.IPType == ipTypeV4 { + var a4 [4]byte + copy(a4[:], e.First) + return netip.AddrFrom4(a4) + } + var a16 [16]byte + copy(a16[:], e.First) + return netip.AddrFrom16(a16) +} + +func (e *IPListEntry) getLast() netip.Addr { + if e.IPType == ipTypeV4 { + var a4 [4]byte + copy(a4[:], e.Last) + return netip.AddrFrom4(a4) + } + var a16 [16]byte + copy(a16[:], e.Last) + return netip.AddrFrom16(a16) +} + +func (e *IPListEntry) checkProtocols() { + for _, proto := range ValidProtocols { + if !e.HasProtocol(proto) { + return + } + } + e.Protocols = 0 +} + +func (e *IPListEntry) validate() error { + if err := CheckIPListType(e.Type); err != nil { + return err + } + e.checkProtocols() + switch e.Type { + case IPListTypeDefender: + if e.Mode < ListModeAllow || e.Mode > ListModeDeny { + return util.NewValidationError(fmt.Sprintf("invalid list mode: %d", e.Mode)) + } + default: + if e.Mode != ListModeAllow { + return util.NewValidationError("invalid list mode") + } + } + e.PrepareForRendering() + if !strings.Contains(e.IPOrNet, "/") { + // parse as IP + parsed, err := netip.ParseAddr(e.IPOrNet) + if err != nil { + return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid IP %q", e.IPOrNet)), util.I18nErrorIPInvalid) + } + if parsed.Is4() { + e.IPOrNet += "/32" + } else if parsed.Is4In6() { + e.IPOrNet = netip.AddrFrom4(parsed.As4()).String() + "/32" + } else { + e.IPOrNet += "/128" + } + } + prefix, err := netip.ParsePrefix(e.IPOrNet) + if err != nil { + return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network %q: %v", e.IPOrNet, err)), util.I18nErrorNetInvalid) + } + prefix = prefix.Masked() + if prefix.Addr().Is4In6() { + e.IPOrNet = fmt.Sprintf("%s/%d", netip.AddrFrom4(prefix.Addr().As4()).String(), prefix.Bits()-96) + } + // TODO: to remove when the in memory ranger switch to netip + _, _, err = net.ParseCIDR(e.IPOrNet) + if err != nil { + return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid network: %v", err)), util.I18nErrorNetInvalid) + } + if prefix.Addr().Is4() || prefix.Addr().Is4In6() { + e.IPType = ipTypeV4 + first := prefix.Addr().As4() + last := util.GetLastIPForPrefix(prefix).As4() + e.First = first[:] + e.Last = last[:] + } else { + e.IPType = ipTypeV6 + first := prefix.Addr().As16() + last := util.GetLastIPForPrefix(prefix).As16() + e.First = first[:] + e.Last = last[:] + } + return nil +} + +func (e *IPListEntry) getACopy() IPListEntry { + first := make([]byte, len(e.First)) + copy(first, e.First) + last := make([]byte, len(e.Last)) + copy(last, e.Last) + + return IPListEntry{ + IPOrNet: e.IPOrNet, + Description: e.Description, + Type: e.Type, + Mode: e.Mode, + First: first, + Last: last, + IPType: e.IPType, + Protocols: e.Protocols, + CreatedAt: e.CreatedAt, + UpdatedAt: e.UpdatedAt, + DeletedAt: e.DeletedAt, + } +} + +// getAsRangerEntry returns the entry as cidranger.RangerEntry +func (e *IPListEntry) getAsRangerEntry() (cidranger.RangerEntry, error) { + _, network, err := net.ParseCIDR(e.IPOrNet) + if err != nil { + return nil, err + } + entry := e.getACopy() + return &rangerEntry{ + entry: &entry, + network: *network, + }, nil +} + +func (e IPListEntry) satisfySearchConstraints(filter, from, order string) bool { + if filter != "" && !strings.HasPrefix(e.IPOrNet, filter) { + return false + } + if from != "" { + if order == OrderASC { + return e.IPOrNet > from + } + return e.IPOrNet < from + } + return true +} + +type rangerEntry struct { + entry *IPListEntry + network net.IPNet +} + +func (e *rangerEntry) Network() net.IPNet { + return e.network +} + +// IPList defines an IP list +type IPList struct { + isInMemory atomic.Bool + listType IPListType + mu sync.RWMutex + Ranges cidranger.Ranger +} + +func (l *IPList) addEntry(e *IPListEntry) { + if l.listType != e.Type { + return + } + if !l.isInMemory.Load() { + return + } + entry, err := e.getAsRangerEntry() + if err != nil { + providerLog(logger.LevelError, "unable to get entry to add %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + return + } + l.mu.Lock() + defer l.mu.Unlock() + + if err := l.Ranges.Insert(entry); err != nil { + providerLog(logger.LevelError, "unable to add entry %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + return + } + if l.Ranges.Len() >= ipListMemoryLimit { + providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType) + l.isInMemory.Store(false) + } +} + +func (l *IPList) removeEntry(e *IPListEntry) { + if l.listType != e.Type { + return + } + if !l.isInMemory.Load() { + return + } + entry, err := e.getAsRangerEntry() + if err != nil { + providerLog(logger.LevelError, "unable to get entry to remove %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + return + } + l.mu.Lock() + defer l.mu.Unlock() + + if _, err := l.Ranges.Remove(entry.Network()); err != nil { + providerLog(logger.LevelError, "unable to remove entry %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + } +} + +func (l *IPList) updateEntry(e *IPListEntry) { + if l.listType != e.Type { + return + } + if !l.isInMemory.Load() { + return + } + entry, err := e.getAsRangerEntry() + if err != nil { + providerLog(logger.LevelError, "unable to get entry to update %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + return + } + l.mu.Lock() + defer l.mu.Unlock() + + if _, err := l.Ranges.Remove(entry.Network()); err != nil { + providerLog(logger.LevelError, "unable to remove entry to update %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + return + } + if err := l.Ranges.Insert(entry); err != nil { + providerLog(logger.LevelError, "unable to add entry to update %q for list type %d, disabling memory mode, err: %v", + e.IPOrNet, l.listType, err) + l.isInMemory.Store(false) + } + if l.Ranges.Len() >= ipListMemoryLimit { + providerLog(logger.LevelError, "memory limit exceeded for list type %d, disabling memory mode", l.listType) + l.isInMemory.Store(false) + } +} + +// DisableMemoryMode disables memory mode forcing database queries +func (l *IPList) DisableMemoryMode() { + l.isInMemory.Store(false) +} + +// IsListed checks if there is a match for the specified IP and protocol. +// If there are multiple matches, the first one is returned, in no particular order, +// so the behavior is undefined +func (l *IPList) IsListed(ip, protocol string) (bool, int, error) { + if l.isInMemory.Load() { + l.mu.RLock() + defer l.mu.RUnlock() + + if l.Ranges.Len() == 0 { + return false, 0, nil + } + + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false, 0, fmt.Errorf("invalid IP %s", ip) + } + + entries, err := l.Ranges.ContainingNetworks(parsedIP) + if err != nil { + return false, 0, fmt.Errorf("unable to find containing networks for ip %q: %w", ip, err) + } + for _, e := range entries { + entry, ok := e.(*rangerEntry) + if ok { + if entry.entry.Protocols == 0 || entry.entry.HasProtocol(protocol) { + return true, entry.entry.Mode, nil + } + } + } + + return false, 0, nil + } + + entries, err := provider.getListEntriesForIP(ip, l.listType) + if err != nil { + return false, 0, err + } + for _, e := range entries { + if e.Protocols == 0 || e.HasProtocol(protocol) { + return true, e.Mode, nil + } + } + + return false, 0, nil +} + +// NewIPList returns a new IP list for the specified type +func NewIPList(listType IPListType) (*IPList, error) { + delete(inMemoryLists, listType) + count, err := provider.countIPListEntries(listType) + if err != nil { + return nil, err + } + if count < ipListMemoryLimit { + providerLog(logger.LevelInfo, "using in-memory matching for list type %d, num entries: %d", listType, count) + entries, err := provider.getIPListEntries(listType, "", "", OrderASC, 0) + if err != nil { + return nil, err + } + ipList := &IPList{ + listType: listType, + Ranges: cidranger.NewPCTrieRanger(), + } + for idx := range entries { + e := entries[idx] + entry, err := e.getAsRangerEntry() + if err != nil { + return nil, fmt.Errorf("unable to get ranger for entry %q: %w", e.IPOrNet, err) + } + if err := ipList.Ranges.Insert(entry); err != nil { + return nil, fmt.Errorf("unable to add ranger for entry %q: %w", e.IPOrNet, err) + } + } + ipList.isInMemory.Store(true) + inMemoryLists[listType] = ipList + + return ipList, nil + } + providerLog(logger.LevelInfo, "list type %d has %d entries, in-memory matching disabled", listType, count) + ipList := &IPList{ + listType: listType, + Ranges: nil, + } + ipList.isInMemory.Store(false) + return ipList, nil +} diff --git a/internal/dataprovider/memory.go b/internal/dataprovider/memory.go new file mode 100644 index 00000000..e8595291 --- /dev/null +++ b/internal/dataprovider/memory.go @@ -0,0 +1,3345 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "bytes" + "crypto/x509" + "errors" + "fmt" + "net/netip" + "os" + "path/filepath" + "slices" + "sort" + "strconv" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + errMemoryProviderClosed = errors.New("memory provider is closed") +) + +type memoryProviderHandle struct { + // configuration file to use for loading users + configFile string + sync.Mutex + isClosed bool + // slice with ordered usernames + usernames []string + // map for users, username is the key + users map[string]User + // slice with ordered group names + groupnames []string + // map for group, group name is the key + groups map[string]Group + // map for virtual folders, folder name is the key + vfolders map[string]vfs.BaseVirtualFolder + // slice with ordered folder names + vfoldersNames []string + // map for admins, username is the key + admins map[string]Admin + // slice with ordered admins + adminsUsernames []string + // map for API keys, keyID is the key + apiKeys map[string]APIKey + // slice with ordered API keys KeyID + apiKeysIDs []string + // map for shares, shareID is the key + shares map[string]Share + // slice with ordered shares shareID + sharesIDs []string + // map for event actions, name is the key + actions map[string]BaseEventAction + // slice with ordered actions + actionsNames []string + // map for event actions, name is the key + rules map[string]EventRule + // slice with ordered rules + rulesNames []string + // map for roles, name is the key + roles map[string]Role + // slice with ordered roles + roleNames []string + // map for IP List entry + ipListEntries map[string]IPListEntry + // slice with ordered IP list entries + ipListEntriesKeys []string + // configurations + configs Configs +} + +// MemoryProvider defines the auth provider for a memory store +type MemoryProvider struct { + dbHandle *memoryProviderHandle +} + +func initializeMemoryProvider(basePath string) error { + configFile := "" + if util.IsFileInputValid(config.Name) { + configFile = config.Name + if !filepath.IsAbs(configFile) { + configFile = filepath.Join(basePath, configFile) + } + } + provider = &MemoryProvider{ + dbHandle: &memoryProviderHandle{ + isClosed: false, + usernames: []string{}, + users: make(map[string]User), + groupnames: []string{}, + groups: make(map[string]Group), + vfolders: make(map[string]vfs.BaseVirtualFolder), + vfoldersNames: []string{}, + admins: make(map[string]Admin), + adminsUsernames: []string{}, + apiKeys: make(map[string]APIKey), + apiKeysIDs: []string{}, + shares: make(map[string]Share), + sharesIDs: []string{}, + actions: make(map[string]BaseEventAction), + actionsNames: []string{}, + rules: make(map[string]EventRule), + rulesNames: []string{}, + roles: map[string]Role{}, + roleNames: []string{}, + ipListEntries: map[string]IPListEntry{}, + ipListEntriesKeys: []string{}, + configs: Configs{}, + configFile: configFile, + }, + } + return provider.reloadConfig() +} + +func (p *MemoryProvider) checkAvailability() error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + return nil +} + +func (p *MemoryProvider) close() error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + p.dbHandle.isClosed = true + return nil +} + +func (p *MemoryProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { + var user User + if tlsCert == nil { + return user, errors.New("TLS certificate cannot be null or empty") + } + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, err + } + return checkUserAndTLSCertificate(&user, protocol, tlsCert) +} + +func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, err + } + return checkUserAndPass(&user, password, ip, protocol) +} + +func (p *MemoryProvider) validateUserAndPubKey(username string, pubKey []byte, isSSHCert bool) (User, string, error) { + var user User + if len(pubKey) == 0 { + return user, "", errors.New("credentials cannot be null or empty") + } + user, err := p.userExists(username, "") + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, "", err + } + return checkUserAndPubKey(&user, pubKey, isSSHCert) +} + +func (p *MemoryProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { + admin, err := p.adminExists(username) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating admin %q: %v", username, err) + return admin, err + } + err = admin.checkUserAndPass(password, ip) + return admin, err +} + +func (p *MemoryProvider) updateAPIKeyLastUse(keyID string) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + apiKey, err := p.apiKeyExistsInternal(keyID) + if err != nil { + return err + } + apiKey.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.apiKeys[apiKey.KeyID] = apiKey + return nil +} + +func (p *MemoryProvider) getAdminSignature(username string) (string, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return "", errMemoryProviderClosed + } + admin, err := p.adminExistsInternal(username) + if err != nil { + return "", err + } + return strconv.FormatInt(admin.UpdatedAt, 10), nil +} + +func (p *MemoryProvider) getUserSignature(username string) (string, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return "", errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + return "", err + } + return strconv.FormatInt(user.UpdatedAt, 10), nil +} + +func (p *MemoryProvider) setUpdatedAt(username string) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return + } + user, err := p.userExistsInternal(username) + if err != nil { + return + } + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.users[user.Username] = user + setLastUserUpdate() +} + +func (p *MemoryProvider) updateLastLogin(username string) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + return err + } + user.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.users[user.Username] = user + return nil +} + +func (p *MemoryProvider) updateAdminLastLogin(username string) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + admin, err := p.adminExistsInternal(username) + if err != nil { + return err + } + admin.LastLogin = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.admins[admin.Username] = admin + return nil +} + +func (p *MemoryProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + providerLog(logger.LevelError, "unable to update transfer quota for user %q error: %v", username, err) + return err + } + if reset { + user.UsedUploadDataTransfer = uploadSize + user.UsedDownloadDataTransfer = downloadSize + } else { + user.UsedUploadDataTransfer += uploadSize + user.UsedDownloadDataTransfer += downloadSize + } + user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %v dl increment: %v is reset? %v", + username, uploadSize, downloadSize, reset) + p.dbHandle.users[user.Username] = user + return nil +} + +func (p *MemoryProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + providerLog(logger.LevelError, "unable to update quota for user %q error: %v", username, err) + return err + } + if reset { + user.UsedQuotaSize = sizeAdd + user.UsedQuotaFiles = filesAdd + } else { + user.UsedQuotaSize += sizeAdd + user.UsedQuotaFiles += filesAdd + } + user.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %v size increment: %v is reset? %v", + username, filesAdd, sizeAdd, reset) + p.dbHandle.users[user.Username] = user + return nil +} + +func (p *MemoryProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return 0, 0, 0, 0, errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + providerLog(logger.LevelError, "unable to get quota for user %q error: %v", username, err) + return 0, 0, 0, 0, err + } + return user.UsedQuotaFiles, user.UsedQuotaSize, user.UsedUploadDataTransfer, user.UsedDownloadDataTransfer, err +} + +func (p *MemoryProvider) addUser(user *User) error { + err := ValidateUser(user) + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + _, err = p.userExistsInternal(user.Username) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: username %v already exists", ErrDuplicatedKey, user.Username), + util.I18nErrorDuplicatedUsername, + ) + } + user.ID = p.getNextID() + user.LastQuotaUpdate = 0 + user.UsedQuotaSize = 0 + user.UsedQuotaFiles = 0 + user.UsedUploadDataTransfer = 0 + user.UsedDownloadDataTransfer = 0 + user.LastLogin = 0 + user.FirstUpload = 0 + user.FirstDownload = 0 + user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + if err := p.addUserToRole(user.Username, user.Role); err != nil { + return err + } + var mappedGroups []string + for idx := range user.Groups { + if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { + // try to remove group mapping + for _, g := range mappedGroups { + p.removeUserFromGroupMapping(user.Username, g) + } + return err + } + mappedGroups = append(mappedGroups, user.Groups[idx].Name) + } + var mappedFolders []string + for idx := range user.VirtualFolders { + if err = p.addUserToFolderMapping(user.Username, user.VirtualFolders[idx].Name); err != nil { + // try to remove folder mapping + for _, f := range mappedFolders { + p.removeRelationFromFolderMapping(f, user.Username, "") + } + return err + } + mappedFolders = append(mappedFolders, user.VirtualFolders[idx].Name) + } + p.dbHandle.users[user.Username] = user.getACopy() + p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username) + sort.Strings(p.dbHandle.usernames) + return nil +} + +func (p *MemoryProvider) updateUser(user *User) error { //nolint:gocyclo + err := ValidateUser(user) + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + u, err := p.userExistsInternal(user.Username) + if err != nil { + return err + } + p.removeUserFromRole(u.Username, u.Role) + if err := p.addUserToRole(user.Username, user.Role); err != nil { + // try ro add old role + if errRollback := p.addUserToRole(u.Username, u.Role); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old role %q for user %q, error: %v", + u.Role, u.Username, errRollback) + } + return err + } + for idx := range u.Groups { + p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name) + } + for idx := range user.Groups { + if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { + // try to add old mapping + for _, g := range u.Groups { + if errRollback := p.addUserToGroupMapping(user.Username, g.Name); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old group mapping %q for user %q, error: %v", + g.Name, user.Username, errRollback) + } + } + return err + } + } + for _, oldFolder := range u.VirtualFolders { + p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") + } + for idx := range user.VirtualFolders { + if err = p.addUserToFolderMapping(user.Username, user.VirtualFolders[idx].Name); err != nil { + // try to add old mapping + for _, f := range u.VirtualFolders { + if errRollback := p.addUserToFolderMapping(user.Username, f.Name); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old folder mapping %q for user %q, error: %v", + f.Name, user.Username, errRollback) + } + } + return err + } + } + user.LastQuotaUpdate = u.LastQuotaUpdate + user.UsedQuotaSize = u.UsedQuotaSize + user.UsedQuotaFiles = u.UsedQuotaFiles + user.UsedUploadDataTransfer = u.UsedUploadDataTransfer + user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer + user.LastLogin = u.LastLogin + user.FirstDownload = u.FirstDownload + user.FirstUpload = u.FirstUpload + user.CreatedAt = u.CreatedAt + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + user.ID = u.ID + // pre-login and external auth hook will use the passed *user so save a copy + p.dbHandle.users[user.Username] = user.getACopy() + setLastUserUpdate() + return nil +} + +func (p *MemoryProvider) deleteUser(user User, _ bool) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + u, err := p.userExistsInternal(user.Username) + if err != nil { + return err + } + p.removeUserFromRole(u.Username, u.Role) + for _, oldFolder := range u.VirtualFolders { + p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") + } + for idx := range u.Groups { + p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name) + } + delete(p.dbHandle.users, user.Username) + // this could be more efficient + p.dbHandle.usernames = make([]string, 0, len(p.dbHandle.users)) + for username := range p.dbHandle.users { + p.dbHandle.usernames = append(p.dbHandle.usernames, username) + } + sort.Strings(p.dbHandle.usernames) + p.deleteAPIKeysWithUser(user.Username) + p.deleteSharesWithUser(user.Username) + return nil +} + +func (p *MemoryProvider) updateUserPassword(username, password string) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + user, err := p.userExistsInternal(username) + if err != nil { + return err + } + user.Password = password + user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.users[username] = user + return nil +} + +func (p *MemoryProvider) dumpUsers() ([]User, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + users := make([]User, 0, len(p.dbHandle.usernames)) + var err error + if p.dbHandle.isClosed { + return users, errMemoryProviderClosed + } + for _, username := range p.dbHandle.usernames { + u := p.dbHandle.users[username] + user := u.getACopy() + p.addVirtualFoldersToUser(&user) + users = append(users, user) + } + return users, err +} + +func (p *MemoryProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + folders := make([]vfs.BaseVirtualFolder, 0, len(p.dbHandle.vfoldersNames)) + if p.dbHandle.isClosed { + return folders, errMemoryProviderClosed + } + for _, f := range p.dbHandle.vfolders { + folders = append(folders, f) + } + return folders, nil +} + +func (p *MemoryProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { + if getLastUserUpdate() < after { + return nil, nil + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + users := make([]User, 0, 10) + for _, username := range p.dbHandle.usernames { + u := p.dbHandle.users[username] + if u.UpdatedAt < after { + continue + } + user := u.getACopy() + p.addVirtualFoldersToUser(&user) + if len(user.Groups) > 0 { + groupMapping := make(map[string]Group) + for idx := range user.Groups { + group, err := p.groupExistsInternal(user.Groups[idx].Name) + if err != nil { + continue + } + groupMapping[group.Name] = group + } + user.applyGroupSettings(groupMapping) + } + + user.SetEmptySecretsIfNil() + users = append(users, user) + } + + return users, nil +} + +func (p *MemoryProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + users := make([]User, 0, 30) + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return users, errMemoryProviderClosed + } + for _, username := range p.dbHandle.usernames { + if needFolders, ok := toFetch[username]; ok { + u := p.dbHandle.users[username] + user := u.getACopy() + if needFolders { + p.addVirtualFoldersToUser(&user) + } + if len(user.Groups) > 0 { + groupMapping := make(map[string]Group) + for idx := range user.Groups { + group, err := p.groupExistsInternal(user.Groups[idx].Name) + if err != nil { + continue + } + groupMapping[group.Name] = group + } + user.applyGroupSettings(groupMapping) + } + user.SetEmptySecretsIfNil() + user.PrepareForRendering() + users = append(users, user) + } + } + + return users, nil +} + +func (p *MemoryProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { + users := make([]User, 0, limit) + var err error + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return users, errMemoryProviderClosed + } + if limit <= 0 { + return users, err + } + itNum := 0 + if order == OrderASC { + for _, username := range p.dbHandle.usernames { + itNum++ + if itNum <= offset { + continue + } + u := p.dbHandle.users[username] + user := u.getACopy() + if !user.hasRole(role) { + continue + } + p.addVirtualFoldersToUser(&user) + user.PrepareForRendering() + users = append(users, user) + if len(users) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.usernames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + username := p.dbHandle.usernames[i] + u := p.dbHandle.users[username] + user := u.getACopy() + if !user.hasRole(role) { + continue + } + p.addVirtualFoldersToUser(&user) + user.PrepareForRendering() + users = append(users, user) + if len(users) >= limit { + break + } + } + } + return users, err +} + +func (p *MemoryProvider) userExists(username, role string) (User, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return User{}, errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + return user, err + } + if !user.hasRole(role) { + return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) + } + p.addVirtualFoldersToUser(&user) + return user, nil +} + +func (p *MemoryProvider) userExistsInternal(username string) (User, error) { + if val, ok := p.dbHandle.users[username]; ok { + return val.getACopy(), nil + } + return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) +} + +func (p *MemoryProvider) groupExistsInternal(name string) (Group, error) { + if val, ok := p.dbHandle.groups[name]; ok { + return val.getACopy(), nil + } + return Group{}, util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) +} + +func (p *MemoryProvider) actionExistsInternal(name string) (BaseEventAction, error) { + if val, ok := p.dbHandle.actions[name]; ok { + return val.getACopy(), nil + } + return BaseEventAction{}, util.NewRecordNotFoundError(fmt.Sprintf("event action %q does not exist", name)) +} + +func (p *MemoryProvider) ruleExistsInternal(name string) (EventRule, error) { + if val, ok := p.dbHandle.rules[name]; ok { + return val.getACopy(), nil + } + return EventRule{}, util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", name)) +} + +func (p *MemoryProvider) roleExistsInternal(name string) (Role, error) { + if val, ok := p.dbHandle.roles[name]; ok { + return val.getACopy(), nil + } + return Role{}, util.NewRecordNotFoundError(fmt.Sprintf("role %q does not exist", name)) +} + +func (p *MemoryProvider) ipListEntryExistsInternal(entry *IPListEntry) (IPListEntry, error) { + if val, ok := p.dbHandle.ipListEntries[entry.getKey()]; ok { + return val.getACopy(), nil + } + return IPListEntry{}, util.NewRecordNotFoundError(fmt.Sprintf("IP list entry %q does not exist", entry.getName())) +} + +func (p *MemoryProvider) addAdmin(admin *Admin) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + err := admin.validate() + if err != nil { + return err + } + _, err = p.adminExistsInternal(admin.Username) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: admin %q already exists", ErrDuplicatedKey, admin.Username), + util.I18nErrorDuplicatedUsername, + ) + } + admin.ID = p.getNextAdminID() + admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + admin.LastLogin = 0 + if err := p.addAdminToRole(admin.Username, admin.Role); err != nil { + return err + } + var mappedAdmins []string + for idx := range admin.Groups { + if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil { + // try to remove group mapping + for _, g := range mappedAdmins { + p.removeAdminFromGroupMapping(admin.Username, g) + } + return err + } + mappedAdmins = append(mappedAdmins, admin.Groups[idx].Name) + } + p.dbHandle.admins[admin.Username] = admin.getACopy() + p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, admin.Username) + sort.Strings(p.dbHandle.adminsUsernames) + return nil +} + +func (p *MemoryProvider) updateAdmin(admin *Admin) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + err := admin.validate() + if err != nil { + return err + } + a, err := p.adminExistsInternal(admin.Username) + if err != nil { + return err + } + p.removeAdminFromRole(a.Username, a.Role) + if err := p.addAdminToRole(admin.Username, admin.Role); err != nil { + // try ro add old role + if errRollback := p.addAdminToRole(a.Username, a.Role); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old role %q for admin %q, error: %v", + a.Role, a.Username, errRollback) + } + return err + } + for idx := range a.Groups { + p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name) + } + for idx := range admin.Groups { + if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil { + // try to add old mapping + for _, oldGroup := range a.Groups { + if errRollback := p.addAdminToGroupMapping(a.Username, oldGroup.Name); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old group mapping %q for admin %q, error: %v", + oldGroup.Name, a.Username, errRollback) + } + } + return err + } + } + admin.ID = a.ID + admin.CreatedAt = a.CreatedAt + admin.LastLogin = a.LastLogin + admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.admins[admin.Username] = admin.getACopy() + return nil +} + +func (p *MemoryProvider) deleteAdmin(admin Admin) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + a, err := p.adminExistsInternal(admin.Username) + if err != nil { + return err + } + p.removeAdminFromRole(a.Username, a.Role) + for idx := range a.Groups { + p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name) + } + + delete(p.dbHandle.admins, admin.Username) + // this could be more efficient + p.dbHandle.adminsUsernames = make([]string, 0, len(p.dbHandle.admins)) + for username := range p.dbHandle.admins { + p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, username) + } + sort.Strings(p.dbHandle.adminsUsernames) + p.deleteAPIKeysWithAdmin(admin.Username) + return nil +} + +func (p *MemoryProvider) adminExists(username string) (Admin, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return Admin{}, errMemoryProviderClosed + } + return p.adminExistsInternal(username) +} + +func (p *MemoryProvider) adminExistsInternal(username string) (Admin, error) { + if val, ok := p.dbHandle.admins[username]; ok { + return val.getACopy(), nil + } + return Admin{}, util.NewRecordNotFoundError(fmt.Sprintf("admin %q does not exist", username)) +} + +func (p *MemoryProvider) dumpAdmins() ([]Admin, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + admins := make([]Admin, 0, len(p.dbHandle.admins)) + if p.dbHandle.isClosed { + return admins, errMemoryProviderClosed + } + for _, admin := range p.dbHandle.admins { + admins = append(admins, admin) + } + return admins, nil +} + +func (p *MemoryProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { + admins := make([]Admin, 0, limit) + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return admins, errMemoryProviderClosed + } + if limit <= 0 { + return admins, nil + } + itNum := 0 + if order == OrderASC { + for _, username := range p.dbHandle.adminsUsernames { + itNum++ + if itNum <= offset { + continue + } + a := p.dbHandle.admins[username] + admin := a.getACopy() + admin.HideConfidentialData() + admins = append(admins, admin) + if len(admins) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.adminsUsernames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + username := p.dbHandle.adminsUsernames[i] + a := p.dbHandle.admins[username] + admin := a.getACopy() + admin.HideConfidentialData() + admins = append(admins, admin) + if len(admins) >= limit { + break + } + } + } + + return admins, nil +} + +func (p *MemoryProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + folder, err := p.folderExistsInternal(name) + if err != nil { + providerLog(logger.LevelError, "unable to update quota for folder %q error: %v", name, err) + return err + } + if reset { + folder.UsedQuotaSize = sizeAdd + folder.UsedQuotaFiles = filesAdd + } else { + folder.UsedQuotaSize += sizeAdd + folder.UsedQuotaFiles += filesAdd + } + folder.LastQuotaUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.vfolders[name] = folder + return nil +} + +func (p *MemoryProvider) getGroups(limit, offset int, order string, _ bool) ([]Group, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + if limit <= 0 { + return nil, nil + } + groups := make([]Group, 0, limit) + itNum := 0 + if order == OrderASC { + for _, name := range p.dbHandle.groupnames { + itNum++ + if itNum <= offset { + continue + } + g := p.dbHandle.groups[name] + group := g.getACopy() + p.addVirtualFoldersToGroup(&group) + group.PrepareForRendering() + groups = append(groups, group) + if len(groups) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.groupnames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + name := p.dbHandle.groupnames[i] + g := p.dbHandle.groups[name] + group := g.getACopy() + p.addVirtualFoldersToGroup(&group) + group.PrepareForRendering() + groups = append(groups, group) + if len(groups) >= limit { + break + } + } + } + return groups, nil +} + +func (p *MemoryProvider) getGroupsWithNames(names []string) ([]Group, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + groups := make([]Group, 0, len(names)) + for _, name := range names { + if val, ok := p.dbHandle.groups[name]; ok { + group := val.getACopy() + p.addVirtualFoldersToGroup(&group) + groups = append(groups, group) + } + } + + return groups, nil +} + +func (p *MemoryProvider) getUsersInGroups(names []string) ([]string, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + var users []string + for _, name := range names { + if val, ok := p.dbHandle.groups[name]; ok { + group := val.getACopy() + users = append(users, group.Users...) + } + } + + return users, nil +} + +func (p *MemoryProvider) groupExists(name string) (Group, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return Group{}, errMemoryProviderClosed + } + group, err := p.groupExistsInternal(name) + if err != nil { + return group, err + } + p.addVirtualFoldersToGroup(&group) + return group, nil +} + +func (p *MemoryProvider) addGroup(group *Group) error { + if err := group.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + _, err := p.groupExistsInternal(group.Name) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: group %q already exists", ErrDuplicatedKey, group.Name), + util.I18nErrorDuplicatedUsername, + ) + } + group.ID = p.getNextGroupID() + group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + group.Users = nil + group.Admins = nil + var mappedFolders []string + for idx := range group.VirtualFolders { + if err = p.addGroupToFolderMapping(group.Name, group.VirtualFolders[idx].Name); err != nil { + // try to remove folder mapping + for _, f := range mappedFolders { + p.removeRelationFromFolderMapping(f, "", group.Name) + } + return err + } + mappedFolders = append(mappedFolders, group.VirtualFolders[idx].Name) + } + p.dbHandle.groups[group.Name] = group.getACopy() + p.dbHandle.groupnames = append(p.dbHandle.groupnames, group.Name) + sort.Strings(p.dbHandle.groupnames) + return nil +} + +func (p *MemoryProvider) updateGroup(group *Group) error { + if err := group.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + g, err := p.groupExistsInternal(group.Name) + if err != nil { + return err + } + for _, oldFolder := range g.VirtualFolders { + p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name) + } + for idx := range group.VirtualFolders { + if err = p.addGroupToFolderMapping(group.Name, group.VirtualFolders[idx].Name); err != nil { + // try to add old mapping + for _, f := range g.VirtualFolders { + if errRollback := p.addGroupToFolderMapping(group.Name, f.Name); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old folder mapping %q for group %q, error: %v", + f.Name, group.Name, errRollback) + } + } + return err + } + } + group.CreatedAt = g.CreatedAt + group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + group.ID = g.ID + group.Users = g.Users + group.Admins = g.Admins + p.dbHandle.groups[group.Name] = group.getACopy() + return nil +} + +func (p *MemoryProvider) deleteGroup(group Group) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + g, err := p.groupExistsInternal(group.Name) + if err != nil { + return err + } + if len(g.Users) > 0 { + return util.NewValidationError(fmt.Sprintf("the group %q is referenced, it cannot be removed", group.Name)) + } + for _, oldFolder := range g.VirtualFolders { + p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name) + } + for _, a := range g.Admins { + p.removeGroupFromAdminMapping(g.Name, a) + } + delete(p.dbHandle.groups, group.Name) + // this could be more efficient + p.dbHandle.groupnames = make([]string, 0, len(p.dbHandle.groups)) + for name := range p.dbHandle.groups { + p.dbHandle.groupnames = append(p.dbHandle.groupnames, name) + } + sort.Strings(p.dbHandle.groupnames) + return nil +} + +func (p *MemoryProvider) dumpGroups() ([]Group, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + groups := make([]Group, 0, len(p.dbHandle.groups)) + var err error + if p.dbHandle.isClosed { + return groups, errMemoryProviderClosed + } + for _, name := range p.dbHandle.groupnames { + g := p.dbHandle.groups[name] + group := g.getACopy() + p.addVirtualFoldersToGroup(&group) + groups = append(groups, group) + } + return groups, err +} + +func (p *MemoryProvider) getUsedFolderQuota(name string) (int, int64, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return 0, 0, errMemoryProviderClosed + } + folder, err := p.folderExistsInternal(name) + if err != nil { + providerLog(logger.LevelError, "unable to get quota for folder %q error: %v", name, err) + return 0, 0, err + } + return folder.UsedQuotaFiles, folder.UsedQuotaSize, err +} + +func (p *MemoryProvider) addVirtualFoldersToGroup(group *Group) { + if len(group.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range group.VirtualFolders { + folder := &group.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder.GetACopy() + folders = append(folders, *folder) + } + group.VirtualFolders = folders + } +} + +func (p *MemoryProvider) addActionsToRule(rule *EventRule) { + var actions []EventAction + for idx := range rule.Actions { + action := &rule.Actions[idx] + baseAction, err := p.actionExistsInternal(action.Name) + if err != nil { + continue + } + baseAction.Options.SetEmptySecretsIfNil() + action.BaseEventAction = baseAction + actions = append(actions, *action) + } + rule.Actions = actions +} + +func (p *MemoryProvider) addRuleToActionMapping(ruleName, actionName string) error { + a, err := p.actionExistsInternal(actionName) + if err != nil { + return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName)) + } + if !slices.Contains(a.Rules, ruleName) { + a.Rules = append(a.Rules, ruleName) + p.dbHandle.actions[actionName] = a + } + return nil +} + +func (p *MemoryProvider) removeRuleFromActionMapping(ruleName, actionName string) { + a, err := p.actionExistsInternal(actionName) + if err != nil { + providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName) + return + } + if slices.Contains(a.Rules, ruleName) { + var rules []string + for _, r := range a.Rules { + if r != ruleName { + rules = append(rules, r) + } + } + a.Rules = rules + p.dbHandle.actions[actionName] = a + } +} + +func (p *MemoryProvider) addAdminToGroupMapping(username, groupname string) error { + g, err := p.groupExistsInternal(groupname) + if err != nil { + return err + } + if !slices.Contains(g.Admins, username) { + g.Admins = append(g.Admins, username) + p.dbHandle.groups[groupname] = g + } + return nil +} + +func (p *MemoryProvider) removeAdminFromGroupMapping(username, groupname string) { + g, err := p.groupExistsInternal(groupname) + if err != nil { + return + } + var admins []string + for _, a := range g.Admins { + if a != username { + admins = append(admins, a) + } + } + g.Admins = admins + p.dbHandle.groups[groupname] = g +} + +func (p *MemoryProvider) removeGroupFromAdminMapping(groupname, username string) { + admin, err := p.adminExistsInternal(username) + if err != nil { + // the admin does not exist so there is no associated group + return + } + var newGroups []AdminGroupMapping + for _, g := range admin.Groups { + if g.Name != groupname { + newGroups = append(newGroups, g) + } + } + admin.Groups = newGroups + p.dbHandle.admins[admin.Username] = admin +} + +func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error { + g, err := p.groupExistsInternal(groupname) + if err != nil { + return err + } + if !slices.Contains(g.Users, username) { + g.Users = append(g.Users, username) + p.dbHandle.groups[groupname] = g + } + return nil +} + +func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) { + g, err := p.groupExistsInternal(groupname) + if err != nil { + return + } + var users []string + for _, u := range g.Users { + if u != username { + users = append(users, u) + } + } + g.Users = users + p.dbHandle.groups[groupname] = g +} + +func (p *MemoryProvider) addAdminToRole(username, role string) error { + if role == "" { + return nil + } + r, err := p.roleExistsInternal(role) + if err != nil { + return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role) + } + if !slices.Contains(r.Admins, username) { + r.Admins = append(r.Admins, username) + p.dbHandle.roles[role] = r + } + return nil +} + +func (p *MemoryProvider) removeAdminFromRole(username, role string) { + if role == "" { + return + } + r, err := p.roleExistsInternal(role) + if err != nil { + providerLog(logger.LevelWarn, "role %q does not exist, cannot remove admin %q", role, username) + return + } + var admins []string + for _, a := range r.Admins { + if a != username { + admins = append(admins, a) + } + } + r.Admins = admins + p.dbHandle.roles[role] = r +} + +func (p *MemoryProvider) addUserToRole(username, role string) error { + if role == "" { + return nil + } + r, err := p.roleExistsInternal(role) + if err != nil { + return fmt.Errorf("%w: role %q does not exist", ErrForeignKeyViolated, role) + } + if !slices.Contains(r.Users, username) { + r.Users = append(r.Users, username) + p.dbHandle.roles[role] = r + } + return nil +} + +func (p *MemoryProvider) removeUserFromRole(username, role string) { + if role == "" { + return + } + r, err := p.roleExistsInternal(role) + if err != nil { + providerLog(logger.LevelWarn, "role %q does not exist, cannot remove user %q", role, username) + return + } + var users []string + for _, u := range r.Users { + if u != username { + users = append(users, u) + } + } + r.Users = users + p.dbHandle.roles[role] = r +} + +func (p *MemoryProvider) addUserToFolderMapping(username, foldername string) error { + f, err := p.folderExistsInternal(foldername) + if err != nil { + return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err)) + } + if !slices.Contains(f.Users, username) { + f.Users = append(f.Users, username) + p.dbHandle.vfolders[foldername] = f + } + return nil +} + +func (p *MemoryProvider) addGroupToFolderMapping(name, foldername string) error { + f, err := p.folderExistsInternal(foldername) + if err != nil { + return util.NewGenericError(fmt.Sprintf("unable to get folder %q: %v", foldername, err)) + } + if !slices.Contains(f.Groups, name) { + f.Groups = append(f.Groups, name) + p.dbHandle.vfolders[foldername] = f + } + return nil +} + +func (p *MemoryProvider) addVirtualFoldersToUser(user *User) { + if len(user.VirtualFolders) > 0 { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + folder := &user.VirtualFolders[idx] + baseFolder, err := p.folderExistsInternal(folder.Name) + if err != nil { + continue + } + folder.BaseVirtualFolder = baseFolder.GetACopy() + folders = append(folders, *folder) + } + user.VirtualFolders = folders + } +} + +func (p *MemoryProvider) removeRelationFromFolderMapping(folderName, username, groupname string) { + folder, err := p.folderExistsInternal(folderName) + if err != nil { + return + } + if username != "" { + var usernames []string + for _, user := range folder.Users { + if user != username { + usernames = append(usernames, user) + } + } + folder.Users = usernames + } + if groupname != "" { + var groups []string + for _, group := range folder.Groups { + if group != groupname { + groups = append(groups, group) + } + } + folder.Groups = groups + } + p.dbHandle.vfolders[folder.Name] = folder +} + +func (p *MemoryProvider) folderExistsInternal(name string) (vfs.BaseVirtualFolder, error) { + if val, ok := p.dbHandle.vfolders[name]; ok { + return val, nil + } + return vfs.BaseVirtualFolder{}, util.NewRecordNotFoundError(fmt.Sprintf("folder %q does not exist", name)) +} + +func (p *MemoryProvider) getFolders(limit, offset int, order string, _ bool) ([]vfs.BaseVirtualFolder, error) { + folders := make([]vfs.BaseVirtualFolder, 0, limit) + var err error + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return folders, errMemoryProviderClosed + } + if limit <= 0 { + return folders, err + } + itNum := 0 + if order == OrderASC { + for _, name := range p.dbHandle.vfoldersNames { + itNum++ + if itNum <= offset { + continue + } + f := p.dbHandle.vfolders[name] + folder := f.GetACopy() + folder.PrepareForRendering() + folders = append(folders, folder) + if len(folders) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.vfoldersNames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + name := p.dbHandle.vfoldersNames[i] + f := p.dbHandle.vfolders[name] + folder := f.GetACopy() + folder.PrepareForRendering() + folders = append(folders, folder) + if len(folders) >= limit { + break + } + } + } + return folders, err +} + +func (p *MemoryProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return vfs.BaseVirtualFolder{}, errMemoryProviderClosed + } + folder, err := p.folderExistsInternal(name) + if err != nil { + return vfs.BaseVirtualFolder{}, err + } + return folder.GetACopy(), nil +} + +func (p *MemoryProvider) addFolder(folder *vfs.BaseVirtualFolder) error { + err := ValidateFolder(folder) + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + _, err = p.folderExistsInternal(folder.Name) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: folder %q already exists", ErrDuplicatedKey, folder.Name), + util.I18nErrorDuplicatedUsername, + ) + } + folder.ID = p.getNextFolderID() + folder.Users = nil + folder.Groups = nil + p.dbHandle.vfolders[folder.Name] = folder.GetACopy() + p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, folder.Name) + sort.Strings(p.dbHandle.vfoldersNames) + return nil +} + +func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { + err := ValidateFolder(folder) + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + f, err := p.folderExistsInternal(folder.Name) + if err != nil { + return err + } + folder.ID = f.ID + folder.LastQuotaUpdate = f.LastQuotaUpdate + folder.UsedQuotaFiles = f.UsedQuotaFiles + folder.UsedQuotaSize = f.UsedQuotaSize + folder.Users = f.Users + folder.Groups = f.Groups + p.dbHandle.vfolders[folder.Name] = folder.GetACopy() + // now update the related users + for _, username := range folder.Users { + user, err := p.userExistsInternal(username) + if err == nil { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + userFolder := &user.VirtualFolders[idx] + if folder.Name == userFolder.Name { + userFolder.BaseVirtualFolder = folder.GetACopy() + } + folders = append(folders, *userFolder) + } + user.VirtualFolders = folders + p.dbHandle.users[user.Username] = user + } + } + return nil +} + +func (p *MemoryProvider) deleteFolder(f vfs.BaseVirtualFolder) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + folder, err := p.folderExistsInternal(f.Name) + if err != nil { + return err + } + for _, username := range folder.Users { + user, err := p.userExistsInternal(username) + if err == nil { + var folders []vfs.VirtualFolder + for idx := range user.VirtualFolders { + userFolder := &user.VirtualFolders[idx] + if folder.Name != userFolder.Name { + folders = append(folders, *userFolder) + } + } + user.VirtualFolders = folders + p.dbHandle.users[user.Username] = user + } + } + for _, groupname := range folder.Groups { + group, err := p.groupExistsInternal(groupname) + if err == nil { + var folders []vfs.VirtualFolder + for idx := range group.VirtualFolders { + groupFolder := &group.VirtualFolders[idx] + if folder.Name != groupFolder.Name { + folders = append(folders, *groupFolder) + } + } + group.VirtualFolders = folders + p.dbHandle.groups[group.Name] = group + } + } + delete(p.dbHandle.vfolders, folder.Name) + p.dbHandle.vfoldersNames = []string{} + for name := range p.dbHandle.vfolders { + p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, name) + } + sort.Strings(p.dbHandle.vfoldersNames) + return nil +} + +func (p *MemoryProvider) apiKeyExistsInternal(keyID string) (APIKey, error) { + if val, ok := p.dbHandle.apiKeys[keyID]; ok { + return val.getACopy(), nil + } + return APIKey{}, util.NewRecordNotFoundError(fmt.Sprintf("API key %q does not exist", keyID)) +} + +func (p *MemoryProvider) apiKeyExists(keyID string) (APIKey, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return APIKey{}, errMemoryProviderClosed + } + return p.apiKeyExistsInternal(keyID) +} + +func (p *MemoryProvider) addAPIKey(apiKey *APIKey) error { + err := apiKey.validate() + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + _, err = p.apiKeyExistsInternal(apiKey.KeyID) + if err == nil { + return fmt.Errorf("API key %q already exists", apiKey.KeyID) + } + if apiKey.User != "" { + if _, err := p.userExistsInternal(apiKey.User); err != nil { + return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) + } + } + if apiKey.Admin != "" { + if _, err := p.adminExistsInternal(apiKey.Admin); err != nil { + return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) + } + } + apiKey.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + apiKey.LastUseAt = 0 + p.dbHandle.apiKeys[apiKey.KeyID] = apiKey.getACopy() + p.dbHandle.apiKeysIDs = append(p.dbHandle.apiKeysIDs, apiKey.KeyID) + sort.Strings(p.dbHandle.apiKeysIDs) + return nil +} + +func (p *MemoryProvider) updateAPIKey(apiKey *APIKey) error { + err := apiKey.validate() + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + k, err := p.apiKeyExistsInternal(apiKey.KeyID) + if err != nil { + return err + } + if apiKey.User != "" { + if _, err := p.userExistsInternal(apiKey.User); err != nil { + return fmt.Errorf("%w: related user %q does not exists", ErrForeignKeyViolated, apiKey.User) + } + } + if apiKey.Admin != "" { + if _, err := p.adminExistsInternal(apiKey.Admin); err != nil { + return fmt.Errorf("%w: related admin %q does not exists", ErrForeignKeyViolated, apiKey.Admin) + } + } + apiKey.ID = k.ID + apiKey.KeyID = k.KeyID + apiKey.Key = k.Key + apiKey.CreatedAt = k.CreatedAt + apiKey.LastUseAt = k.LastUseAt + apiKey.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.apiKeys[apiKey.KeyID] = apiKey.getACopy() + return nil +} + +func (p *MemoryProvider) deleteAPIKey(apiKey APIKey) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err := p.apiKeyExistsInternal(apiKey.KeyID) + if err != nil { + return err + } + + delete(p.dbHandle.apiKeys, apiKey.KeyID) + p.updateAPIKeysOrdering() + + return nil +} + +func (p *MemoryProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { + apiKeys := make([]APIKey, 0, limit) + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return apiKeys, errMemoryProviderClosed + } + if limit <= 0 { + return apiKeys, nil + } + itNum := 0 + if order == OrderDESC { + for i := len(p.dbHandle.apiKeysIDs) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + keyID := p.dbHandle.apiKeysIDs[i] + k := p.dbHandle.apiKeys[keyID] + apiKey := k.getACopy() + apiKey.HideConfidentialData() + apiKeys = append(apiKeys, apiKey) + if len(apiKeys) >= limit { + break + } + } + } else { + for _, keyID := range p.dbHandle.apiKeysIDs { + itNum++ + if itNum <= offset { + continue + } + k := p.dbHandle.apiKeys[keyID] + apiKey := k.getACopy() + apiKey.HideConfidentialData() + apiKeys = append(apiKeys, apiKey) + if len(apiKeys) >= limit { + break + } + } + } + + return apiKeys, nil +} + +func (p *MemoryProvider) dumpAPIKeys() ([]APIKey, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + apiKeys := make([]APIKey, 0, len(p.dbHandle.apiKeys)) + if p.dbHandle.isClosed { + return apiKeys, errMemoryProviderClosed + } + for _, k := range p.dbHandle.apiKeys { + apiKeys = append(apiKeys, k) + } + return apiKeys, nil +} + +func (p *MemoryProvider) deleteAPIKeysWithUser(username string) { + found := false + for k, v := range p.dbHandle.apiKeys { + if v.User == username { + delete(p.dbHandle.apiKeys, k) + found = true + } + } + if found { + p.updateAPIKeysOrdering() + } +} + +func (p *MemoryProvider) deleteAPIKeysWithAdmin(username string) { + found := false + for k, v := range p.dbHandle.apiKeys { + if v.Admin == username { + delete(p.dbHandle.apiKeys, k) + found = true + } + } + if found { + p.updateAPIKeysOrdering() + } +} + +func (p *MemoryProvider) deleteSharesWithUser(username string) { + found := false + for k, v := range p.dbHandle.shares { + if v.Username == username { + delete(p.dbHandle.shares, k) + found = true + } + } + if found { + p.updateSharesOrdering() + } +} + +func (p *MemoryProvider) updateAPIKeysOrdering() { + // this could be more efficient + p.dbHandle.apiKeysIDs = make([]string, 0, len(p.dbHandle.apiKeys)) + for keyID := range p.dbHandle.apiKeys { + p.dbHandle.apiKeysIDs = append(p.dbHandle.apiKeysIDs, keyID) + } + sort.Strings(p.dbHandle.apiKeysIDs) +} + +func (p *MemoryProvider) updateSharesOrdering() { + // this could be more efficient + p.dbHandle.sharesIDs = make([]string, 0, len(p.dbHandle.shares)) + for shareID := range p.dbHandle.shares { + p.dbHandle.sharesIDs = append(p.dbHandle.sharesIDs, shareID) + } + sort.Strings(p.dbHandle.sharesIDs) +} + +func (p *MemoryProvider) shareExistsInternal(shareID, username string) (Share, error) { + if val, ok := p.dbHandle.shares[shareID]; ok { + if username != "" && val.Username != username { + return Share{}, util.NewRecordNotFoundError(fmt.Sprintf("Share %q does not exist", shareID)) + } + return val.getACopy(), nil + } + return Share{}, util.NewRecordNotFoundError(fmt.Sprintf("Share %q does not exist", shareID)) +} + +func (p *MemoryProvider) shareExists(shareID, username string) (Share, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return Share{}, errMemoryProviderClosed + } + return p.shareExistsInternal(shareID, username) +} + +func (p *MemoryProvider) addShare(share *Share) error { + err := share.validate() + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + _, err = p.shareExistsInternal(share.ShareID, share.Username) + if err == nil { + return fmt.Errorf("share %q already exists", share.ShareID) + } + if _, err := p.userExistsInternal(share.Username); err != nil { + return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) + } + if !share.IsRestore { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + share.UpdatedAt = share.CreatedAt + share.LastUseAt = 0 + share.UsedTokens = 0 + } + if share.CreatedAt == 0 { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.UpdatedAt == 0 { + share.UpdatedAt = share.CreatedAt + } + p.dbHandle.shares[share.ShareID] = share.getACopy() + p.dbHandle.sharesIDs = append(p.dbHandle.sharesIDs, share.ShareID) + sort.Strings(p.dbHandle.sharesIDs) + return nil +} + +func (p *MemoryProvider) updateShare(share *Share) error { + err := share.validate() + if err != nil { + return err + } + + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + s, err := p.shareExistsInternal(share.ShareID, share.Username) + if err != nil { + return err + } + if _, err := p.userExistsInternal(share.Username); err != nil { + return util.NewValidationError(fmt.Sprintf("related user %q does not exists", share.Username)) + } + share.ID = s.ID + share.ShareID = s.ShareID + if !share.IsRestore { + share.UsedTokens = s.UsedTokens + share.CreatedAt = s.CreatedAt + share.LastUseAt = s.LastUseAt + share.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.CreatedAt == 0 { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.UpdatedAt == 0 { + share.UpdatedAt = share.CreatedAt + } + p.dbHandle.shares[share.ShareID] = share.getACopy() + return nil +} + +func (p *MemoryProvider) deleteShare(share Share) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err := p.shareExistsInternal(share.ShareID, share.Username) + if err != nil { + return err + } + + delete(p.dbHandle.shares, share.ShareID) + p.updateSharesOrdering() + + return nil +} + +func (p *MemoryProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return []Share{}, errMemoryProviderClosed + } + if limit <= 0 { + return []Share{}, nil + } + shares := make([]Share, 0, limit) + itNum := 0 + if order == OrderDESC { + for i := len(p.dbHandle.sharesIDs) - 1; i >= 0; i-- { + shareID := p.dbHandle.sharesIDs[i] + s := p.dbHandle.shares[shareID] + if s.Username != username { + continue + } + itNum++ + if itNum <= offset { + continue + } + share := s.getACopy() + share.HideConfidentialData() + shares = append(shares, share) + if len(shares) >= limit { + break + } + } + } else { + for _, shareID := range p.dbHandle.sharesIDs { + s := p.dbHandle.shares[shareID] + if s.Username != username { + continue + } + itNum++ + if itNum <= offset { + continue + } + share := s.getACopy() + share.HideConfidentialData() + shares = append(shares, share) + if len(shares) >= limit { + break + } + } + } + + return shares, nil +} + +func (p *MemoryProvider) dumpShares() ([]Share, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + shares := make([]Share, 0, len(p.dbHandle.shares)) + if p.dbHandle.isClosed { + return shares, errMemoryProviderClosed + } + for _, s := range p.dbHandle.shares { + shares = append(shares, s) + } + return shares, nil +} + +func (p *MemoryProvider) updateShareLastUse(shareID string, numTokens int) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + share, err := p.shareExistsInternal(shareID, "") + if err != nil { + return err + } + share.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now()) + share.UsedTokens += numTokens + p.dbHandle.shares[share.ShareID] = share + return nil +} + +func (p *MemoryProvider) getDefenderHosts(_ int64, _ int) ([]DefenderEntry, error) { + return nil, ErrNotImplemented +} + +func (p *MemoryProvider) getDefenderHostByIP(_ string, _ int64) (DefenderEntry, error) { + return DefenderEntry{}, ErrNotImplemented +} + +func (p *MemoryProvider) isDefenderHostBanned(_ string) (DefenderEntry, error) { + return DefenderEntry{}, ErrNotImplemented +} + +func (p *MemoryProvider) updateDefenderBanTime(_ string, _ int) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) deleteDefenderHost(_ string) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) addDefenderEvent(_ string, _ int) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) setDefenderBanTime(_ string, _ int64) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) cleanupDefender(_ int64) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) addActiveTransfer(_ ActiveTransfer) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) updateActiveTransferSizes(_, _, _ int64, _ string) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) removeActiveTransfer(_ int64, _ string) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) cleanupActiveTransfers(_ time.Time) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) getActiveTransfers(_ time.Time) ([]ActiveTransfer, error) { + return nil, ErrNotImplemented +} + +func (p *MemoryProvider) addSharedSession(_ Session) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) deleteSharedSession(_ string, _ SessionType) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) getSharedSession(_ string, _ SessionType) (Session, error) { + return Session{}, ErrNotImplemented +} + +func (p *MemoryProvider) cleanupSharedSessions(_ SessionType, _ int64) error { + return ErrNotImplemented +} + +func (p *MemoryProvider) getEventActions(limit, offset int, order string, _ bool) ([]BaseEventAction, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + if limit <= 0 { + return nil, nil + } + actions := make([]BaseEventAction, 0, limit) + itNum := 0 + if order == OrderASC { + for _, name := range p.dbHandle.actionsNames { + itNum++ + if itNum <= offset { + continue + } + a := p.dbHandle.actions[name] + action := a.getACopy() + action.PrepareForRendering() + actions = append(actions, action) + if len(actions) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.actionsNames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + name := p.dbHandle.actionsNames[i] + a := p.dbHandle.actions[name] + action := a.getACopy() + action.PrepareForRendering() + actions = append(actions, action) + if len(actions) >= limit { + break + } + } + } + return actions, nil +} + +func (p *MemoryProvider) dumpEventActions() ([]BaseEventAction, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + actions := make([]BaseEventAction, 0, len(p.dbHandle.actions)) + for _, name := range p.dbHandle.actionsNames { + a := p.dbHandle.actions[name] + action := a.getACopy() + actions = append(actions, action) + } + return actions, nil +} + +func (p *MemoryProvider) eventActionExists(name string) (BaseEventAction, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return BaseEventAction{}, errMemoryProviderClosed + } + return p.actionExistsInternal(name) +} + +func (p *MemoryProvider) addEventAction(action *BaseEventAction) error { + err := action.validate() + if err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err = p.actionExistsInternal(action.Name) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: event action %q already exists", ErrDuplicatedKey, action.Name), + util.I18nErrorDuplicatedName, + ) + } + action.ID = p.getNextActionID() + action.Rules = nil + p.dbHandle.actions[action.Name] = action.getACopy() + p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, action.Name) + sort.Strings(p.dbHandle.actionsNames) + return nil +} + +func (p *MemoryProvider) updateEventAction(action *BaseEventAction) error { + err := action.validate() + if err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldAction, err := p.actionExistsInternal(action.Name) + if err != nil { + return fmt.Errorf("event action %s does not exist", action.Name) + } + action.ID = oldAction.ID + action.Name = oldAction.Name + action.Rules = nil + if len(oldAction.Rules) > 0 { + var relatedRules []string + for _, ruleName := range oldAction.Rules { + rule, err := p.ruleExistsInternal(ruleName) + if err == nil { + relatedRules = append(relatedRules, ruleName) + rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.rules[ruleName] = rule + setLastRuleUpdate() + } + } + action.Rules = relatedRules + } + p.dbHandle.actions[action.Name] = action.getACopy() + return nil +} + +func (p *MemoryProvider) deleteEventAction(action BaseEventAction) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldAction, err := p.actionExistsInternal(action.Name) + if err != nil { + return fmt.Errorf("event action %s does not exist", action.Name) + } + if len(oldAction.Rules) > 0 { + return util.NewValidationError(fmt.Sprintf("action %s is referenced, it cannot be removed", oldAction.Name)) + } + delete(p.dbHandle.actions, action.Name) + // this could be more efficient + p.dbHandle.actionsNames = make([]string, 0, len(p.dbHandle.actions)) + for name := range p.dbHandle.actions { + p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, name) + } + sort.Strings(p.dbHandle.actionsNames) + return nil +} + +func (p *MemoryProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + if limit <= 0 { + return nil, nil + } + itNum := 0 + rules := make([]EventRule, 0, limit) + if order == OrderASC { + for _, name := range p.dbHandle.rulesNames { + itNum++ + if itNum <= offset { + continue + } + r := p.dbHandle.rules[name] + rule := r.getACopy() + p.addActionsToRule(&rule) + rule.PrepareForRendering() + rules = append(rules, rule) + if len(rules) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.rulesNames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + name := p.dbHandle.rulesNames[i] + r := p.dbHandle.rules[name] + rule := r.getACopy() + p.addActionsToRule(&rule) + rule.PrepareForRendering() + rules = append(rules, rule) + if len(rules) >= limit { + break + } + } + } + return rules, nil +} + +func (p *MemoryProvider) dumpEventRules() ([]EventRule, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + rules := make([]EventRule, 0, len(p.dbHandle.rules)) + for _, name := range p.dbHandle.rulesNames { + r := p.dbHandle.rules[name] + rule := r.getACopy() + p.addActionsToRule(&rule) + rules = append(rules, rule) + } + return rules, nil +} + +func (p *MemoryProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { + if getLastRuleUpdate() < after { + return nil, nil + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + rules := make([]EventRule, 0, 10) + for _, name := range p.dbHandle.rulesNames { + r := p.dbHandle.rules[name] + if r.UpdatedAt < after { + continue + } + rule := r.getACopy() + p.addActionsToRule(&rule) + rules = append(rules, rule) + } + return rules, nil +} + +func (p *MemoryProvider) eventRuleExists(name string) (EventRule, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return EventRule{}, errMemoryProviderClosed + } + rule, err := p.ruleExistsInternal(name) + if err != nil { + return rule, err + } + p.addActionsToRule(&rule) + return rule, nil +} + +func (p *MemoryProvider) addEventRule(rule *EventRule) error { + if err := rule.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err := p.ruleExistsInternal(rule.Name) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: event rule %q already exists", ErrDuplicatedKey, rule.Name), + util.I18nErrorDuplicatedName, + ) + } + rule.ID = p.getNextRuleID() + rule.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + rule.UpdatedAt = rule.CreatedAt + var mappedActions []string + for idx := range rule.Actions { + if err := p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil { + // try to remove action mapping + for _, a := range mappedActions { + p.removeRuleFromActionMapping(rule.Name, a) + } + return err + } + mappedActions = append(mappedActions, rule.Actions[idx].Name) + } + sort.Slice(rule.Actions, func(i, j int) bool { + return rule.Actions[i].Order < rule.Actions[j].Order + }) + p.dbHandle.rules[rule.Name] = rule.getACopy() + p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, rule.Name) + sort.Strings(p.dbHandle.rulesNames) + setLastRuleUpdate() + return nil +} + +func (p *MemoryProvider) updateEventRule(rule *EventRule) error { + if err := rule.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldRule, err := p.ruleExistsInternal(rule.Name) + if err != nil { + return err + } + for idx := range oldRule.Actions { + p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name) + } + for idx := range rule.Actions { + if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil { + // try to add old mapping + for _, oldAction := range oldRule.Actions { + if errRollback := p.addRuleToActionMapping(oldRule.Name, oldAction.Name); errRollback != nil { + providerLog(logger.LevelError, "unable to rollback old action mapping %q for rule %q, error: %v", + oldAction.Name, oldRule.Name, errRollback) + } + } + return err + } + } + rule.ID = oldRule.ID + rule.CreatedAt = oldRule.CreatedAt + rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + sort.Slice(rule.Actions, func(i, j int) bool { + return rule.Actions[i].Order < rule.Actions[j].Order + }) + p.dbHandle.rules[rule.Name] = rule.getACopy() + setLastRuleUpdate() + return nil +} + +func (p *MemoryProvider) deleteEventRule(rule EventRule, _ bool) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldRule, err := p.ruleExistsInternal(rule.Name) + if err != nil { + return err + } + if len(oldRule.Actions) > 0 { + for idx := range oldRule.Actions { + p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name) + } + } + delete(p.dbHandle.rules, rule.Name) + p.dbHandle.rulesNames = make([]string, 0, len(p.dbHandle.rules)) + for name := range p.dbHandle.rules { + p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, name) + } + sort.Strings(p.dbHandle.rulesNames) + setLastRuleUpdate() + return nil +} + +func (*MemoryProvider) getTaskByName(_ string) (Task, error) { + return Task{}, ErrNotImplemented +} + +func (*MemoryProvider) addTask(_ string) error { + return ErrNotImplemented +} + +func (*MemoryProvider) updateTask(_ string, _ int64) error { + return ErrNotImplemented +} + +func (*MemoryProvider) updateTaskTimestamp(_ string) error { + return ErrNotImplemented +} + +func (*MemoryProvider) addNode() error { + return ErrNotImplemented +} + +func (*MemoryProvider) getNodeByName(_ string) (Node, error) { + return Node{}, ErrNotImplemented +} + +func (*MemoryProvider) getNodes() ([]Node, error) { + return nil, ErrNotImplemented +} + +func (*MemoryProvider) updateNodeTimestamp() error { + return ErrNotImplemented +} + +func (*MemoryProvider) cleanupNodes() error { + return ErrNotImplemented +} + +func (p *MemoryProvider) roleExists(name string) (Role, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return Role{}, errMemoryProviderClosed + } + role, err := p.roleExistsInternal(name) + if err != nil { + return role, err + } + return role, nil +} + +func (p *MemoryProvider) addRole(role *Role) error { + if err := role.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + + _, err := p.roleExistsInternal(role.Name) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: role %q already exists", ErrDuplicatedKey, role.Name), + util.I18nErrorDuplicatedName, + ) + } + role.ID = p.getNextRoleID() + role.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + role.Users = nil + role.Admins = nil + p.dbHandle.roles[role.Name] = role.getACopy() + p.dbHandle.roleNames = append(p.dbHandle.roleNames, role.Name) + sort.Strings(p.dbHandle.roleNames) + return nil +} + +func (p *MemoryProvider) updateRole(role *Role) error { + if err := role.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldRole, err := p.roleExistsInternal(role.Name) + if err != nil { + return err + } + role.ID = oldRole.ID + role.CreatedAt = oldRole.CreatedAt + role.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + role.Users = oldRole.Users + role.Admins = oldRole.Admins + p.dbHandle.roles[role.Name] = role.getACopy() + return nil +} + +func (p *MemoryProvider) deleteRole(role Role) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldRole, err := p.roleExistsInternal(role.Name) + if err != nil { + return err + } + if len(oldRole.Admins) > 0 { + return util.NewValidationError(fmt.Sprintf("the role %q is referenced, it cannot be removed", oldRole.Name)) + } + for _, username := range oldRole.Users { + user, err := p.userExistsInternal(username) + if err != nil { + continue + } + if user.Role == role.Name { + user.Role = "" + p.dbHandle.users[username] = user + } else { + providerLog(logger.LevelError, "user %q does not have the expected role %q, actual %q", username, role.Name, user.Role) + } + } + delete(p.dbHandle.roles, role.Name) + p.dbHandle.roleNames = make([]string, 0, len(p.dbHandle.roles)) + for name := range p.dbHandle.roles { + p.dbHandle.roleNames = append(p.dbHandle.roleNames, name) + } + sort.Strings(p.dbHandle.roleNames) + return nil +} + +func (p *MemoryProvider) getRoles(limit int, offset int, order string, _ bool) ([]Role, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + if limit <= 0 { + return nil, nil + } + roles := make([]Role, 0, 10) + itNum := 0 + if order == OrderASC { + for _, name := range p.dbHandle.roleNames { + itNum++ + if itNum <= offset { + continue + } + r := p.dbHandle.roles[name] + role := r.getACopy() + roles = append(roles, role) + if len(roles) >= limit { + break + } + } + } else { + for i := len(p.dbHandle.roleNames) - 1; i >= 0; i-- { + itNum++ + if itNum <= offset { + continue + } + name := p.dbHandle.roleNames[i] + r := p.dbHandle.roles[name] + role := r.getACopy() + roles = append(roles, role) + if len(roles) >= limit { + break + } + } + } + return roles, nil +} + +func (p *MemoryProvider) dumpRoles() ([]Role, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + + roles := make([]Role, 0, len(p.dbHandle.roles)) + for _, name := range p.dbHandle.roleNames { + r := p.dbHandle.roles[name] + roles = append(roles, r.getACopy()) + } + return roles, nil +} + +func (p *MemoryProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return IPListEntry{}, errMemoryProviderClosed + } + entry, err := p.ipListEntryExistsInternal(&IPListEntry{IPOrNet: ipOrNet, Type: listType}) + if err != nil { + return entry, err + } + entry.PrepareForRendering() + return entry, nil +} + +func (p *MemoryProvider) addIPListEntry(entry *IPListEntry) error { + if err := entry.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err := p.ipListEntryExistsInternal(entry) + if err == nil { + return util.NewI18nError( + fmt.Errorf("%w: entry %q already exists", ErrDuplicatedKey, entry.IPOrNet), + util.I18nErrorDuplicatedIPNet, + ) + } + entry.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.ipListEntries[entry.getKey()] = entry.getACopy() + p.dbHandle.ipListEntriesKeys = append(p.dbHandle.ipListEntriesKeys, entry.getKey()) + sort.Strings(p.dbHandle.ipListEntriesKeys) + return nil +} + +func (p *MemoryProvider) updateIPListEntry(entry *IPListEntry) error { + if err := entry.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + oldEntry, err := p.ipListEntryExistsInternal(entry) + if err != nil { + return err + } + entry.CreatedAt = oldEntry.CreatedAt + entry.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.ipListEntries[entry.getKey()] = entry.getACopy() + return nil +} + +func (p *MemoryProvider) deleteIPListEntry(entry IPListEntry, _ bool) error { + if err := entry.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + _, err := p.ipListEntryExistsInternal(&entry) + if err != nil { + return err + } + delete(p.dbHandle.ipListEntries, entry.getKey()) + p.dbHandle.ipListEntriesKeys = make([]string, 0, len(p.dbHandle.ipListEntries)) + for k := range p.dbHandle.ipListEntries { + p.dbHandle.ipListEntriesKeys = append(p.dbHandle.ipListEntriesKeys, k) + } + sort.Strings(p.dbHandle.ipListEntriesKeys) + return nil +} + +func (p *MemoryProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + entries := make([]IPListEntry, 0, 15) + if order == OrderASC { + for _, k := range p.dbHandle.ipListEntriesKeys { + e := p.dbHandle.ipListEntries[k] + if e.Type == listType && e.satisfySearchConstraints(filter, from, order) { + entry := e.getACopy() + entry.PrepareForRendering() + entries = append(entries, entry) + if limit > 0 && len(entries) >= limit { + break + } + } + } + } else { + for i := len(p.dbHandle.ipListEntriesKeys) - 1; i >= 0; i-- { + e := p.dbHandle.ipListEntries[p.dbHandle.ipListEntriesKeys[i]] + if e.Type == listType && e.satisfySearchConstraints(filter, from, order) { + entry := e.getACopy() + entry.PrepareForRendering() + entries = append(entries, entry) + if limit > 0 && len(entries) >= limit { + break + } + } + } + } + + return entries, nil +} + +func (p *MemoryProvider) getRecentlyUpdatedIPListEntries(_ int64) ([]IPListEntry, error) { + return nil, ErrNotImplemented +} + +func (p *MemoryProvider) dumpIPListEntries() ([]IPListEntry, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + if count := len(p.dbHandle.ipListEntriesKeys); count > ipListMemoryLimit { + providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count) + return nil, nil + } + entries := make([]IPListEntry, 0, len(p.dbHandle.ipListEntries)) + for _, k := range p.dbHandle.ipListEntriesKeys { + e := p.dbHandle.ipListEntries[k] + entry := e.getACopy() + entry.PrepareForRendering() + entries = append(entries, entry) + } + return entries, nil +} + +func (p *MemoryProvider) countIPListEntries(listType IPListType) (int64, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return 0, errMemoryProviderClosed + } + if listType == 0 { + return int64(len(p.dbHandle.ipListEntriesKeys)), nil + } + var count int64 + for _, k := range p.dbHandle.ipListEntriesKeys { + e := p.dbHandle.ipListEntries[k] + if e.Type == listType { + count++ + } + } + return count, nil +} + +func (p *MemoryProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + if p.dbHandle.isClosed { + return nil, errMemoryProviderClosed + } + entries := make([]IPListEntry, 0, 3) + ipAddr, err := netip.ParseAddr(ip) + if err != nil { + return entries, fmt.Errorf("invalid ip address %s", ip) + } + var netType int + var ipBytes []byte + if ipAddr.Is4() || ipAddr.Is4In6() { + netType = ipTypeV4 + as4 := ipAddr.As4() + ipBytes = as4[:] + } else { + netType = ipTypeV6 + as16 := ipAddr.As16() + ipBytes = as16[:] + } + for _, k := range p.dbHandle.ipListEntriesKeys { + e := p.dbHandle.ipListEntries[k] + if e.Type == listType && e.IPType == netType && bytes.Compare(ipBytes, e.First) >= 0 && bytes.Compare(ipBytes, e.Last) <= 0 { + entry := e.getACopy() + entry.PrepareForRendering() + entries = append(entries, entry) + } + } + return entries, nil +} + +func (p *MemoryProvider) getConfigs() (Configs, error) { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return Configs{}, errMemoryProviderClosed + } + return p.dbHandle.configs.getACopy(), nil +} + +func (p *MemoryProvider) setConfigs(configs *Configs) error { + if err := configs.validate(); err != nil { + return err + } + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + p.dbHandle.configs = configs.getACopy() + return nil +} + +func (p *MemoryProvider) setFirstDownloadTimestamp(username string) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + return err + } + if user.FirstDownload > 0 { + return util.NewGenericError(fmt.Sprintf("first download already set to %s", + util.GetTimeFromMsecSinceEpoch(user.FirstDownload))) + } + user.FirstDownload = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.users[user.Username] = user + return nil +} + +func (p *MemoryProvider) setFirstUploadTimestamp(username string) error { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + if p.dbHandle.isClosed { + return errMemoryProviderClosed + } + user, err := p.userExistsInternal(username) + if err != nil { + return err + } + if user.FirstUpload > 0 { + return util.NewGenericError(fmt.Sprintf("first upload already set to %s", + util.GetTimeFromMsecSinceEpoch(user.FirstUpload))) + } + user.FirstUpload = util.GetTimeAsMsSinceEpoch(time.Now()) + p.dbHandle.users[user.Username] = user + return nil +} + +func (p *MemoryProvider) getNextID() int64 { + nextID := int64(1) + for _, v := range p.dbHandle.users { + if v.ID >= nextID { + nextID = v.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) getNextFolderID() int64 { + nextID := int64(1) + for _, v := range p.dbHandle.vfolders { + if v.ID >= nextID { + nextID = v.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) getNextAdminID() int64 { + nextID := int64(1) + for _, a := range p.dbHandle.admins { + if a.ID >= nextID { + nextID = a.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) getNextGroupID() int64 { + nextID := int64(1) + for _, g := range p.dbHandle.groups { + if g.ID >= nextID { + nextID = g.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) getNextActionID() int64 { + nextID := int64(1) + for _, a := range p.dbHandle.actions { + if a.ID >= nextID { + nextID = a.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) getNextRuleID() int64 { + nextID := int64(1) + for _, r := range p.dbHandle.rules { + if r.ID >= nextID { + nextID = r.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) getNextRoleID() int64 { + nextID := int64(1) + for _, r := range p.dbHandle.roles { + if r.ID >= nextID { + nextID = r.ID + 1 + } + } + return nextID +} + +func (p *MemoryProvider) clear() { + p.dbHandle.Lock() + defer p.dbHandle.Unlock() + + p.dbHandle.usernames = []string{} + p.dbHandle.users = make(map[string]User) + p.dbHandle.groupnames = []string{} + p.dbHandle.groups = map[string]Group{} + p.dbHandle.vfoldersNames = []string{} + p.dbHandle.vfolders = make(map[string]vfs.BaseVirtualFolder) + p.dbHandle.admins = make(map[string]Admin) + p.dbHandle.adminsUsernames = []string{} + p.dbHandle.apiKeys = make(map[string]APIKey) + p.dbHandle.apiKeysIDs = []string{} + p.dbHandle.shares = make(map[string]Share) + p.dbHandle.sharesIDs = []string{} + p.dbHandle.actions = map[string]BaseEventAction{} + p.dbHandle.actionsNames = []string{} + p.dbHandle.rules = map[string]EventRule{} + p.dbHandle.rulesNames = []string{} + p.dbHandle.roles = map[string]Role{} + p.dbHandle.roleNames = []string{} + p.dbHandle.ipListEntries = map[string]IPListEntry{} + p.dbHandle.ipListEntriesKeys = []string{} + p.dbHandle.configs = Configs{} +} + +func (p *MemoryProvider) reloadConfig() error { + if p.dbHandle.configFile == "" { + providerLog(logger.LevelDebug, "no dump configuration file defined") + return nil + } + providerLog(logger.LevelDebug, "loading dump from file: %q", p.dbHandle.configFile) + fi, err := os.Stat(p.dbHandle.configFile) + if err != nil { + providerLog(logger.LevelError, "error loading dump: %v", err) + return err + } + if fi.Size() == 0 { + err = errors.New("dump configuration file is invalid, its size must be > 0") + providerLog(logger.LevelError, "error loading dump: %v", err) + return err + } + if fi.Size() > 20971520 { + err = errors.New("dump configuration file is invalid, its size must be <= 20971520 bytes") + providerLog(logger.LevelError, "error loading dump: %v", err) + return err + } + content, err := os.ReadFile(p.dbHandle.configFile) + if err != nil { + providerLog(logger.LevelError, "error loading dump: %v", err) + return err + } + dump, err := ParseDumpData(content) + if err != nil { + providerLog(logger.LevelError, "error loading dump: %v", err) + return err + } + return p.restoreDump(&dump) +} + +func (p *MemoryProvider) restoreDump(dump *BackupData) error { + p.clear() + + if err := p.restoreConfigs(dump); err != nil { + return err + } + + if err := p.restoreIPListEntries(dump); err != nil { + return err + } + + if err := p.restoreRoles(dump); err != nil { + return err + } + + if err := p.restoreFolders(dump); err != nil { + return err + } + + if err := p.restoreGroups(dump); err != nil { + return err + } + + if err := p.restoreUsers(dump); err != nil { + return err + } + + if err := p.restoreAdmins(dump); err != nil { + return err + } + + if err := p.restoreAPIKeys(dump); err != nil { + return err + } + + if err := p.restoreShares(dump); err != nil { + return err + } + + if err := p.restoreEventActions(dump); err != nil { + return err + } + + if err := p.restoreEventRules(dump); err != nil { + return err + } + + providerLog(logger.LevelDebug, "config loaded from file: %q", p.dbHandle.configFile) + return nil +} + +func (p *MemoryProvider) restoreEventActions(dump *BackupData) error { + for idx := range dump.EventActions { + action := dump.EventActions[idx] + a, err := p.eventActionExists(action.Name) + if err == nil { + action.ID = a.ID + err = UpdateEventAction(&action, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating event action %q: %v", action.Name, err) + return err + } + } else { + err = AddEventAction(&action, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding event action %q: %v", action.Name, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreEventRules(dump *BackupData) error { + for idx := range dump.EventRules { + rule := dump.EventRules[idx] + r, err := p.eventRuleExists(rule.Name) + if dump.Version < 15 { + rule.Status = 1 + } + if err == nil { + rule.ID = r.ID + err = UpdateEventRule(&rule, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating event rule %q: %v", rule.Name, err) + return err + } + } else { + err = AddEventRule(&rule, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding event rule %q: %v", rule.Name, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreShares(dump *BackupData) error { + for idx := range dump.Shares { + share := dump.Shares[idx] + s, err := p.shareExists(share.ShareID, "") + share.IsRestore = true + if err == nil { + share.ID = s.ID + err = UpdateShare(&share, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating share %q: %v", share.ShareID, err) + return err + } + } else { + err = AddShare(&share, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding share %q: %v", share.ShareID, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreAPIKeys(dump *BackupData) error { + for idx := range dump.APIKeys { + apiKey := dump.APIKeys[idx] + if apiKey.Key == "" { + return fmt.Errorf("cannot restore an empty API key: %+v", apiKey) + } + k, err := p.apiKeyExists(apiKey.KeyID) + if err == nil { + apiKey.ID = k.ID + err = UpdateAPIKey(&apiKey, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating API key %q: %v", apiKey.KeyID, err) + return err + } + } else { + err = AddAPIKey(&apiKey, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding API key %q: %v", apiKey.KeyID, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreAdmins(dump *BackupData) error { + for idx := range dump.Admins { + admin := dump.Admins[idx] + admin.Username = config.convertName(admin.Username) + a, err := p.adminExists(admin.Username) + if err == nil { + admin.ID = a.ID + err = UpdateAdmin(&admin, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating admin %q: %v", admin.Username, err) + return err + } + } else { + err = AddAdmin(&admin, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding admin %q: %v", admin.Username, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreConfigs(dump *BackupData) error { + if dump.Configs != nil && dump.Configs.UpdatedAt > 0 { + return UpdateConfigs(dump.Configs, ActionExecutorSystem, "", "") + } + return nil +} + +func (p *MemoryProvider) restoreIPListEntries(dump *BackupData) error { + for idx := range dump.IPLists { + entry := dump.IPLists[idx] + _, err := p.ipListEntryExists(entry.IPOrNet, entry.Type) + if err == nil { + err = UpdateIPListEntry(&entry, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating IP list entry %q: %v", entry.getName(), err) + return err + } + } else { + err = AddIPListEntry(&entry, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding IP list entry %q: %v", entry.getName(), err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreRoles(dump *BackupData) error { + for idx := range dump.Roles { + role := dump.Roles[idx] + role.Name = config.convertName(role.Name) + r, err := p.roleExists(role.Name) + if err == nil { + role.ID = r.ID + err = UpdateRole(&role, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating role %q: %v", role.Name, err) + return err + } + } else { + role.Admins = nil + role.Users = nil + err = AddRole(&role, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding role %q: %v", role.Name, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreGroups(dump *BackupData) error { + for idx := range dump.Groups { + group := dump.Groups[idx] + group.Name = config.convertName(group.Name) + g, err := p.groupExists(group.Name) + if err == nil { + group.ID = g.ID + err = UpdateGroup(&group, g.Users, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating group %q: %v", group.Name, err) + return err + } + } else { + group.Users = nil + err = AddGroup(&group, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding group %q: %v", group.Name, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreFolders(dump *BackupData) error { + for idx := range dump.Folders { + folder := dump.Folders[idx] + folder.Name = config.convertName(folder.Name) + f, err := p.getFolderByName(folder.Name) + if err == nil { + folder.ID = f.ID + err = UpdateFolder(&folder, f.Users, f.Groups, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating folder %q: %v", folder.Name, err) + return err + } + } else { + folder.Users = nil + err = AddFolder(&folder, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding folder %q: %v", folder.Name, err) + return err + } + } + } + return nil +} + +func (p *MemoryProvider) restoreUsers(dump *BackupData) error { + for idx := range dump.Users { + user := dump.Users[idx] + user.Username = config.convertName(user.Username) + u, err := p.userExists(user.Username, "") + if err == nil { + user.ID = u.ID + err = UpdateUser(&user, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error updating user %q: %v", user.Username, err) + return err + } + } else { + err = AddUser(&user, ActionExecutorSystem, "", "") + if err != nil { + providerLog(logger.LevelError, "error adding user %q: %v", user.Username, err) + return err + } + } + } + return nil +} + +// initializeDatabase does nothing, no initilization is needed for memory provider +func (p *MemoryProvider) initializeDatabase() error { + return ErrNoInitRequired +} + +func (p *MemoryProvider) migrateDatabase() error { + return ErrNoInitRequired +} + +func (p *MemoryProvider) revertDatabase(_ int) error { + return errors.New("memory provider does not store data, revert not possible") +} + +func (p *MemoryProvider) resetDatabase() error { + return errors.New("memory provider does not store data, reset not possible") +} diff --git a/internal/dataprovider/mysql.go b/internal/dataprovider/mysql.go new file mode 100644 index 00000000..f6be87b0 --- /dev/null +++ b/internal/dataprovider/mysql.go @@ -0,0 +1,907 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nomysql + +package dataprovider + +import ( + "context" + "crypto/tls" + "crypto/x509" + "database/sql" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "github.com/go-sql-driver/mysql" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + mysqlResetSQL = "DROP TABLE IF EXISTS `{{api_keys}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{users_folders_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{users_groups_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{admins_groups_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{groups_folders_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{shares_groups_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{admins}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{folders}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{shares}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{users}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{groups}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{defender_events}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{rules_actions_mapping}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{events_actions}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{events_rules}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{tasks}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{nodes}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{roles}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{ip_lists}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{configs}}` CASCADE;" + + "DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;" + mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + + "`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " + + "`permissions` longtext NOT NULL, `filters` longtext NULL, `additional_info` longtext NULL, `last_login` bigint NOT NULL, " + + "`role_id` integer NULL, `created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + + "CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " + + "`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " + + "`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " + + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + + "CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" + + "CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`date_time` bigint NOT NULL, `score` integer NOT NULL, `host_id` bigint NOT NULL);" + + "ALTER TABLE `{{defender_events}}` ADD CONSTRAINT `{{prefix}}defender_events_host_id_fk_defender_hosts_id` " + + "FOREIGN KEY (`host_id`) REFERENCES `{{defender_hosts}}` (`id`) ON DELETE CASCADE;" + + "CREATE TABLE `{{folders}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + + "`description` varchar(512) NULL, `path` longtext NULL, `used_quota_size` bigint NOT NULL, " + + "`used_quota_files` integer NOT NULL, `last_quota_update` bigint NOT NULL, `filesystem` longtext NULL);" + + "CREATE TABLE `{{groups}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " + + "`updated_at` bigint NOT NULL, `user_settings` longtext NULL);" + + "CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL, `type` integer NOT NULL, `data` longtext NOT NULL, " + + "`timestamp` bigint NOT NULL, PRIMARY KEY (`key`, `type`));" + + "CREATE TABLE `{{users}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + + "`status` integer NOT NULL, `expiration_date` bigint NOT NULL, `description` varchar(512) NULL, `password` longtext NULL, " + + "`public_keys` longtext NULL, `home_dir` longtext NOT NULL, `uid` bigint NOT NULL, `gid` bigint NOT NULL, " + + "`max_sessions` integer NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, " + + "`permissions` longtext NOT NULL, `used_quota_size` bigint NOT NULL, `used_quota_files` integer NOT NULL, " + + "`last_quota_update` bigint NOT NULL, `upload_bandwidth` integer NOT NULL, `download_bandwidth` integer NOT NULL, " + + "`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " + + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL, " + + "`upload_data_transfer` integer NOT NULL, `download_data_transfer` integer NOT NULL, " + + "`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` bigint NOT NULL, " + + "`used_download_data_transfer` bigint NOT NULL, `deleted_at` bigint NOT NULL, `first_download` bigint NOT NULL, " + + "`first_upload` bigint NOT NULL, `last_password_change` bigint NOT NULL, `role_id` integer NULL);" + + "CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " + + "`virtual_path` longtext NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `sort_order` integer NOT NULL);" + + "CREATE TABLE `{{users_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`user_id` integer NOT NULL, `group_id` integer NOT NULL, `group_type` integer NOT NULL, `sort_order` integer NOT NULL);" + + "CREATE TABLE `{{users_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `virtual_path` longtext NOT NULL, " + + "`quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `folder_id` integer NOT NULL, `user_id` integer NOT NULL, `sort_order` integer NOT NULL);" + + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_folder_mapping` " + + "UNIQUE (`user_id`, `folder_id`);" + + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_user_id_fk_users_id` " + + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_folder_id_fk_folders_id` " + + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" + + "CREATE INDEX `{{prefix}}users_folders_mapping_sort_order_idx` ON `{{users_folders_mapping}}` (`sort_order`);" + + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_group_mapping` UNIQUE (`user_id`, `group_id`);" + + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_group_folder_mapping` UNIQUE (`group_id`, `folder_id`);" + + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_group_id_fk_groups_id` " + + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE NO ACTION;" + + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_user_id_fk_users_id` " + + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE; " + + "CREATE INDEX `{{prefix}}users_groups_mapping_sort_order_idx` ON `{{users_groups_mapping}}` (`sort_order`);" + + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id` " + + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" + + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_group_id_fk_groups_id` " + + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" + + "CREATE INDEX `{{prefix}}groups_folders_mapping_sort_order_idx` ON `{{groups_folders_mapping}}` (`sort_order`); " + + "CREATE TABLE `{{shares}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`share_id` varchar(60) NOT NULL UNIQUE, `name` varchar(255) NOT NULL, `description` varchar(512) NULL, " + + "`scope` integer NOT NULL, `paths` longtext NOT NULL, `created_at` bigint NOT NULL, " + + "`updated_at` bigint NOT NULL, `last_use_at` bigint NOT NULL, `expires_at` bigint NOT NULL, " + + "`password` longtext NULL, `max_tokens` integer NOT NULL, `used_tokens` integer NOT NULL, " + + "`allow_from` longtext NULL, `options` longtext NULL, `user_id` integer NOT NULL);" + + "ALTER TABLE `{{shares}}` ADD CONSTRAINT `{{prefix}}shares_user_id_fk_users_id` " + + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + + "CREATE TABLE `{{api_keys}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL, `key_id` varchar(50) NOT NULL UNIQUE," + + "`api_key` varchar(255) NOT NULL UNIQUE, `scope` integer NOT NULL, `created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `last_use_at` bigint NOT NULL, " + + "`expires_at` bigint NOT NULL, `description` longtext NULL, `admin_id` integer NULL, `user_id` integer NULL);" + + "ALTER TABLE `{{api_keys}}` ADD CONSTRAINT `{{prefix}}api_keys_admin_id_fk_admins_id` FOREIGN KEY (`admin_id`) REFERENCES `{{admins}}` (`id`) ON DELETE CASCADE;" + + "ALTER TABLE `{{api_keys}}` ADD CONSTRAINT `{{prefix}}api_keys_user_id_fk_users_id` FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" + + "CREATE TABLE `{{events_rules}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`name` varchar(255) NOT NULL UNIQUE, `status` integer NOT NULL, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " + + "`updated_at` bigint NOT NULL, `trigger` integer NOT NULL, `conditions` longtext NOT NULL, `deleted_at` bigint NOT NULL);" + + "CREATE TABLE `{{events_actions}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `type` integer NOT NULL, " + + "`options` longtext NOT NULL);" + + "CREATE TABLE `{{rules_actions_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`rule_id` integer NOT NULL, `action_id` integer NOT NULL, `order` integer NOT NULL, `options` longtext NOT NULL);" + + "CREATE TABLE `{{tasks}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + + "`updated_at` bigint NOT NULL, `version` bigint NOT NULL);" + + "ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}unique_rule_action_mapping` UNIQUE (`rule_id`, `action_id`);" + + "ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id` " + + "FOREIGN KEY (`rule_id`) REFERENCES `{{events_rules}}` (`id`) ON DELETE CASCADE;" + + "ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id` " + + "FOREIGN KEY (`action_id`) REFERENCES `{{events_actions}}` (`id`) ON DELETE NO ACTION;" + + "CREATE TABLE `{{admins_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + " `admin_id` integer NOT NULL, `group_id` integer NOT NULL, `options` longtext NOT NULL, `sort_order` integer NOT NULL);" + + "ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_admin_group_mapping` " + + "UNIQUE (`admin_id`, `group_id`);" + + "ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_admin_id_fk_admins_id` " + + "FOREIGN KEY (`admin_id`) REFERENCES `{{admins}}` (`id`) ON DELETE CASCADE;" + + "ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_group_id_fk_groups_id` " + + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" + + "CREATE INDEX `{{prefix}}admins_groups_mapping_sort_order_idx` ON `{{admins_groups_mapping}}` (`sort_order`); " + + "CREATE TABLE `{{nodes}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + + "`name` varchar(255) NOT NULL UNIQUE, `data` longtext NOT NULL, `created_at` bigint NOT NULL, " + + "`updated_at` bigint NOT NULL);" + + "CREATE TABLE `{{roles}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + + "`description` varchar(512) NULL, `created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + + "ALTER TABLE `{{admins}}` ADD CONSTRAINT `{{prefix}}admins_role_id_fk_roles_id` FOREIGN KEY (`role_id`) " + + "REFERENCES `{{roles}}`(`id`) ON DELETE NO ACTION;" + + "ALTER TABLE `{{users}}` ADD CONSTRAINT `{{prefix}}users_role_id_fk_roles_id` FOREIGN KEY (`role_id`) " + + "REFERENCES `{{roles}}`(`id`) ON DELETE SET NULL;" + + "CREATE TABLE `{{ip_lists}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, `type` integer NOT NULL, " + + "`ipornet` varchar(50) NOT NULL, `mode` integer NOT NULL, `description` varchar(512) NULL, " + + "`first` VARBINARY(16) NOT NULL, `last` VARBINARY(16) NOT NULL, `ip_type` integer NOT NULL, `protocols` integer NOT NULL, " + + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `deleted_at` bigint NOT NULL);" + + "ALTER TABLE `{{ip_lists}}` ADD CONSTRAINT `{{prefix}}unique_ipornet_type_mapping` UNIQUE (`type`, `ipornet`);" + + "CREATE TABLE `{{configs}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `configs` longtext NOT NULL);" + + "INSERT INTO {{configs}} (configs) VALUES ('{}');" + + "CREATE INDEX `{{prefix}}users_updated_at_idx` ON `{{users}}` (`updated_at`);" + + "CREATE INDEX `{{prefix}}users_deleted_at_idx` ON `{{users}}` (`deleted_at`);" + + "CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" + + "CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" + + "CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);" + + "CREATE INDEX `{{prefix}}active_transfers_connection_id_idx` ON `{{active_transfers}}` (`connection_id`);" + + "CREATE INDEX `{{prefix}}active_transfers_transfer_id_idx` ON `{{active_transfers}}` (`transfer_id`);" + + "CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);" + + "CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + + "CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" + + "CREATE INDEX `{{prefix}}events_rules_updated_at_idx` ON `{{events_rules}}` (`updated_at`);" + + "CREATE INDEX `{{prefix}}events_rules_deleted_at_idx` ON `{{events_rules}}` (`deleted_at`);" + + "CREATE INDEX `{{prefix}}events_rules_trigger_idx` ON `{{events_rules}}` (`trigger`);" + + "CREATE INDEX `{{prefix}}rules_actions_mapping_order_idx` ON `{{rules_actions_mapping}}` (`order`);" + + "CREATE INDEX `{{prefix}}ip_lists_type_idx` ON `{{ip_lists}}` (`type`);" + + "CREATE INDEX `{{prefix}}ip_lists_ipornet_idx` ON `{{ip_lists}}` (`ipornet`);" + + "CREATE INDEX `{{prefix}}ip_lists_ip_type_idx` ON `{{ip_lists}}` (`ip_type`);" + + "CREATE INDEX `{{prefix}}ip_lists_updated_at_idx` ON `{{ip_lists}}` (`updated_at`);" + + "CREATE INDEX `{{prefix}}ip_lists_deleted_at_idx` ON `{{ip_lists}}` (`deleted_at`);" + + "CREATE INDEX `{{prefix}}ip_lists_first_last_idx` ON `{{ip_lists}}` (`first`, `last`);" + + "INSERT INTO {{schema_version}} (version) VALUES (33);" + mysqlV34SQL = "CREATE TABLE `{{shares_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY," + + "`share_id` integer NOT NULL, `group_id` integer NOT NULL, `permissions` integer NOT NULL," + + "`sort_order` integer NOT NULL," + + "CONSTRAINT `{{prefix}}unique_share_group_mapping` UNIQUE (`share_id`, `group_id`)," + + "CONSTRAINT `{{prefix}}shares_groups_mapping_share_id_fk` FOREIGN KEY (`share_id`) REFERENCES `{{shares}}` (`id`) ON DELETE CASCADE," + + "CONSTRAINT `{{prefix}}shares_groups_mapping_group_id_fk` FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE); " + + "CREATE INDEX `{{prefix}}shares_groups_mapping_sort_order_idx` ON `{{shares_groups_mapping}}` (`sort_order`); " + + "CREATE INDEX `{{prefix}}shares_groups_mapping_share_id_idx` ON `{{shares_groups_mapping}}` (`share_id`); " + + "CREATE INDEX `{{prefix}}shares_groups_mapping_group_id_idx` ON `{{shares_groups_mapping}}` (`group_id`);" + mysqlV34DownSQL = "DROP TABLE IF EXISTS `{{shares_groups_mapping}}`;" +) + +// MySQLProvider defines the auth provider for MySQL/MariaDB database +type MySQLProvider struct { + dbHandle *sql.DB +} + +func init() { + version.AddFeature("+mysql") +} + +func initializeMySQLProvider() error { + connString, err := getMySQLConnectionString(false) + if err != nil { + return err + } + redactedConnString, err := getMySQLConnectionString(true) + if err != nil { + return err + } + dbHandle, err := sql.Open("mysql", connString) + if err != nil { + providerLog(logger.LevelError, "error creating mysql database handler, connection string: %q, error: %v", + redactedConnString, err) + return err + } + providerLog(logger.LevelDebug, "mysql database handle created, connection string: %q, pool size: %v", + redactedConnString, config.PoolSize) + dbHandle.SetMaxOpenConns(config.PoolSize) + if config.PoolSize > 0 { + dbHandle.SetMaxIdleConns(config.PoolSize) + } else { + dbHandle.SetMaxIdleConns(2) + } + dbHandle.SetConnMaxLifetime(240 * time.Second) + dbHandle.SetConnMaxIdleTime(120 * time.Second) + provider = &MySQLProvider{dbHandle: dbHandle} + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return dbHandle.PingContext(ctx) +} +func getMySQLConnectionString(redactedPwd bool) (string, error) { + var connectionString string + if config.ConnectionString == "" { + password := config.Password + if redactedPwd && password != "" { + password = "[redacted]" + } + sslMode := getSSLMode() + if sslMode == "custom" && !redactedPwd { + if err := registerMySQLCustomTLSConfig(); err != nil { + return "", err + } + } + connectionString = fmt.Sprintf("%s:%s@tcp([%s]:%d)/%s?collation=utf8mb4_unicode_ci&interpolateParams=true&timeout=10s&parseTime=true&clientFoundRows=true&tls=%s&writeTimeout=60s&readTimeout=60s", + config.Username, password, config.Host, config.Port, config.Name, sslMode) + } else { + connectionString = config.ConnectionString + } + return connectionString, nil +} + +func registerMySQLCustomTLSConfig() error { + tlsConfig := &tls.Config{} + if config.RootCert != "" { + rootCAs, err := x509.SystemCertPool() + if err != nil { + rootCAs = x509.NewCertPool() + } + rootCrt, err := os.ReadFile(config.RootCert) + if err != nil { + return fmt.Errorf("unable to load root certificate %q: %v", config.RootCert, err) + } + if !rootCAs.AppendCertsFromPEM(rootCrt) { + return fmt.Errorf("unable to parse root certificate %q", config.RootCert) + } + tlsConfig.RootCAs = rootCAs + } + if config.ClientCert != "" && config.ClientKey != "" { + clientCert := make([]tls.Certificate, 0, 1) + tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey) + if err != nil { + return fmt.Errorf("unable to load key pair %q, %q: %v", config.ClientCert, config.ClientKey, err) + } + clientCert = append(clientCert, tlsCert) + tlsConfig.Certificates = clientCert + } + if config.SSLMode == 2 || config.SSLMode == 3 { + tlsConfig.InsecureSkipVerify = true + } + if !filepath.IsAbs(config.Host) && !config.DisableSNI { + tlsConfig.ServerName = config.Host + } + providerLog(logger.LevelInfo, "registering custom TLS config, root cert %q, client cert %q, client key %q, disable SNI? %v", + config.RootCert, config.ClientCert, config.ClientKey, config.DisableSNI) + if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil { + return fmt.Errorf("unable to register tls config: %v", err) + } + return nil +} + +func (p *MySQLProvider) checkAvailability() error { + return sqlCommonCheckAvailability(p.dbHandle) +} + +func (p *MySQLProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { + return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle) +} + +func (p *MySQLProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { + return sqlCommonValidateUserAndTLSCertificate(username, protocol, tlsCert, p.dbHandle) +} + +func (p *MySQLProvider) validateUserAndPubKey(username string, publicKey []byte, isSSHCert bool) (User, string, error) { + return sqlCommonValidateUserAndPubKey(username, publicKey, isSSHCert, p.dbHandle) +} + +func (p *MySQLProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) +} + +func (p *MySQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) +} + +func (p *MySQLProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { + return sqlCommonGetUsedQuota(username, p.dbHandle) +} + +func (p *MySQLProvider) getAdminSignature(username string) (string, error) { + return sqlCommonGetAdminSignature(username, p.dbHandle) +} + +func (p *MySQLProvider) getUserSignature(username string) (string, error) { + return sqlCommonGetUserSignature(username, p.dbHandle) +} + +func (p *MySQLProvider) setUpdatedAt(username string) { + sqlCommonSetUpdatedAt(username, p.dbHandle) +} + +func (p *MySQLProvider) updateLastLogin(username string) error { + return sqlCommonUpdateLastLogin(username, p.dbHandle) +} + +func (p *MySQLProvider) updateAdminLastLogin(username string) error { + return sqlCommonUpdateAdminLastLogin(username, p.dbHandle) +} + +func (p *MySQLProvider) userExists(username, role string) (User, error) { + return sqlCommonGetUserByUsername(username, role, p.dbHandle) +} + +func (p *MySQLProvider) addUser(user *User) error { + return p.normalizeError(sqlCommonAddUser(user, p.dbHandle), fieldUsername) +} + +func (p *MySQLProvider) updateUser(user *User) error { + return p.normalizeError(sqlCommonUpdateUser(user, p.dbHandle), -1) +} + +func (p *MySQLProvider) deleteUser(user User, softDelete bool) error { + return sqlCommonDeleteUser(user, softDelete, p.dbHandle) +} + +func (p *MySQLProvider) updateUserPassword(username, password string) error { + return sqlCommonUpdateUserPassword(username, password, p.dbHandle) +} + +func (p *MySQLProvider) dumpUsers() ([]User, error) { + return sqlCommonDumpUsers(p.dbHandle) +} + +func (p *MySQLProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { + return sqlCommonGetRecentlyUpdatedUsers(after, p.dbHandle) +} + +func (p *MySQLProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { + return sqlCommonGetUsers(limit, offset, order, role, p.dbHandle) +} + +func (p *MySQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) +} + +func (p *MySQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { + return sqlCommonDumpFolders(p.dbHandle) +} + +func (p *MySQLProvider) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { + return sqlCommonGetFolders(limit, offset, order, minimal, p.dbHandle) +} + +func (p *MySQLProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + return sqlCommonGetFolderByName(ctx, name, p.dbHandle) +} + +func (p *MySQLProvider) addFolder(folder *vfs.BaseVirtualFolder) error { + return p.normalizeError(sqlCommonAddFolder(folder, p.dbHandle), fieldName) +} + +func (p *MySQLProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { + return sqlCommonUpdateFolder(folder, p.dbHandle) +} + +func (p *MySQLProvider) deleteFolder(folder vfs.BaseVirtualFolder) error { + return sqlCommonDeleteFolder(folder, p.dbHandle) +} + +func (p *MySQLProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { + return sqlCommonUpdateFolderQuota(name, filesAdd, sizeAdd, reset, p.dbHandle) +} + +func (p *MySQLProvider) getUsedFolderQuota(name string) (int, int64, error) { + return sqlCommonGetFolderUsedQuota(name, p.dbHandle) +} + +func (p *MySQLProvider) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) { + return sqlCommonGetGroups(limit, offset, order, minimal, p.dbHandle) +} + +func (p *MySQLProvider) getGroupsWithNames(names []string) ([]Group, error) { + return sqlCommonGetGroupsWithNames(names, p.dbHandle) +} + +func (p *MySQLProvider) getUsersInGroups(names []string) ([]string, error) { + return sqlCommonGetUsersInGroups(names, p.dbHandle) +} + +func (p *MySQLProvider) groupExists(name string) (Group, error) { + return sqlCommonGetGroupByName(name, p.dbHandle) +} + +func (p *MySQLProvider) addGroup(group *Group) error { + return p.normalizeError(sqlCommonAddGroup(group, p.dbHandle), fieldName) +} + +func (p *MySQLProvider) updateGroup(group *Group) error { + return sqlCommonUpdateGroup(group, p.dbHandle) +} + +func (p *MySQLProvider) deleteGroup(group Group) error { + return sqlCommonDeleteGroup(group, p.dbHandle) +} + +func (p *MySQLProvider) dumpGroups() ([]Group, error) { + return sqlCommonDumpGroups(p.dbHandle) +} + +func (p *MySQLProvider) adminExists(username string) (Admin, error) { + return sqlCommonGetAdminByUsername(username, p.dbHandle) +} + +func (p *MySQLProvider) addAdmin(admin *Admin) error { + return p.normalizeError(sqlCommonAddAdmin(admin, p.dbHandle), fieldUsername) +} + +func (p *MySQLProvider) updateAdmin(admin *Admin) error { + return p.normalizeError(sqlCommonUpdateAdmin(admin, p.dbHandle), -1) +} + +func (p *MySQLProvider) deleteAdmin(admin Admin) error { + return sqlCommonDeleteAdmin(admin, p.dbHandle) +} + +func (p *MySQLProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { + return sqlCommonGetAdmins(limit, offset, order, p.dbHandle) +} + +func (p *MySQLProvider) dumpAdmins() ([]Admin, error) { + return sqlCommonDumpAdmins(p.dbHandle) +} + +func (p *MySQLProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { + return sqlCommonValidateAdminAndPass(username, password, ip, p.dbHandle) +} + +func (p *MySQLProvider) apiKeyExists(keyID string) (APIKey, error) { + return sqlCommonGetAPIKeyByID(keyID, p.dbHandle) +} + +func (p *MySQLProvider) addAPIKey(apiKey *APIKey) error { + return p.normalizeError(sqlCommonAddAPIKey(apiKey, p.dbHandle), -1) +} + +func (p *MySQLProvider) updateAPIKey(apiKey *APIKey) error { + return p.normalizeError(sqlCommonUpdateAPIKey(apiKey, p.dbHandle), -1) +} + +func (p *MySQLProvider) deleteAPIKey(apiKey APIKey) error { + return sqlCommonDeleteAPIKey(apiKey, p.dbHandle) +} + +func (p *MySQLProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { + return sqlCommonGetAPIKeys(limit, offset, order, p.dbHandle) +} + +func (p *MySQLProvider) dumpAPIKeys() ([]APIKey, error) { + return sqlCommonDumpAPIKeys(p.dbHandle) +} + +func (p *MySQLProvider) updateAPIKeyLastUse(keyID string) error { + return sqlCommonUpdateAPIKeyLastUse(keyID, p.dbHandle) +} + +func (p *MySQLProvider) shareExists(shareID, username string) (Share, error) { + return sqlCommonGetShareByID(shareID, username, p.dbHandle) +} + +func (p *MySQLProvider) addShare(share *Share) error { + return p.normalizeError(sqlCommonAddShare(share, p.dbHandle), fieldName) +} + +func (p *MySQLProvider) updateShare(share *Share) error { + return p.normalizeError(sqlCommonUpdateShare(share, p.dbHandle), -1) +} + +func (p *MySQLProvider) deleteShare(share Share) error { + return sqlCommonDeleteShare(share, p.dbHandle) +} + +func (p *MySQLProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { + return sqlCommonGetShares(limit, offset, order, username, p.dbHandle) +} + +func (p *MySQLProvider) dumpShares() ([]Share, error) { + return sqlCommonDumpShares(p.dbHandle) +} + +func (p *MySQLProvider) updateShareLastUse(shareID string, numTokens int) error { + return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle) +} + +func (p *MySQLProvider) getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { + return sqlCommonGetDefenderHosts(from, limit, p.dbHandle) +} + +func (p *MySQLProvider) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { + return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle) +} + +func (p *MySQLProvider) isDefenderHostBanned(ip string) (DefenderEntry, error) { + return sqlCommonIsDefenderHostBanned(ip, p.dbHandle) +} + +func (p *MySQLProvider) updateDefenderBanTime(ip string, minutes int) error { + return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle) +} + +func (p *MySQLProvider) deleteDefenderHost(ip string) error { + return sqlCommonDeleteDefenderHost(ip, p.dbHandle) +} + +func (p *MySQLProvider) addDefenderEvent(ip string, score int) error { + return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle) +} + +func (p *MySQLProvider) setDefenderBanTime(ip string, banTime int64) error { + return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle) +} + +func (p *MySQLProvider) cleanupDefender(from int64) error { + return sqlCommonDefenderCleanup(from, p.dbHandle) +} + +func (p *MySQLProvider) addActiveTransfer(transfer ActiveTransfer) error { + return sqlCommonAddActiveTransfer(transfer, p.dbHandle) +} + +func (p *MySQLProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) +} + +func (p *MySQLProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) +} + +func (p *MySQLProvider) cleanupActiveTransfers(before time.Time) error { + return sqlCommonCleanupActiveTransfers(before, p.dbHandle) +} + +func (p *MySQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return sqlCommonGetActiveTransfers(from, p.dbHandle) +} + +func (p *MySQLProvider) addSharedSession(session Session) error { + return sqlCommonAddSession(session, p.dbHandle) +} + +func (p *MySQLProvider) deleteSharedSession(key string, sessionType SessionType) error { + return sqlCommonDeleteSession(key, sessionType, p.dbHandle) +} + +func (p *MySQLProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { + return sqlCommonGetSession(key, sessionType, p.dbHandle) +} + +func (p *MySQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) +} + +func (p *MySQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { + return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle) +} + +func (p *MySQLProvider) dumpEventActions() ([]BaseEventAction, error) { + return sqlCommonDumpEventActions(p.dbHandle) +} + +func (p *MySQLProvider) eventActionExists(name string) (BaseEventAction, error) { + return sqlCommonGetEventActionByName(name, p.dbHandle) +} + +func (p *MySQLProvider) addEventAction(action *BaseEventAction) error { + return p.normalizeError(sqlCommonAddEventAction(action, p.dbHandle), fieldName) +} + +func (p *MySQLProvider) updateEventAction(action *BaseEventAction) error { + return sqlCommonUpdateEventAction(action, p.dbHandle) +} + +func (p *MySQLProvider) deleteEventAction(action BaseEventAction) error { + return sqlCommonDeleteEventAction(action, p.dbHandle) +} + +func (p *MySQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { + return sqlCommonGetEventRules(limit, offset, order, p.dbHandle) +} + +func (p *MySQLProvider) dumpEventRules() ([]EventRule, error) { + return sqlCommonDumpEventRules(p.dbHandle) +} + +func (p *MySQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { + return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle) +} + +func (p *MySQLProvider) eventRuleExists(name string) (EventRule, error) { + return sqlCommonGetEventRuleByName(name, p.dbHandle) +} + +func (p *MySQLProvider) addEventRule(rule *EventRule) error { + return p.normalizeError(sqlCommonAddEventRule(rule, p.dbHandle), fieldName) +} + +func (p *MySQLProvider) updateEventRule(rule *EventRule) error { + return sqlCommonUpdateEventRule(rule, p.dbHandle) +} + +func (p *MySQLProvider) deleteEventRule(rule EventRule, softDelete bool) error { + return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle) +} + +func (p *MySQLProvider) getTaskByName(name string) (Task, error) { + return sqlCommonGetTaskByName(name, p.dbHandle) +} + +func (p *MySQLProvider) addTask(name string) error { + return sqlCommonAddTask(name, p.dbHandle) +} + +func (p *MySQLProvider) updateTask(name string, version int64) error { + return sqlCommonUpdateTask(name, version, p.dbHandle) +} + +func (p *MySQLProvider) updateTaskTimestamp(name string) error { + return sqlCommonUpdateTaskTimestamp(name, p.dbHandle) +} + +func (p *MySQLProvider) addNode() error { + return sqlCommonAddNode(p.dbHandle) +} + +func (p *MySQLProvider) getNodeByName(name string) (Node, error) { + return sqlCommonGetNodeByName(name, p.dbHandle) +} + +func (p *MySQLProvider) getNodes() ([]Node, error) { + return sqlCommonGetNodes(p.dbHandle) +} + +func (p *MySQLProvider) updateNodeTimestamp() error { + return sqlCommonUpdateNodeTimestamp(p.dbHandle) +} + +func (p *MySQLProvider) cleanupNodes() error { + return sqlCommonCleanupNodes(p.dbHandle) +} + +func (p *MySQLProvider) roleExists(name string) (Role, error) { + return sqlCommonGetRoleByName(name, p.dbHandle) +} + +func (p *MySQLProvider) addRole(role *Role) error { + return p.normalizeError(sqlCommonAddRole(role, p.dbHandle), fieldName) +} + +func (p *MySQLProvider) updateRole(role *Role) error { + return sqlCommonUpdateRole(role, p.dbHandle) +} + +func (p *MySQLProvider) deleteRole(role Role) error { + return sqlCommonDeleteRole(role, p.dbHandle) +} + +func (p *MySQLProvider) getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) { + return sqlCommonGetRoles(limit, offset, order, minimal, p.dbHandle) +} + +func (p *MySQLProvider) dumpRoles() ([]Role, error) { + return sqlCommonDumpRoles(p.dbHandle) +} + +func (p *MySQLProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { + return sqlCommonGetIPListEntry(ipOrNet, listType, p.dbHandle) +} + +func (p *MySQLProvider) addIPListEntry(entry *IPListEntry) error { + return p.normalizeError(sqlCommonAddIPListEntry(entry, p.dbHandle), fieldIPNet) +} + +func (p *MySQLProvider) updateIPListEntry(entry *IPListEntry) error { + return sqlCommonUpdateIPListEntry(entry, p.dbHandle) +} + +func (p *MySQLProvider) deleteIPListEntry(entry IPListEntry, softDelete bool) error { + return sqlCommonDeleteIPListEntry(entry, softDelete, p.dbHandle) +} + +func (p *MySQLProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { + return sqlCommonGetIPListEntries(listType, filter, from, order, limit, p.dbHandle) +} + +func (p *MySQLProvider) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) { + return sqlCommonGetRecentlyUpdatedIPListEntries(after, p.dbHandle) +} + +func (p *MySQLProvider) dumpIPListEntries() ([]IPListEntry, error) { + return sqlCommonDumpIPListEntries(p.dbHandle) +} + +func (p *MySQLProvider) countIPListEntries(listType IPListType) (int64, error) { + return sqlCommonCountIPListEntries(listType, p.dbHandle) +} + +func (p *MySQLProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { + return sqlCommonGetListEntriesForIP(ip, listType, p.dbHandle) +} + +func (p *MySQLProvider) getConfigs() (Configs, error) { + return sqlCommonGetConfigs(p.dbHandle) +} + +func (p *MySQLProvider) setConfigs(configs *Configs) error { + return sqlCommonSetConfigs(configs, p.dbHandle) +} + +func (p *MySQLProvider) setFirstDownloadTimestamp(username string) error { + return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle) +} + +func (p *MySQLProvider) setFirstUploadTimestamp(username string) error { + return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle) +} + +func (p *MySQLProvider) close() error { + return p.dbHandle.Close() +} + +func (p *MySQLProvider) reloadConfig() error { + return nil +} + +// initializeDatabase creates the initial database structure +func (p *MySQLProvider) initializeDatabase() error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false) + if err == nil && dbVersion.Version > 0 { + return ErrNoInitRequired + } + if errors.Is(err, sql.ErrNoRows) { + return errSchemaVersionEmpty + } + logger.InfoToConsole("creating initial database schema, version 33") + providerLog(logger.LevelInfo, "creating initial database schema, version 33") + initialSQL := sqlReplaceAll(mysqlInitialSQL) + + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 33, true) +} + +func (p *MySQLProvider) migrateDatabase() error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) + if err != nil { + return err + } + + switch version := dbVersion.Version; { + case version == sqlDatabaseVersion: + providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version) + return ErrNoInitRequired + case version < 33: + err = errSchemaVersionTooOld(version) + providerLog(logger.LevelError, "%v", err) + logger.ErrorToConsole("%v", err) + return err + case version == 33: + return updateMySQLDatabaseFromV33(p.dbHandle) + default: + if version > sqlDatabaseVersion { + providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, + sqlDatabaseVersion) + logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, + sqlDatabaseVersion) + return nil + } + return fmt.Errorf("database schema version not handled: %d", version) + } +} + +func (p *MySQLProvider) revertDatabase(targetVersion int) error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) + if err != nil { + return err + } + if dbVersion.Version == targetVersion { + return errors.New("current version match target version, nothing to do") + } + + switch dbVersion.Version { + case 34: + return downgradeMySQLDatabaseFromV34(p.dbHandle) + default: + return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) + } +} + +func (p *MySQLProvider) resetDatabase() error { + sql := sqlReplaceAll(mysqlResetSQL) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0, false) +} + +func (p *MySQLProvider) normalizeError(err error, fieldType int) error { + if err == nil { + return nil + } + var mysqlErr *mysql.MySQLError + if errors.As(err, &mysqlErr) { + switch mysqlErr.Number { + case 1062: + var message string + switch fieldType { + case fieldUsername: + message = util.I18nErrorDuplicatedUsername + case fieldIPNet: + message = util.I18nErrorDuplicatedIPNet + default: + message = util.I18nErrorDuplicatedName + } + return util.NewI18nError( + fmt.Errorf("%w: %s", ErrDuplicatedKey, err.Error()), + message, + ) + case 1452: + return fmt.Errorf("%w: %s", ErrForeignKeyViolated, err.Error()) + } + } + return err +} + +func updateMySQLDatabaseFromV33(dbHandle *sql.DB) error { + return updateMySQLDatabaseFrom33To34(dbHandle) +} + +func downgradeMySQLDatabaseFromV34(dbHandle *sql.DB) error { + return downgradeMySQLDatabaseFrom34To33(dbHandle) +} + +func updateMySQLDatabaseFrom33To34(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 33 -> 34") + providerLog(logger.LevelInfo, "updating database schema version: 33 -> 34") + + sql := strings.ReplaceAll(mysqlV34SQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 34, true) +} + +func downgradeMySQLDatabaseFrom34To33(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 34 -> 33") + providerLog(logger.LevelInfo, "downgrading database schema version: 34 -> 33") + + sql := strings.ReplaceAll(mysqlV34DownSQL, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 33, false) +} diff --git a/internal/dataprovider/mysql_disabled.go b/internal/dataprovider/mysql_disabled.go new file mode 100644 index 00000000..203092b2 --- /dev/null +++ b/internal/dataprovider/mysql_disabled.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nomysql + +package dataprovider + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-mysql") +} + +func initializeMySQLProvider() error { + return errors.New("MySQL disabled at build time") +} diff --git a/internal/dataprovider/node.go b/internal/dataprovider/node.go new file mode 100644 index 00000000..4b8fc803 --- /dev/null +++ b/internal/dataprovider/node.go @@ -0,0 +1,280 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/go-jose/go-jose/v4" + + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// Supported protocols for connecting to other nodes +const ( + NodeProtoHTTP = "http" + NodeProtoHTTPS = "https" +) + +const ( + // NodeTokenHeader defines the header to use for the node auth token + NodeTokenHeader = "X-SFTPGO-Node" + nodeTokenAudience = "node" +) + +var ( + // current node + currentNode *Node + errNoClusterNodes = errors.New("no cluster node defined") + activeNodeTimeDiff = -2 * time.Minute + nodeReqTimeout = 8 * time.Second +) + +// NodeConfig defines the node configuration +type NodeConfig struct { + Host string `json:"host" mapstructure:"host"` + Port int `json:"port" mapstructure:"port"` + Proto string `json:"proto" mapstructure:"proto"` +} + +func (n *NodeConfig) validate() error { + currentNode = nil + if config.IsShared != 1 { + return nil + } + if n.Host == "" { + return nil + } + currentNode = &Node{ + Data: NodeData{ + Host: n.Host, + Port: n.Port, + Proto: n.Proto, + }, + } + return provider.addNode() +} + +// NodeData defines the details to connect to a cluster node +type NodeData struct { + Host string `json:"host"` + Port int `json:"port"` + Proto string `json:"proto"` + Key *kms.Secret `json:"api_key"` +} + +func (n *NodeData) validate() error { + if n.Host == "" { + return util.NewValidationError("node host is mandatory") + } + if n.Port < 0 || n.Port > 65535 { + return util.NewValidationError(fmt.Sprintf("invalid node port: %d", n.Port)) + } + if n.Proto != NodeProtoHTTP && n.Proto != NodeProtoHTTPS { + return util.NewValidationError(fmt.Sprintf("invalid node proto: %s", n.Proto)) + } + n.Key = kms.NewPlainSecret(util.GenerateOpaqueString()) + n.Key.SetAdditionalData(n.Host) + if err := n.Key.Encrypt(); err != nil { + return fmt.Errorf("unable to encrypt node key: %w", err) + } + return nil +} + +func (n *NodeData) getNodeName() string { + h := sha256.New() + var b bytes.Buffer + + fmt.Fprintf(&b, "%s:%d", n.Host, n.Port) + h.Write(b.Bytes()) + return hex.EncodeToString(h.Sum(nil)) +} + +// Node defines a cluster node +type Node struct { + Name string `json:"name"` + Data NodeData `json:"data"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +func (n *Node) validate() error { + if n.Name == "" { + n.Name = n.Data.getNodeName() + } + return n.Data.validate() +} + +func (n *Node) authenticate(token string) (*jwt.Claims, error) { + if err := n.Data.Key.TryDecrypt(); err != nil { + providerLog(logger.LevelError, "unable to decrypt node key: %v", err) + return nil, err + } + if token == "" { + return nil, ErrInvalidCredentials + } + claims, err := jwt.VerifyTokenWithKey(token, []jose.SignatureAlgorithm{jose.HS256}, []byte(n.Data.Key.GetPayload())) + if err != nil { + return nil, fmt.Errorf("unable to parse and validate token: %v", err) + } + if claims.Username == "" { + return nil, errors.New("no admin username associated with node token") + } + if !claims.Audience.Contains(nodeTokenAudience) { + return nil, errors.New("invalid node token audience") + } + + return claims, nil +} + +// getBaseURL returns the base URL for this node +func (n *Node) getBaseURL() string { + var sb strings.Builder + sb.WriteString(n.Data.Proto) + sb.WriteString("://") + sb.WriteString(n.Data.Host) + if n.Data.Port > 0 { + sb.WriteString(":") + sb.WriteString(strconv.Itoa(n.Data.Port)) + } + return sb.String() +} + +// generateAuthToken generates a new auth token +func (n *Node) generateAuthToken(username, role string, permissions []string) (string, error) { + if err := n.Data.Key.TryDecrypt(); err != nil { + return "", fmt.Errorf("unable to decrypt node key: %w", err) + } + signer, err := jwt.NewSigner(jose.HS256, []byte(n.Data.Key.GetPayload())) + if err != nil { + return "", fmt.Errorf("unable to create signer: %w", err) + } + claims := &jwt.Claims{ + Username: username, + Role: role, + Permissions: permissions, + } + claims.Audience = []string{nodeTokenAudience} + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + payload, err := signer.Sign(claims) + if err != nil { + return "", fmt.Errorf("unable to sign authentication token: %w", err) + } + return payload, nil +} + +func (n *Node) prepareRequest(ctx context.Context, username, role, relativeURL, method string, + permissions []string, body io.Reader, +) (*http.Request, error) { + url := fmt.Sprintf("%s%s", n.getBaseURL(), relativeURL) + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + token, err := n.generateAuthToken(username, role, permissions) + if err != nil { + return nil, err + } + req.Header.Set(NodeTokenHeader, fmt.Sprintf("Bearer %s", token)) + return req, nil +} + +// SendGetRequest sends an HTTP GET request to this node. +// The responseHolder must be a pointer +func (n *Node) SendGetRequest(username, role, relativeURL string, permissions []string, responseHolder any) error { + ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout) + defer cancel() + + req, err := n.prepareRequest(ctx, username, role, relativeURL, http.MethodGet, permissions, nil) + if err != nil { + return err + } + client := httpclient.GetHTTPClient() + defer client.CloseIdleConnections() + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("unable to send HTTP GET to node %s: %w", n.Name, err) + } + defer resp.Body.Close() + + if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 10485760)) + if err != nil { + return fmt.Errorf("unable to read response body: %w", err) + } + err = json.Unmarshal(respBody, responseHolder) + if err != nil { + return errors.New("unable to decode response as json") + } + return nil +} + +// SendDeleteRequest sends an HTTP DELETE request to this node +func (n *Node) SendDeleteRequest(username, role, relativeURL string, permissions []string) error { + ctx, cancel := context.WithTimeout(context.Background(), nodeReqTimeout) + defer cancel() + + req, err := n.prepareRequest(ctx, username, role, relativeURL, http.MethodDelete, permissions, nil) + if err != nil { + return err + } + client := httpclient.GetHTTPClient() + defer client.CloseIdleConnections() + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("unable to send HTTP DELETE to node %s: %w", n.Name, err) + } + defer resp.Body.Close() + + if resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusNoContent { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return nil +} + +// AuthenticateNodeToken check the validity of the provided token +func AuthenticateNodeToken(token string) (*jwt.Claims, error) { + if currentNode == nil { + return nil, errNoClusterNodes + } + return currentNode.authenticate(token) +} + +// GetNodeName returns the node name or an empty string +func GetNodeName() string { + if currentNode == nil { + return "" + } + return currentNode.Name +} diff --git a/internal/dataprovider/pgsql.go b/internal/dataprovider/pgsql.go new file mode 100644 index 00000000..c4f7c93a --- /dev/null +++ b/internal/dataprovider/pgsql.go @@ -0,0 +1,935 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nopgsql + +package dataprovider + +import ( + "context" + "crypto/x509" + "database/sql" + "errors" + "fmt" + "net" + "slices" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/stdlib" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + pgsqlResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}" CASCADE; +DROP TABLE IF EXISTS "{{users_folders_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{users_groups_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{admins_groups_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{groups_folders_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{shares_groups_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{admins}}" CASCADE; +DROP TABLE IF EXISTS "{{folders}}" CASCADE; +DROP TABLE IF EXISTS "{{shares}}" CASCADE; +DROP TABLE IF EXISTS "{{users}}" CASCADE; +DROP TABLE IF EXISTS "{{groups}}" CASCADE; +DROP TABLE IF EXISTS "{{defender_events}}" CASCADE; +DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE; +DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE; +DROP TABLE IF EXISTS "{{shared_sessions}}" CASCADE; +DROP TABLE IF EXISTS "{{rules_actions_mapping}}" CASCADE; +DROP TABLE IF EXISTS "{{events_actions}}" CASCADE; +DROP TABLE IF EXISTS "{{events_rules}}" CASCADE; +DROP TABLE IF EXISTS "{{tasks}}" CASCADE; +DROP TABLE IF EXISTS "{{nodes}}" CASCADE; +DROP TABLE IF EXISTS "{{roles}}" CASCADE; +DROP TABLE IF EXISTS "{{ip_lists}}" CASCADE; +DROP TABLE IF EXISTS "{{configs}}" CASCADE; +DROP TABLE IF EXISTS "{{schema_version}}" CASCADE; +` + pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "version" integer NOT NULL); +CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "username" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, +"permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL, +"role_id" integer NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); +CREATE TABLE "{{active_transfers}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "connection_id" varchar(100) NOT NULL, +"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL, +"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL, +"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL); +CREATE TABLE "{{defender_hosts}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "ip" varchar(50) NOT NULL UNIQUE, +"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL); +CREATE TABLE "{{defender_events}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "date_time" bigint NOT NULL, "score" integer NOT NULL, +"host_id" bigint NOT NULL); +ALTER TABLE "{{defender_events}}" ADD CONSTRAINT "{{prefix}}defender_events_host_id_fk_defender_hosts_id" FOREIGN KEY +("host_id") REFERENCES "{{defender_hosts}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, +"path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, +"filesystem" text NULL); +CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL); +CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL, "type" integer NOT NULL, +"data" text NOT NULL, "timestamp" bigint NOT NULL, PRIMARY KEY ("key", "type")); +CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "username" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, +"expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "public_keys" text NULL, +"home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "max_sessions" integer NOT NULL, +"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, +"used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, +"download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL, +"additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL, +"upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, +"used_upload_data_transfer" bigint NOT NULL, "used_download_data_transfer" bigint NOT NULL, "deleted_at" bigint NOT NULL, +"first_download" bigint NOT NULL, "first_upload" bigint NOT NULL, "last_password_change" bigint NOT NULL, "role_id" integer NULL); +CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "group_id" integer NOT NULL, +"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL); +CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "user_id" integer NOT NULL, +"group_id" integer NOT NULL, "group_type" integer NOT NULL, "sort_order" integer NOT NULL); +CREATE TABLE "{{users_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "virtual_path" text NOT NULL, +"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL, "folder_id" integer NOT NULL, "user_id" integer NOT NULL); +ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"); +ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_folder_id_fk_folders_id" +FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_user_id_fk_users_id" +FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, +"share_id" varchar(60) NOT NULL UNIQUE, "name" varchar(255) NOT NULL, "description" varchar(512) NULL, +"scope" integer NOT NULL, "paths" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, +"last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, "password" text NULL, +"max_tokens" integer NOT NULL, "used_tokens" integer NOT NULL, "allow_from" text NULL, "options" text NULL, +"user_id" integer NOT NULL); +ALTER TABLE "{{shares}}" ADD CONSTRAINT "{{prefix}}shares_user_id_fk_users_id" FOREIGN KEY ("user_id") +REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE TABLE "{{api_keys}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL, +"key_id" varchar(50) NOT NULL UNIQUE, "api_key" varchar(255) NOT NULL UNIQUE, "scope" integer NOT NULL, +"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL,"expires_at" bigint NOT NULL, +"description" text NULL, "admin_id" integer NULL, "user_id" integer NULL); +ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_admin_id_fk_admins_id" FOREIGN KEY ("admin_id") +REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_user_id_fk_users_id" FOREIGN KEY ("user_id") +REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"); +ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"); +CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id"); +ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}users_groups_mapping_group_id_fk_groups_id" +FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE NO ACTION; +CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id"); +ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}users_groups_mapping_user_id_fk_users_id" +FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE INDEX "{{prefix}}users_groups_mapping_sort_order_idx" ON "{{users_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id"); +ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_folder_id_fk_folders_id" +FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id"); +ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_group_id_fk_groups_id" +FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE INDEX "{{prefix}}groups_folders_mapping_sort_order_idx" ON "{{groups_folders_mapping}}" ("sort_order"); +CREATE TABLE "{{events_rules}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, +"status" integer NOT NULL, "description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, +"trigger" integer NOT NULL, "conditions" text NOT NULL, "deleted_at" bigint NOT NULL); +CREATE TABLE "{{events_actions}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL); +CREATE TABLE "{{rules_actions_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "rule_id" integer NOT NULL, +"action_id" integer NOT NULL, "order" integer NOT NULL, "options" text NOT NULL); +CREATE TABLE "{{tasks}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, "updated_at" bigint NOT NULL, +"version" bigint NOT NULL); +ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id"); +ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id" +FOREIGN KEY ("rule_id") REFERENCES "{{events_rules}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id" +FOREIGN KEY ("action_id") REFERENCES "{{events_actions}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE NO ACTION; +CREATE TABLE "{{admins_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, +"admin_id" integer NOT NULL, "group_id" integer NOT NULL, "options" text NOT NULL, "sort_order" integer NOT NULL); +ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id"); +ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_admin_id_fk_admins_id" +FOREIGN KEY ("admin_id") REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_group_id_fk_groups_id" +FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; +CREATE TABLE "{{nodes}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, +"data" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); +CREATE TABLE "{{roles}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); +ALTER TABLE "{{admins}}" ADD CONSTRAINT "{{prefix}}admins_role_id_fk_roles_id" FOREIGN KEY ("role_id") +REFERENCES "{{roles}}" ("id") ON DELETE NO ACTION; +ALTER TABLE "{{users}}" ADD CONSTRAINT "{{prefix}}users_role_id_fk_roles_id" FOREIGN KEY ("role_id") +REFERENCES "{{roles}}" ("id") ON DELETE SET NULL; +CREATE TABLE "{{ip_lists}}" ("id" bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "type" integer NOT NULL, +"ipornet" varchar(50) NOT NULL, "mode" integer NOT NULL, "description" varchar(512) NULL, "first" inet NOT NULL, +"last" inet NOT NULL, "ip_type" integer NOT NULL, "protocols" integer NOT NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL, "deleted_at" bigint NOT NULL); +ALTER TABLE "{{ip_lists}}" ADD CONSTRAINT "{{prefix}}unique_ipornet_type_mapping" UNIQUE ("type", "ipornet"); +CREATE TABLE "{{configs}}" ("id" integer NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY, "configs" text NOT NULL); +INSERT INTO {{configs}} (configs) VALUES ('{}'); +CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id"); +CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id"); +CREATE INDEX "{{prefix}}users_folders_mapping_sort_order_idx" ON "{{users_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id"); +CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id"); +CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at"); +CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at"); +CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id"); +CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at"); +CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time"); +CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); +CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); +CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); +CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); +CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at"); +CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at"); +CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger"); +CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id"); +CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id"); +CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order"); +CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id"); +CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id"); +CREATE INDEX "{{prefix}}admins_groups_mapping_sort_order_idx" ON "{{admins_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}admins_role_id_idx" ON "{{admins}}" ("role_id"); +CREATE INDEX "{{prefix}}users_role_id_idx" ON "{{users}}" ("role_id"); +CREATE INDEX "{{prefix}}ip_lists_type_idx" ON "{{ip_lists}}" ("type"); +CREATE INDEX "{{prefix}}ip_lists_ipornet_idx" ON "{{ip_lists}}" ("ipornet"); +CREATE INDEX "{{prefix}}ip_lists_updated_at_idx" ON "{{ip_lists}}" ("updated_at"); +CREATE INDEX "{{prefix}}ip_lists_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at"); +CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last"); +INSERT INTO {{schema_version}} (version) VALUES (33); +` + // not supported in CockroachDB + ipListsLikeIndex = `CREATE INDEX "{{prefix}}ip_lists_ipornet_like_idx" ON "{{ip_lists}}" ("ipornet" varchar_pattern_ops);` + pgsqlV34SQL = `CREATE TABLE "{{shares_groups_mapping}}" ( +"id" integer NOT NULL PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, +"share_id" integer NOT NULL, +"group_id" integer NOT NULL, +"permissions" integer NOT NULL, +"sort_order" integer NOT NULL, +CONSTRAINT "{{prefix}}unique_share_group_mapping" UNIQUE ("share_id", "group_id"), +CONSTRAINT "{{prefix}}shares_groups_mapping_share_id_fk" FOREIGN KEY ("share_id") REFERENCES "{{shares}}"("id") ON DELETE CASCADE, +CONSTRAINT "{{prefix}}shares_groups_mapping_group_id_fk" FOREIGN KEY ("group_id") REFERENCES "{{groups}}"("id") ON DELETE CASCADE); +CREATE INDEX "{{prefix}}shares_groups_mapping_sort_order_idx" ON "{{shares_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}shares_groups_mapping_share_id_idx" ON "{{shares_groups_mapping}}" ("share_id"); +CREATE INDEX "{{prefix}}shares_groups_mapping_group_id_idx" ON "{{shares_groups_mapping}}" ("group_id"); +` + pgsqlV34DownSQL = `DROP TABLE IF EXISTS "{{shares_groups_mapping}}";` +) + +var ( + pgSQLTargetSessionAttrs = []string{"any", "read-write", "read-only", "primary", "standby", "prefer-standby"} +) + +// PGSQLProvider defines the auth provider for PostgreSQL database +type PGSQLProvider struct { + dbHandle *sql.DB +} + +func init() { + version.AddFeature("+pgsql") +} + +func initializePGSQLProvider() error { + var dbHandle *sql.DB + if config.TargetSessionAttrs == "any" { + pgxConfig, err := pgx.ParseConfig(getPGSQLConnectionString(false)) + if err != nil { + providerLog(logger.LevelError, "error parsing postgres configuration, connection string: %q, error: %v", + getPGSQLConnectionString(true), err) + return err + } + dbHandle = stdlib.OpenDB(*pgxConfig, stdlib.OptionBeforeConnect(stdlib.RandomizeHostOrderFunc)) + } else { + var err error + dbHandle, err = sql.Open("pgx", getPGSQLConnectionString(false)) + if err != nil { + providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v", + getPGSQLConnectionString(true), err) + return err + } + } + providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d", + getPGSQLConnectionString(true), config.PoolSize) + dbHandle.SetMaxOpenConns(config.PoolSize) + if config.PoolSize > 0 { + dbHandle.SetMaxIdleConns(config.PoolSize) + } else { + dbHandle.SetMaxIdleConns(2) + } + dbHandle.SetConnMaxLifetime(240 * time.Second) + dbHandle.SetConnMaxIdleTime(120 * time.Second) + provider = &PGSQLProvider{dbHandle: dbHandle} + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return dbHandle.PingContext(ctx) +} + +func getPGSQLHostsAndPorts(configHost string, configPort int) (string, string) { + var hosts, ports []string + defaultPort := strconv.Itoa(configPort) + if defaultPort == "0" { + defaultPort = "5432" + } + + for hostport := range strings.SplitSeq(configHost, ",") { + hostport = strings.TrimSpace(hostport) + if hostport == "" { + continue + } + host, port, err := net.SplitHostPort(hostport) + if err == nil { + hosts = append(hosts, host) + ports = append(ports, port) + } else { + hosts = append(hosts, hostport) + ports = append(ports, defaultPort) + } + } + + return strings.Join(hosts, ","), strings.Join(ports, ",") +} + +func getPGSQLConnectionString(redactedPwd bool) string { + var connectionString string + if config.ConnectionString == "" { + password := config.Password + if redactedPwd && password != "" { + password = "[redacted]" + } + host, port := getPGSQLHostsAndPorts(config.Host, config.Port) + connectionString = fmt.Sprintf("host='%s' port='%s' dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10", + host, port, config.Name, config.Username, password, getSSLMode()) + if config.RootCert != "" { + connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert) + } + if config.ClientCert != "" && config.ClientKey != "" { + connectionString += fmt.Sprintf(" sslcert='%s' sslkey='%s'", config.ClientCert, config.ClientKey) + } + if config.DisableSNI { + connectionString += " sslsni=0" + } + if slices.Contains(pgSQLTargetSessionAttrs, config.TargetSessionAttrs) { + connectionString += fmt.Sprintf(" target_session_attrs='%s'", config.TargetSessionAttrs) + } + } else { + connectionString = config.ConnectionString + } + return connectionString +} + +func (p *PGSQLProvider) checkAvailability() error { + return sqlCommonCheckAvailability(p.dbHandle) +} + +func (p *PGSQLProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { + return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle) +} + +func (p *PGSQLProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { + return sqlCommonValidateUserAndTLSCertificate(username, protocol, tlsCert, p.dbHandle) +} + +func (p *PGSQLProvider) validateUserAndPubKey(username string, publicKey []byte, isSSHCert bool) (User, string, error) { + return sqlCommonValidateUserAndPubKey(username, publicKey, isSSHCert, p.dbHandle) +} + +func (p *PGSQLProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) +} + +func (p *PGSQLProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) +} + +func (p *PGSQLProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { + return sqlCommonGetUsedQuota(username, p.dbHandle) +} + +func (p *PGSQLProvider) getAdminSignature(username string) (string, error) { + return sqlCommonGetAdminSignature(username, p.dbHandle) +} + +func (p *PGSQLProvider) getUserSignature(username string) (string, error) { + return sqlCommonGetUserSignature(username, p.dbHandle) +} + +func (p *PGSQLProvider) setUpdatedAt(username string) { + sqlCommonSetUpdatedAt(username, p.dbHandle) +} + +func (p *PGSQLProvider) updateLastLogin(username string) error { + return sqlCommonUpdateLastLogin(username, p.dbHandle) +} + +func (p *PGSQLProvider) updateAdminLastLogin(username string) error { + return sqlCommonUpdateAdminLastLogin(username, p.dbHandle) +} + +func (p *PGSQLProvider) userExists(username, role string) (User, error) { + return sqlCommonGetUserByUsername(username, role, p.dbHandle) +} + +func (p *PGSQLProvider) addUser(user *User) error { + return p.normalizeError(sqlCommonAddUser(user, p.dbHandle), fieldUsername) +} + +func (p *PGSQLProvider) updateUser(user *User) error { + return p.normalizeError(sqlCommonUpdateUser(user, p.dbHandle), -1) +} + +func (p *PGSQLProvider) deleteUser(user User, softDelete bool) error { + return sqlCommonDeleteUser(user, softDelete, p.dbHandle) +} + +func (p *PGSQLProvider) updateUserPassword(username, password string) error { + return sqlCommonUpdateUserPassword(username, password, p.dbHandle) +} + +func (p *PGSQLProvider) dumpUsers() ([]User, error) { + return sqlCommonDumpUsers(p.dbHandle) +} + +func (p *PGSQLProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { + return sqlCommonGetRecentlyUpdatedUsers(after, p.dbHandle) +} + +func (p *PGSQLProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { + return sqlCommonGetUsers(limit, offset, order, role, p.dbHandle) +} + +func (p *PGSQLProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) +} + +func (p *PGSQLProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { + return sqlCommonDumpFolders(p.dbHandle) +} + +func (p *PGSQLProvider) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { + return sqlCommonGetFolders(limit, offset, order, minimal, p.dbHandle) +} + +func (p *PGSQLProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + return sqlCommonGetFolderByName(ctx, name, p.dbHandle) +} + +func (p *PGSQLProvider) addFolder(folder *vfs.BaseVirtualFolder) error { + return p.normalizeError(sqlCommonAddFolder(folder, p.dbHandle), fieldName) +} + +func (p *PGSQLProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { + return sqlCommonUpdateFolder(folder, p.dbHandle) +} + +func (p *PGSQLProvider) deleteFolder(folder vfs.BaseVirtualFolder) error { + return sqlCommonDeleteFolder(folder, p.dbHandle) +} + +func (p *PGSQLProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { + return sqlCommonUpdateFolderQuota(name, filesAdd, sizeAdd, reset, p.dbHandle) +} + +func (p *PGSQLProvider) getUsedFolderQuota(name string) (int, int64, error) { + return sqlCommonGetFolderUsedQuota(name, p.dbHandle) +} + +func (p *PGSQLProvider) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) { + return sqlCommonGetGroups(limit, offset, order, minimal, p.dbHandle) +} + +func (p *PGSQLProvider) getGroupsWithNames(names []string) ([]Group, error) { + return sqlCommonGetGroupsWithNames(names, p.dbHandle) +} + +func (p *PGSQLProvider) getUsersInGroups(names []string) ([]string, error) { + return sqlCommonGetUsersInGroups(names, p.dbHandle) +} + +func (p *PGSQLProvider) groupExists(name string) (Group, error) { + return sqlCommonGetGroupByName(name, p.dbHandle) +} + +func (p *PGSQLProvider) addGroup(group *Group) error { + return p.normalizeError(sqlCommonAddGroup(group, p.dbHandle), fieldName) +} + +func (p *PGSQLProvider) updateGroup(group *Group) error { + return sqlCommonUpdateGroup(group, p.dbHandle) +} + +func (p *PGSQLProvider) deleteGroup(group Group) error { + return sqlCommonDeleteGroup(group, p.dbHandle) +} + +func (p *PGSQLProvider) dumpGroups() ([]Group, error) { + return sqlCommonDumpGroups(p.dbHandle) +} + +func (p *PGSQLProvider) adminExists(username string) (Admin, error) { + return sqlCommonGetAdminByUsername(username, p.dbHandle) +} + +func (p *PGSQLProvider) addAdmin(admin *Admin) error { + return p.normalizeError(sqlCommonAddAdmin(admin, p.dbHandle), fieldUsername) +} + +func (p *PGSQLProvider) updateAdmin(admin *Admin) error { + return p.normalizeError(sqlCommonUpdateAdmin(admin, p.dbHandle), -1) +} + +func (p *PGSQLProvider) deleteAdmin(admin Admin) error { + return sqlCommonDeleteAdmin(admin, p.dbHandle) +} + +func (p *PGSQLProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { + return sqlCommonGetAdmins(limit, offset, order, p.dbHandle) +} + +func (p *PGSQLProvider) dumpAdmins() ([]Admin, error) { + return sqlCommonDumpAdmins(p.dbHandle) +} + +func (p *PGSQLProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { + return sqlCommonValidateAdminAndPass(username, password, ip, p.dbHandle) +} + +func (p *PGSQLProvider) apiKeyExists(keyID string) (APIKey, error) { + return sqlCommonGetAPIKeyByID(keyID, p.dbHandle) +} + +func (p *PGSQLProvider) addAPIKey(apiKey *APIKey) error { + return p.normalizeError(sqlCommonAddAPIKey(apiKey, p.dbHandle), -1) +} + +func (p *PGSQLProvider) updateAPIKey(apiKey *APIKey) error { + return p.normalizeError(sqlCommonUpdateAPIKey(apiKey, p.dbHandle), -1) +} + +func (p *PGSQLProvider) deleteAPIKey(apiKey APIKey) error { + return sqlCommonDeleteAPIKey(apiKey, p.dbHandle) +} + +func (p *PGSQLProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { + return sqlCommonGetAPIKeys(limit, offset, order, p.dbHandle) +} + +func (p *PGSQLProvider) dumpAPIKeys() ([]APIKey, error) { + return sqlCommonDumpAPIKeys(p.dbHandle) +} + +func (p *PGSQLProvider) updateAPIKeyLastUse(keyID string) error { + return sqlCommonUpdateAPIKeyLastUse(keyID, p.dbHandle) +} + +func (p *PGSQLProvider) shareExists(shareID, username string) (Share, error) { + return sqlCommonGetShareByID(shareID, username, p.dbHandle) +} + +func (p *PGSQLProvider) addShare(share *Share) error { + return p.normalizeError(sqlCommonAddShare(share, p.dbHandle), fieldName) +} + +func (p *PGSQLProvider) updateShare(share *Share) error { + return p.normalizeError(sqlCommonUpdateShare(share, p.dbHandle), -1) +} + +func (p *PGSQLProvider) deleteShare(share Share) error { + return sqlCommonDeleteShare(share, p.dbHandle) +} + +func (p *PGSQLProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { + return sqlCommonGetShares(limit, offset, order, username, p.dbHandle) +} + +func (p *PGSQLProvider) dumpShares() ([]Share, error) { + return sqlCommonDumpShares(p.dbHandle) +} + +func (p *PGSQLProvider) updateShareLastUse(shareID string, numTokens int) error { + return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle) +} + +func (p *PGSQLProvider) getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { + return sqlCommonGetDefenderHosts(from, limit, p.dbHandle) +} + +func (p *PGSQLProvider) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { + return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle) +} + +func (p *PGSQLProvider) isDefenderHostBanned(ip string) (DefenderEntry, error) { + return sqlCommonIsDefenderHostBanned(ip, p.dbHandle) +} + +func (p *PGSQLProvider) updateDefenderBanTime(ip string, minutes int) error { + return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle) +} + +func (p *PGSQLProvider) deleteDefenderHost(ip string) error { + return sqlCommonDeleteDefenderHost(ip, p.dbHandle) +} + +func (p *PGSQLProvider) addDefenderEvent(ip string, score int) error { + return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle) +} + +func (p *PGSQLProvider) setDefenderBanTime(ip string, banTime int64) error { + return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle) +} + +func (p *PGSQLProvider) cleanupDefender(from int64) error { + return sqlCommonDefenderCleanup(from, p.dbHandle) +} + +func (p *PGSQLProvider) addActiveTransfer(transfer ActiveTransfer) error { + return sqlCommonAddActiveTransfer(transfer, p.dbHandle) +} + +func (p *PGSQLProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) +} + +func (p *PGSQLProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) +} + +func (p *PGSQLProvider) cleanupActiveTransfers(before time.Time) error { + return sqlCommonCleanupActiveTransfers(before, p.dbHandle) +} + +func (p *PGSQLProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return sqlCommonGetActiveTransfers(from, p.dbHandle) +} + +func (p *PGSQLProvider) addSharedSession(session Session) error { + return sqlCommonAddSession(session, p.dbHandle) +} + +func (p *PGSQLProvider) deleteSharedSession(key string, sessionType SessionType) error { + return sqlCommonDeleteSession(key, sessionType, p.dbHandle) +} + +func (p *PGSQLProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { + return sqlCommonGetSession(key, sessionType, p.dbHandle) +} + +func (p *PGSQLProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) +} + +func (p *PGSQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { + return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle) +} + +func (p *PGSQLProvider) dumpEventActions() ([]BaseEventAction, error) { + return sqlCommonDumpEventActions(p.dbHandle) +} + +func (p *PGSQLProvider) eventActionExists(name string) (BaseEventAction, error) { + return sqlCommonGetEventActionByName(name, p.dbHandle) +} + +func (p *PGSQLProvider) addEventAction(action *BaseEventAction) error { + return p.normalizeError(sqlCommonAddEventAction(action, p.dbHandle), fieldName) +} + +func (p *PGSQLProvider) updateEventAction(action *BaseEventAction) error { + return sqlCommonUpdateEventAction(action, p.dbHandle) +} + +func (p *PGSQLProvider) deleteEventAction(action BaseEventAction) error { + return sqlCommonDeleteEventAction(action, p.dbHandle) +} + +func (p *PGSQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { + return sqlCommonGetEventRules(limit, offset, order, p.dbHandle) +} + +func (p *PGSQLProvider) dumpEventRules() ([]EventRule, error) { + return sqlCommonDumpEventRules(p.dbHandle) +} + +func (p *PGSQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { + return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle) +} + +func (p *PGSQLProvider) eventRuleExists(name string) (EventRule, error) { + return sqlCommonGetEventRuleByName(name, p.dbHandle) +} + +func (p *PGSQLProvider) addEventRule(rule *EventRule) error { + return p.normalizeError(sqlCommonAddEventRule(rule, p.dbHandle), fieldName) +} + +func (p *PGSQLProvider) updateEventRule(rule *EventRule) error { + return sqlCommonUpdateEventRule(rule, p.dbHandle) +} + +func (p *PGSQLProvider) deleteEventRule(rule EventRule, softDelete bool) error { + return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle) +} + +func (p *PGSQLProvider) getTaskByName(name string) (Task, error) { + return sqlCommonGetTaskByName(name, p.dbHandle) +} + +func (p *PGSQLProvider) addTask(name string) error { + return sqlCommonAddTask(name, p.dbHandle) +} + +func (p *PGSQLProvider) updateTask(name string, version int64) error { + return sqlCommonUpdateTask(name, version, p.dbHandle) +} + +func (p *PGSQLProvider) updateTaskTimestamp(name string) error { + return sqlCommonUpdateTaskTimestamp(name, p.dbHandle) +} + +func (p *PGSQLProvider) addNode() error { + return sqlCommonAddNode(p.dbHandle) +} + +func (p *PGSQLProvider) getNodeByName(name string) (Node, error) { + return sqlCommonGetNodeByName(name, p.dbHandle) +} + +func (p *PGSQLProvider) getNodes() ([]Node, error) { + return sqlCommonGetNodes(p.dbHandle) +} + +func (p *PGSQLProvider) updateNodeTimestamp() error { + return sqlCommonUpdateNodeTimestamp(p.dbHandle) +} + +func (p *PGSQLProvider) cleanupNodes() error { + return sqlCommonCleanupNodes(p.dbHandle) +} + +func (p *PGSQLProvider) roleExists(name string) (Role, error) { + return sqlCommonGetRoleByName(name, p.dbHandle) +} + +func (p *PGSQLProvider) addRole(role *Role) error { + return p.normalizeError(sqlCommonAddRole(role, p.dbHandle), fieldName) +} + +func (p *PGSQLProvider) updateRole(role *Role) error { + return sqlCommonUpdateRole(role, p.dbHandle) +} + +func (p *PGSQLProvider) deleteRole(role Role) error { + return sqlCommonDeleteRole(role, p.dbHandle) +} + +func (p *PGSQLProvider) getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) { + return sqlCommonGetRoles(limit, offset, order, minimal, p.dbHandle) +} + +func (p *PGSQLProvider) dumpRoles() ([]Role, error) { + return sqlCommonDumpRoles(p.dbHandle) +} + +func (p *PGSQLProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { + return sqlCommonGetIPListEntry(ipOrNet, listType, p.dbHandle) +} + +func (p *PGSQLProvider) addIPListEntry(entry *IPListEntry) error { + return p.normalizeError(sqlCommonAddIPListEntry(entry, p.dbHandle), fieldIPNet) +} + +func (p *PGSQLProvider) updateIPListEntry(entry *IPListEntry) error { + return sqlCommonUpdateIPListEntry(entry, p.dbHandle) +} + +func (p *PGSQLProvider) deleteIPListEntry(entry IPListEntry, softDelete bool) error { + return sqlCommonDeleteIPListEntry(entry, softDelete, p.dbHandle) +} + +func (p *PGSQLProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { + return sqlCommonGetIPListEntries(listType, filter, from, order, limit, p.dbHandle) +} + +func (p *PGSQLProvider) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) { + return sqlCommonGetRecentlyUpdatedIPListEntries(after, p.dbHandle) +} + +func (p *PGSQLProvider) dumpIPListEntries() ([]IPListEntry, error) { + return sqlCommonDumpIPListEntries(p.dbHandle) +} + +func (p *PGSQLProvider) countIPListEntries(listType IPListType) (int64, error) { + return sqlCommonCountIPListEntries(listType, p.dbHandle) +} + +func (p *PGSQLProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { + return sqlCommonGetListEntriesForIP(ip, listType, p.dbHandle) +} + +func (p *PGSQLProvider) getConfigs() (Configs, error) { + return sqlCommonGetConfigs(p.dbHandle) +} + +func (p *PGSQLProvider) setConfigs(configs *Configs) error { + return sqlCommonSetConfigs(configs, p.dbHandle) +} + +func (p *PGSQLProvider) setFirstDownloadTimestamp(username string) error { + return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle) +} + +func (p *PGSQLProvider) setFirstUploadTimestamp(username string) error { + return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle) +} + +func (p *PGSQLProvider) close() error { + return p.dbHandle.Close() +} + +func (p *PGSQLProvider) reloadConfig() error { + return nil +} + +// initializeDatabase creates the initial database structure +func (p *PGSQLProvider) initializeDatabase() error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false) + if err == nil && dbVersion.Version > 0 { + return ErrNoInitRequired + } + if errors.Is(err, sql.ErrNoRows) { + return errSchemaVersionEmpty + } + logger.InfoToConsole("creating initial database schema, version 33") + providerLog(logger.LevelInfo, "creating initial database schema, version 33") + var initialSQL string + if config.Driver == CockroachDataProviderName { + initialSQL = sqlReplaceAll(pgsqlInitial) + initialSQL = strings.ReplaceAll(initialSQL, "GENERATED ALWAYS AS IDENTITY", "DEFAULT unordered_unique_rowid()") + } else { + initialSQL = sqlReplaceAll(pgsqlInitial + ipListsLikeIndex) + } + + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 33, true) +} + +func (p *PGSQLProvider) migrateDatabase() error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) + if err != nil { + return err + } + + switch version := dbVersion.Version; { + case version == sqlDatabaseVersion: + providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version) + return ErrNoInitRequired + case version < 33: + err = errSchemaVersionTooOld(version) + providerLog(logger.LevelError, "%v", err) + logger.ErrorToConsole("%v", err) + return err + case version == 33: + return updatePGSQLDatabaseFromV33(p.dbHandle) + default: + if version > sqlDatabaseVersion { + providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, + sqlDatabaseVersion) + logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, + sqlDatabaseVersion) + return nil + } + return fmt.Errorf("database schema version not handled: %d", version) + } +} + +func (p *PGSQLProvider) revertDatabase(targetVersion int) error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) + if err != nil { + return err + } + if dbVersion.Version == targetVersion { + return errors.New("current version match target version, nothing to do") + } + + switch dbVersion.Version { + case 34: + return downgradePGSQLDatabaseFromV34(p.dbHandle) + default: + return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) + } +} + +func (p *PGSQLProvider) resetDatabase() error { + sql := sqlReplaceAll(pgsqlResetSQL) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) +} + +func (p *PGSQLProvider) normalizeError(err error, fieldType int) error { + if err == nil { + return nil + } + var pgsqlErr *pgconn.PgError + if errors.As(err, &pgsqlErr) { + switch pgsqlErr.Code { + case "23505": + var message string + switch fieldType { + case fieldUsername: + message = util.I18nErrorDuplicatedUsername + case fieldIPNet: + message = util.I18nErrorDuplicatedIPNet + default: + message = util.I18nErrorDuplicatedName + } + return util.NewI18nError( + fmt.Errorf("%w: %s", ErrDuplicatedKey, err.Error()), + message, + ) + case "23503": + return fmt.Errorf("%w: %s", ErrForeignKeyViolated, err.Error()) + } + } + return err +} + +func updatePGSQLDatabaseFromV33(dbHandle *sql.DB) error { + return updatePGSQLDatabaseFrom33To34(dbHandle) +} + +func downgradePGSQLDatabaseFromV34(dbHandle *sql.DB) error { + return downgradePGSQLDatabaseFrom34To33(dbHandle) +} + +func updatePGSQLDatabaseFrom33To34(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 33 -> 34") + providerLog(logger.LevelInfo, "updating database schema version: 33 -> 34") + + sql := strings.ReplaceAll(pgsqlV34SQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 34, true) +} + +func downgradePGSQLDatabaseFrom34To33(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 34 -> 33") + providerLog(logger.LevelInfo, "downgrading database schema version: 34 -> 33") + + sql := strings.ReplaceAll(pgsqlV34DownSQL, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 33, false) +} diff --git a/internal/dataprovider/pgsql_disabled.go b/internal/dataprovider/pgsql_disabled.go new file mode 100644 index 00000000..899b5380 --- /dev/null +++ b/internal/dataprovider/pgsql_disabled.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nopgsql + +package dataprovider + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-pgsql") +} + +func initializePGSQLProvider() error { + return errors.New("PostgreSQL disabled at build time") +} diff --git a/internal/dataprovider/quotaupdater.go b/internal/dataprovider/quotaupdater.go new file mode 100644 index 00000000..e486bf85 --- /dev/null +++ b/internal/dataprovider/quotaupdater.go @@ -0,0 +1,261 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +var delayedQuotaUpdater quotaUpdater + +func init() { + delayedQuotaUpdater = newQuotaUpdater() +} + +type quotaObject struct { + size int64 + files int +} + +type transferQuotaObject struct { + ulSize int64 + dlSize int64 +} + +type quotaUpdater struct { + paramsMutex sync.RWMutex + waitTime time.Duration + sync.RWMutex + pendingUserQuotaUpdates map[string]quotaObject + pendingFolderQuotaUpdates map[string]quotaObject + pendingTransferQuotaUpdates map[string]transferQuotaObject +} + +func newQuotaUpdater() quotaUpdater { + return quotaUpdater{ + pendingUserQuotaUpdates: make(map[string]quotaObject), + pendingFolderQuotaUpdates: make(map[string]quotaObject), + pendingTransferQuotaUpdates: make(map[string]transferQuotaObject), + } +} + +func (q *quotaUpdater) start() { + q.setWaitTime(config.DelayedQuotaUpdate) + + go q.loop() +} + +func (q *quotaUpdater) loop() { + waitTime := q.getWaitTime() + providerLog(logger.LevelDebug, "delayed quota update loop started, wait time: %v", waitTime) + for waitTime > 0 { + // We do this with a time.Sleep instead of a time.Ticker because we don't know + // how long each quota processing cycle will take, and we want to make + // sure we wait the configured seconds between each iteration + time.Sleep(waitTime) + providerLog(logger.LevelDebug, "delayed quota update check start") + q.storeUsersQuota() + q.storeFoldersQuota() + q.storeUsersTransferQuota() + providerLog(logger.LevelDebug, "delayed quota update check end") + waitTime = q.getWaitTime() + } + providerLog(logger.LevelDebug, "delayed quota update loop ended, wait time: %v", waitTime) +} + +func (q *quotaUpdater) setWaitTime(secs int) { + q.paramsMutex.Lock() + defer q.paramsMutex.Unlock() + + q.waitTime = time.Duration(secs) * time.Second +} + +func (q *quotaUpdater) getWaitTime() time.Duration { + q.paramsMutex.RLock() + defer q.paramsMutex.RUnlock() + + return q.waitTime +} + +func (q *quotaUpdater) resetUserQuota(username string) { + q.Lock() + defer q.Unlock() + + delete(q.pendingUserQuotaUpdates, username) +} + +func (q *quotaUpdater) updateUserQuota(username string, files int, size int64) { + q.Lock() + defer q.Unlock() + + obj := q.pendingUserQuotaUpdates[username] + obj.size += size + obj.files += files + if obj.files == 0 && obj.size == 0 { + delete(q.pendingUserQuotaUpdates, username) + return + } + q.pendingUserQuotaUpdates[username] = obj +} + +func (q *quotaUpdater) getUserPendingQuota(username string) (int, int64) { + q.RLock() + defer q.RUnlock() + + obj := q.pendingUserQuotaUpdates[username] + + return obj.files, obj.size +} + +func (q *quotaUpdater) resetFolderQuota(name string) { + q.Lock() + defer q.Unlock() + + delete(q.pendingFolderQuotaUpdates, name) +} + +func (q *quotaUpdater) updateFolderQuota(name string, files int, size int64) { + q.Lock() + defer q.Unlock() + + obj := q.pendingFolderQuotaUpdates[name] + obj.size += size + obj.files += files + if obj.files == 0 && obj.size == 0 { + delete(q.pendingFolderQuotaUpdates, name) + return + } + q.pendingFolderQuotaUpdates[name] = obj +} + +func (q *quotaUpdater) getFolderPendingQuota(name string) (int, int64) { + q.RLock() + defer q.RUnlock() + + obj := q.pendingFolderQuotaUpdates[name] + + return obj.files, obj.size +} + +func (q *quotaUpdater) resetUserTransferQuota(username string) { + q.Lock() + defer q.Unlock() + + delete(q.pendingTransferQuotaUpdates, username) +} + +func (q *quotaUpdater) updateUserTransferQuota(username string, ulSize, dlSize int64) { + q.Lock() + defer q.Unlock() + + obj := q.pendingTransferQuotaUpdates[username] + obj.ulSize += ulSize + obj.dlSize += dlSize + if obj.ulSize == 0 && obj.dlSize == 0 { + delete(q.pendingTransferQuotaUpdates, username) + return + } + q.pendingTransferQuotaUpdates[username] = obj +} + +func (q *quotaUpdater) getUserPendingTransferQuota(username string) (int64, int64) { + q.RLock() + defer q.RUnlock() + + obj := q.pendingTransferQuotaUpdates[username] + + return obj.ulSize, obj.dlSize +} + +func (q *quotaUpdater) getUsernames() []string { + q.RLock() + defer q.RUnlock() + + result := make([]string, 0, len(q.pendingUserQuotaUpdates)) + for username := range q.pendingUserQuotaUpdates { + result = append(result, username) + } + + return result +} + +func (q *quotaUpdater) getFoldernames() []string { + q.RLock() + defer q.RUnlock() + + result := make([]string, 0, len(q.pendingFolderQuotaUpdates)) + for name := range q.pendingFolderQuotaUpdates { + result = append(result, name) + } + + return result +} + +func (q *quotaUpdater) getTransferQuotaUsernames() []string { + q.RLock() + defer q.RUnlock() + + result := make([]string, 0, len(q.pendingTransferQuotaUpdates)) + for username := range q.pendingTransferQuotaUpdates { + result = append(result, username) + } + + return result +} + +func (q *quotaUpdater) storeUsersQuota() { + for _, username := range q.getUsernames() { + files, size := q.getUserPendingQuota(username) + if size != 0 || files != 0 { + err := provider.updateQuota(username, files, size, false) + if err != nil { + providerLog(logger.LevelWarn, "unable to update quota delayed for user %q: %v", username, err) + continue + } + q.updateUserQuota(username, -files, -size) + } + } +} + +func (q *quotaUpdater) storeFoldersQuota() { + for _, name := range q.getFoldernames() { + files, size := q.getFolderPendingQuota(name) + if size != 0 || files != 0 { + err := provider.updateFolderQuota(name, files, size, false) + if err != nil { + providerLog(logger.LevelWarn, "unable to update quota delayed for folder %q: %v", name, err) + continue + } + q.updateFolderQuota(name, -files, -size) + } + } +} + +func (q *quotaUpdater) storeUsersTransferQuota() { + for _, username := range q.getTransferQuotaUsernames() { + ulSize, dlSize := q.getUserPendingTransferQuota(username) + if ulSize != 0 || dlSize != 0 { + err := provider.updateTransferQuota(username, ulSize, dlSize, false) + if err != nil { + providerLog(logger.LevelWarn, "unable to update transfer quota delayed for user %q: %v", username, err) + continue + } + q.updateUserTransferQuota(username, -ulSize, -dlSize) + } + } +} diff --git a/internal/dataprovider/role.go b/internal/dataprovider/role.go new file mode 100644 index 00000000..5f1154e0 --- /dev/null +++ b/internal/dataprovider/role.go @@ -0,0 +1,90 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "fmt" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// Role defines an SFTPGo role. +type Role struct { + // Data provider unique identifier + ID int64 `json:"id"` + // Role name + Name string `json:"name"` + // optional description + Description string `json:"description,omitempty"` + // Creation time as unix timestamp in milliseconds + CreatedAt int64 `json:"created_at"` + // last update time as unix timestamp in milliseconds + UpdatedAt int64 `json:"updated_at"` + // list of admins associated with this role + Admins []string `json:"admins,omitempty"` + // list of usernames associated with this role + Users []string `json:"users,omitempty"` +} + +// RenderAsJSON implements the renderer interface used within plugins +func (r *Role) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + role, err := provider.roleExists(r.Name) + if err != nil { + providerLog(logger.LevelError, "unable to reload role before rendering as json: %v", err) + return nil, err + } + return json.Marshal(role) + } + return json.Marshal(r) +} + +func (r *Role) validate() error { + if r.Name == "" { + return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) + } + if !util.IsNameValid(r.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if len(r.Name) > 255 { + return util.NewValidationError("name is too long, 255 is the maximum length allowed") + } + if config.NamingRules&1 == 0 && !usernameRegex.MatchString(r.Name) { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("name %q is not valid, the following characters are allowed: a-zA-Z0-9-_.~", r.Name)), + util.I18nErrorInvalidName, + ) + } + return nil +} + +func (r *Role) getACopy() Role { + users := make([]string, len(r.Users)) + copy(users, r.Users) + admins := make([]string, len(r.Admins)) + copy(admins, r.Admins) + + return Role{ + ID: r.ID, + Name: r.Name, + Description: r.Description, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + Users: users, + Admins: admins, + } +} diff --git a/internal/dataprovider/scheduler.go b/internal/dataprovider/scheduler.go new file mode 100644 index 00000000..085066a5 --- /dev/null +++ b/internal/dataprovider/scheduler.go @@ -0,0 +1,200 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "fmt" + "sync/atomic" + "time" + + "github.com/robfig/cron/v3" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + scheduler *cron.Cron + lastUserCacheUpdate atomic.Int64 + lastIPListsCacheUpdate atomic.Int64 + // used for bolt and memory providers, so we avoid iterating all users/rules + // to find recently modified ones + lastUserUpdate atomic.Int64 + lastRuleUpdate atomic.Int64 +) + +func stopScheduler() { + if scheduler != nil { + scheduler.Stop() + scheduler = nil + } +} + +func startScheduler() error { + stopScheduler() + + scheduler = cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger)) + _, err := scheduler.AddFunc("@every 55s", checkDataprovider) + if err != nil { + return fmt.Errorf("unable to schedule dataprovider availability check: %w", err) + } + err = addScheduledCacheUpdates() + if err != nil { + return err + } + if currentNode != nil { + _, err = scheduler.AddFunc("@every 30m", func() { + err := provider.cleanupNodes() + if err != nil { + providerLog(logger.LevelError, "unable to cleanup nodes: %v", err) + } else { + providerLog(logger.LevelDebug, "cleanup nodes ok") + } + }) + } + if err != nil { + return fmt.Errorf("unable to schedule nodes cleanup: %w", err) + } + scheduler.Start() + return nil +} + +func addScheduledCacheUpdates() error { + lastUserCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) + lastIPListsCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) + _, err := scheduler.AddFunc("@every 10m", checkCacheUpdates) + if err != nil { + return fmt.Errorf("unable to schedule cache updates: %w", err) + } + return nil +} + +func checkDataprovider() { + if currentNode != nil { + err := provider.updateNodeTimestamp() + if err != nil { + providerLog(logger.LevelError, "unable to update node timestamp: %v", err) + } else { + providerLog(logger.LevelDebug, "node timestamp updated") + } + metric.UpdateDataProviderAvailability(err) + return + } + err := provider.checkAvailability() + if err != nil { + providerLog(logger.LevelError, "check availability error: %v", err) + } + metric.UpdateDataProviderAvailability(err) +} + +func checkCacheUpdates() { + checkUserCache() + checkIPListEntryCache() + cachedUserPasswords.cleanup() + cachedAdminPasswords.cleanup() + cachedAPIKeys.cleanup() +} + +func checkUserCache() { + lastCheck := lastUserCacheUpdate.Load() + providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastCheck)) + checkTime := util.GetTimeAsMsSinceEpoch(time.Now()) + if config.IsShared == 1 { + lastCheck -= 5000 + } + users, err := provider.getRecentlyUpdatedUsers(lastCheck) + if err != nil { + providerLog(logger.LevelError, "unable to get recently updated users: %v", err) + return + } + for idx := range users { + user := users[idx] + providerLog(logger.LevelDebug, "invalidate caches for user %q", user.Username) + if user.DeletedAt > 0 { + deletedAt := util.GetTimeFromMsecSinceEpoch(user.DeletedAt) + if deletedAt.Add(30 * time.Minute).Before(time.Now()) { + providerLog(logger.LevelDebug, "removing user %q deleted at %s", user.Username, deletedAt) + go provider.deleteUser(user, false) //nolint:errcheck + } + webDAVUsersCache.remove(user.Username) + cachedUserPasswords.Remove(user.Username) + delayedQuotaUpdater.resetUserQuota(user.Username) + } else { + webDAVUsersCache.swap(&user, "") + } + } + lastUserCacheUpdate.Store(checkTime) + providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load())) +} + +func checkIPListEntryCache() { + if config.IsShared != 1 { + return + } + hasMemoryLists := false + for _, l := range inMemoryLists { + if l.isInMemory.Load() { + hasMemoryLists = true + break + } + } + if !hasMemoryLists { + return + } + providerLog(logger.LevelDebug, "start IP list cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastIPListsCacheUpdate.Load())) + checkTime := util.GetTimeAsMsSinceEpoch(time.Now()) + entries, err := provider.getRecentlyUpdatedIPListEntries(lastIPListsCacheUpdate.Load() - 5000) + if err != nil { + providerLog(logger.LevelError, "unable to get recently updated IP list entries: %v", err) + return + } + for idx := range entries { + e := entries[idx] + providerLog(logger.LevelDebug, "update cache for IP list entry %q", e.getName()) + if e.DeletedAt > 0 { + deletedAt := util.GetTimeFromMsecSinceEpoch(e.DeletedAt) + if deletedAt.Add(30 * time.Minute).Before(time.Now()) { + providerLog(logger.LevelDebug, "removing IP list entry %q deleted at %s", e.getName(), deletedAt) + go provider.deleteIPListEntry(e, false) //nolint:errcheck + } + for _, l := range inMemoryLists { + l.removeEntry(&e) + } + } else { + for _, l := range inMemoryLists { + l.updateEntry(&e) + } + } + } + lastIPListsCacheUpdate.Store(checkTime) + providerLog(logger.LevelDebug, "end IP list entries cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastIPListsCacheUpdate.Load())) +} + +func setLastUserUpdate() { + lastUserUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) +} + +func getLastUserUpdate() int64 { + return lastUserUpdate.Load() +} + +func setLastRuleUpdate() { + lastRuleUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) +} + +func getLastRuleUpdate() int64 { + return lastRuleUpdate.Load() +} diff --git a/internal/dataprovider/session.go b/internal/dataprovider/session.go new file mode 100644 index 00000000..a40c42d8 --- /dev/null +++ b/internal/dataprovider/session.go @@ -0,0 +1,51 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "errors" + "fmt" +) + +// SessionType defines the supported session types +type SessionType int + +// Supported session types +const ( + SessionTypeOIDCAuth SessionType = iota + 1 + SessionTypeOIDCToken + SessionTypeResetCode + SessionTypeOAuth2Auth + SessionTypeInvalidToken + SessionTypeWebTask +) + +// Session defines a shared session persisted in the data provider +type Session struct { + Key string + Data any + Type SessionType + Timestamp int64 +} + +func (s *Session) validate() error { + if s.Key == "" { + return errors.New("unable to save a session with an empty key") + } + if s.Type < SessionTypeOIDCAuth || s.Type > SessionTypeWebTask { + return fmt.Errorf("invalid session type: %v", s.Type) + } + return nil +} diff --git a/internal/dataprovider/share.go b/internal/dataprovider/share.go new file mode 100644 index 00000000..6d2881b0 --- /dev/null +++ b/internal/dataprovider/share.go @@ -0,0 +1,313 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "fmt" + "net" + "strings" + "time" + + "github.com/alexedwards/argon2id" + passwordvalidator "github.com/wagslane/go-password-validator" + "golang.org/x/crypto/bcrypt" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// ShareScope defines the supported share scopes +type ShareScope int + +// Supported share scopes +const ( + ShareScopeRead ShareScope = iota + 1 + ShareScopeWrite + ShareScopeReadWrite +) + +const ( + redactedPassword = "[**redacted**]" +) + +// Share defines files and or directories shared with external users +type Share struct { + // Database unique identifier + ID int64 `json:"-"` + // Unique ID used to access this object + ShareID string `json:"id"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Scope ShareScope `json:"scope"` + // Paths to files or directories, for ShareScopeWrite it must be exactly one directory + Paths []string `json:"paths"` + // Username who shared this object + Username string `json:"username"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + // 0 means never used + LastUseAt int64 `json:"last_use_at,omitempty"` + // ExpiresAt expiration date/time as unix timestamp in milliseconds, 0 means no expiration + ExpiresAt int64 `json:"expires_at,omitempty"` + // Optional password to protect the share + Password string `json:"password"` + // Limit the available access tokens, 0 means no limit + MaxTokens int `json:"max_tokens,omitempty"` + // Used tokens + UsedTokens int `json:"used_tokens,omitempty"` + // Limit the share availability to these IPs/CIDR networks + AllowFrom []string `json:"allow_from,omitempty"` + // set for restores, we don't have to validate the expiration date + // otherwise we fail to restore existing shares and we have to insert + // all the previous values with no modifications + IsRestore bool `json:"-"` +} + +// IsExpired returns true if the share is expired +func (s *Share) IsExpired() bool { + if s.ExpiresAt > 0 { + return s.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) + } + return false +} + +// GetAllowedFromAsString returns the allowed IP as comma separated string +func (s *Share) GetAllowedFromAsString() string { + return strings.Join(s.AllowFrom, ",") +} + +// IsPasswordHashed returns true if the password is hashed +func (s *Share) IsPasswordHashed() bool { + return util.IsStringPrefixInSlice(s.Password, hashPwdPrefixes) +} + +func (s *Share) getACopy() Share { + allowFrom := make([]string, len(s.AllowFrom)) + copy(allowFrom, s.AllowFrom) + + return Share{ + ID: s.ID, + ShareID: s.ShareID, + Name: s.Name, + Description: s.Description, + Scope: s.Scope, + Paths: s.Paths, + Username: s.Username, + CreatedAt: s.CreatedAt, + UpdatedAt: s.UpdatedAt, + LastUseAt: s.LastUseAt, + ExpiresAt: s.ExpiresAt, + Password: s.Password, + MaxTokens: s.MaxTokens, + UsedTokens: s.UsedTokens, + AllowFrom: allowFrom, + } +} + +// RenderAsJSON implements the renderer interface used within plugins +func (s *Share) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + share, err := provider.shareExists(s.ShareID, s.Username) + if err != nil { + providerLog(logger.LevelError, "unable to reload share before rendering as json: %v", err) + return nil, err + } + share.HideConfidentialData() + return json.Marshal(share) + } + s.HideConfidentialData() + return json.Marshal(s) +} + +// HideConfidentialData hides share confidential data +func (s *Share) HideConfidentialData() { + if s.Password != "" { + s.Password = redactedPassword + } +} + +// HasRedactedPassword returns true if this share has a redacted password +func (s *Share) HasRedactedPassword() bool { + return s.Password == redactedPassword +} + +func (s *Share) hashPassword() error { + if s.Password != "" && !util.IsStringPrefixInSlice(s.Password, internalHashPwdPrefixes) { + user, err := GetUserWithGroupSettings(s.Username, "") + if err != nil { + return util.NewGenericError(fmt.Sprintf("unable to validate user: %v", err)) + } + if minEntropy := user.getMinPasswordEntropy(); minEntropy > 0 { + if err := passwordvalidator.Validate(s.Password, minEntropy); err != nil { + return util.NewI18nError(util.NewValidationError(err.Error()), util.I18nErrorPasswordComplexity) + } + } + if config.PasswordHashing.Algo == HashingAlgoBcrypt { + hashed, err := bcrypt.GenerateFromPassword([]byte(s.Password), config.PasswordHashing.BcryptOptions.Cost) + if err != nil { + return err + } + s.Password = util.BytesToString(hashed) + } else { + hashed, err := argon2id.CreateHash(s.Password, argon2Params) + if err != nil { + return err + } + s.Password = hashed + } + } + return nil +} + +func (s *Share) validatePaths() error { + var paths []string + for _, p := range s.Paths { + if strings.TrimSpace(p) != "" { + paths = append(paths, p) + } + } + s.Paths = paths + if len(s.Paths) == 0 { + return util.NewI18nError(util.NewValidationError("at least a shared path is required"), util.I18nErrorSharePathRequired) + } + for idx := range s.Paths { + s.Paths[idx] = util.CleanPath(s.Paths[idx]) + } + s.Paths = util.RemoveDuplicates(s.Paths, false) + if s.Scope >= ShareScopeWrite && len(s.Paths) != 1 { + return util.NewI18nError(util.NewValidationError("the write share scope requires exactly one path"), util.I18nErrorShareWriteScope) + } + // check nested paths + if len(s.Paths) > 1 { + for idx := range s.Paths { + for innerIdx := range s.Paths { + if idx == innerIdx { + continue + } + if s.Paths[idx] == "/" || s.Paths[innerIdx] == "/" || util.IsDirOverlapped(s.Paths[idx], s.Paths[innerIdx], true, "/") { + return util.NewI18nError(util.NewGenericError("shared paths cannot be nested"), util.I18nErrorShareNestedPaths) + } + } + } + } + return nil +} + +func (s *Share) validate() error { //nolint:gocyclo + if s.ShareID == "" { + return util.NewValidationError("share_id is mandatory") + } + if s.Name == "" { + return util.NewI18nError(util.NewValidationError("name is mandatory"), util.I18nErrorNameRequired) + } + if !util.IsNameValid(s.Name) { + return util.NewI18nError(errInvalidInput, util.I18nErrorInvalidInput) + } + if s.Scope < ShareScopeRead || s.Scope > ShareScopeReadWrite { + return util.NewI18nError(util.NewValidationError(fmt.Sprintf("invalid scope: %v", s.Scope)), util.I18nErrorShareScope) + } + if err := s.validatePaths(); err != nil { + return err + } + if s.ExpiresAt > 0 { + if !s.IsRestore && s.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) { + return util.NewI18nError(util.NewValidationError("expiration must be in the future"), util.I18nErrorShareExpirationPast) + } + } else { + s.ExpiresAt = 0 + } + if s.MaxTokens < 0 { + return util.NewI18nError(util.NewValidationError("invalid max tokens"), util.I18nErrorShareMaxTokens) + } + if s.Username == "" { + return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) + } + if s.HasRedactedPassword() { + return util.NewValidationError("cannot save a share with a redacted password") + } + if err := s.hashPassword(); err != nil { + return err + } + s.AllowFrom = util.RemoveDuplicates(s.AllowFrom, false) + for _, IPMask := range s.AllowFrom { + _, _, err := net.ParseCIDR(IPMask) + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not parse allow from entry %q : %v", IPMask, err)), + util.I18nErrorInvalidIPMask, + ) + } + } + return nil +} + +// CheckCredentials verifies the share credentials if a password if set +func (s *Share) CheckCredentials(password string) (bool, error) { + if s.Password == "" { + return true, nil + } + if password == "" { + return false, ErrInvalidCredentials + } + if strings.HasPrefix(s.Password, bcryptPwdPrefix) { + if err := bcrypt.CompareHashAndPassword([]byte(s.Password), []byte(password)); err != nil { + return false, ErrInvalidCredentials + } + return true, nil + } + match, err := argon2id.ComparePasswordAndHash(password, s.Password) + if !match || err != nil { + return false, ErrInvalidCredentials + } + return match, err +} + +// GetRelativePath returns the specified absolute path as relative to the share base path +func (s *Share) GetRelativePath(name string) string { + if len(s.Paths) == 0 { + return "" + } + return util.CleanPath(strings.TrimPrefix(name, s.Paths[0])) +} + +// IsUsable checks if the share is usable from the specified IP +func (s *Share) IsUsable(ip string) (bool, error) { + if s.MaxTokens > 0 && s.UsedTokens >= s.MaxTokens { + return false, util.NewI18nError(util.NewRecordNotFoundError("max share usage exceeded"), util.I18nErrorShareUsage) + } + if s.ExpiresAt > 0 { + if s.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) { + return false, util.NewI18nError(util.NewRecordNotFoundError("share expired"), util.I18nErrorShareExpired) + } + } + if len(s.AllowFrom) == 0 { + return true, nil + } + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + return false, util.NewI18nError(ErrLoginNotAllowedFromIP, util.I18nErrorLoginFromIPDenied) + } + for _, ipMask := range s.AllowFrom { + _, network, err := net.ParseCIDR(ipMask) + if err != nil { + continue + } + if network.Contains(parsedIP) { + return true, nil + } + } + return false, util.NewI18nError(ErrLoginNotAllowedFromIP, util.I18nErrorLoginFromIPDenied) +} diff --git a/internal/dataprovider/sqlcommon.go b/internal/dataprovider/sqlcommon.go new file mode 100644 index 00000000..768adb31 --- /dev/null +++ b/internal/dataprovider/sqlcommon.go @@ -0,0 +1,4093 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "context" + "crypto/x509" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/netip" + "runtime/debug" + "strconv" + "strings" + "time" + + "github.com/cockroachdb/cockroach-go/v2/crdb" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + sqlDatabaseVersion = 34 + defaultSQLQueryTimeout = 10 * time.Second + longSQLQueryTimeout = 60 * time.Second +) + +var ( + errSQLFoldersAssociation = errors.New("unable to associate virtual folders to user") + errSQLGroupsAssociation = errors.New("unable to associate groups to user") + errSQLUsersAssociation = errors.New("unable to associate users to group") + errSchemaVersionEmpty = errors.New("we can't determine schema version because the schema_migration table is empty. The SFTPGo database might be corrupted. Consider using the \"resetprovider\" sub-command") +) + +type sqlQuerier interface { + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) +} + +type sqlScanner interface { + Scan(dest ...any) error +} + +func sqlReplaceAll(sql string) string { + sql = strings.ReplaceAll(sql, "{{schema_version}}", sqlTableSchemaVersion) + sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins) + sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders) + sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) + sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) + sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping) + sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping) + sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping) + sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents) + sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts) + sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers) + sql = strings.ReplaceAll(sql, "{{shared_sessions}}", sqlTableSharedSessions) + sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions) + sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules) + sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping) + sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks) + sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes) + sql = strings.ReplaceAll(sql, "{{roles}}", sqlTableRoles) + sql = strings.ReplaceAll(sql, "{{ip_lists}}", sqlTableIPLists) + sql = strings.ReplaceAll(sql, "{{configs}}", sqlTableConfigs) + sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) + return sql +} + +func sqlCommonGetShareByID(shareID, username string, dbHandle sqlQuerier) (Share, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + filterUser := username != "" + q := getShareByIDQuery(filterUser) + + var row *sql.Row + if filterUser { + row = dbHandle.QueryRowContext(ctx, q, shareID, username) + } else { + row = dbHandle.QueryRowContext(ctx, q, shareID) + } + + return getShareFromDbRow(row) +} + +func sqlCommonAddShare(share *Share, dbHandle *sql.DB) error { + err := share.validate() + if err != nil { + return err + } + + user, err := provider.userExists(share.Username, "") + if err != nil { + return util.NewGenericError(fmt.Sprintf("unable to validate user %q", share.Username)) + } + + paths, err := json.Marshal(share.Paths) + if err != nil { + return err + } + var allowFrom []byte + if len(share.AllowFrom) > 0 { + res, err := json.Marshal(share.AllowFrom) + if err == nil { + allowFrom = res + } + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddShareQuery() + usedTokens := 0 + createdAt := util.GetTimeAsMsSinceEpoch(time.Now()) + updatedAt := createdAt + lastUseAt := int64(0) + if share.IsRestore { + usedTokens = share.UsedTokens + if share.CreatedAt > 0 { + createdAt = share.CreatedAt + } + if share.UpdatedAt > 0 { + updatedAt = share.UpdatedAt + } + lastUseAt = share.LastUseAt + } + _, err = dbHandle.ExecContext(ctx, q, share.ShareID, share.Name, share.Description, share.Scope, + paths, createdAt, updatedAt, lastUseAt, share.ExpiresAt, share.Password, + share.MaxTokens, usedTokens, allowFrom, user.ID) + return err +} + +func sqlCommonUpdateShare(share *Share, dbHandle *sql.DB) error { + err := share.validate() + if err != nil { + return err + } + + paths, err := json.Marshal(share.Paths) + if err != nil { + return err + } + + var allowFrom []byte + if len(share.AllowFrom) > 0 { + res, err := json.Marshal(share.AllowFrom) + if err == nil { + allowFrom = res + } + } + + user, err := provider.userExists(share.Username, "") + if err != nil { + return util.NewGenericError(fmt.Sprintf("unable to validate user %q", share.Username)) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + var q string + if share.IsRestore { + q = getUpdateShareRestoreQuery() + } else { + q = getUpdateShareQuery() + } + + var res sql.Result + if share.IsRestore { + if share.CreatedAt == 0 { + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + } + if share.UpdatedAt == 0 { + share.UpdatedAt = share.CreatedAt + } + res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths, + share.CreatedAt, share.UpdatedAt, share.LastUseAt, share.ExpiresAt, share.Password, share.MaxTokens, + share.UsedTokens, allowFrom, user.ID, share.ShareID) + } else { + res, err = dbHandle.ExecContext(ctx, q, share.Name, share.Description, share.Scope, paths, + util.GetTimeAsMsSinceEpoch(time.Now()), share.ExpiresAt, share.Password, share.MaxTokens, + allowFrom, user.ID, share.ShareID) + } + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDeleteShare(share Share, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteShareQuery() + res, err := dbHandle.ExecContext(ctx, q, share.ShareID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetShares(limit, offset int, order, username string, dbHandle sqlQuerier) ([]Share, error) { + shares := make([]Share, 0, limit) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getSharesQuery(order) + rows, err := dbHandle.QueryContext(ctx, q, username, limit, offset) + if err != nil { + return shares, err + } + defer rows.Close() + + for rows.Next() { + s, err := getShareFromDbRow(rows) + if err != nil { + return shares, err + } + s.HideConfidentialData() + shares = append(shares, s) + } + + return shares, rows.Err() +} + +func sqlCommonDumpShares(dbHandle sqlQuerier) ([]Share, error) { + shares := make([]Share, 0, 30) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDumpSharesQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return shares, err + } + defer rows.Close() + + for rows.Next() { + s, err := getShareFromDbRow(rows) + if err != nil { + return shares, err + } + shares = append(shares, s) + } + + return shares, rows.Err() +} + +func sqlCommonGetAPIKeyByID(keyID string, dbHandle sqlQuerier) (APIKey, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAPIKeyByIDQuery() + row := dbHandle.QueryRowContext(ctx, q, keyID) + + apiKey, err := getAPIKeyFromDbRow(row) + if err != nil { + return apiKey, err + } + return getAPIKeyWithRelatedFields(ctx, apiKey, dbHandle) +} + +func sqlCommonAddAPIKey(apiKey *APIKey, dbHandle *sql.DB) error { + err := apiKey.validate() + if err != nil { + return err + } + + userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddAPIKeyQuery() + _, err = dbHandle.ExecContext(ctx, q, apiKey.KeyID, apiKey.Name, apiKey.Key, apiKey.Scope, + util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.LastUseAt, + apiKey.ExpiresAt, apiKey.Description, userID, adminID) + return err +} + +func sqlCommonUpdateAPIKey(apiKey *APIKey, dbHandle *sql.DB) error { + err := apiKey.validate() + if err != nil { + return err + } + + userID, adminID, err := sqlCommonGetAPIKeyRelatedIDs(apiKey) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateAPIKeyQuery() + res, err := dbHandle.ExecContext(ctx, q, apiKey.Name, apiKey.Scope, apiKey.ExpiresAt, userID, adminID, + apiKey.Description, util.GetTimeAsMsSinceEpoch(time.Now()), apiKey.KeyID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDeleteAPIKey(apiKey APIKey, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteAPIKeyQuery() + res, err := dbHandle.ExecContext(ctx, q, apiKey.KeyID) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetAPIKeys(limit, offset int, order string, dbHandle sqlQuerier) ([]APIKey, error) { + apiKeys := make([]APIKey, 0, limit) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAPIKeysQuery(order) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return apiKeys, err + } + defer rows.Close() + + for rows.Next() { + k, err := getAPIKeyFromDbRow(rows) + if err != nil { + return apiKeys, err + } + k.HideConfidentialData() + apiKeys = append(apiKeys, k) + } + err = rows.Err() + if err != nil { + return apiKeys, err + } + apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin) + if err != nil { + return apiKeys, err + } + + return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser) +} + +func sqlCommonDumpAPIKeys(dbHandle sqlQuerier) ([]APIKey, error) { + apiKeys := make([]APIKey, 0, 30) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDumpAPIKeysQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return apiKeys, err + } + defer rows.Close() + + for rows.Next() { + k, err := getAPIKeyFromDbRow(rows) + if err != nil { + return apiKeys, err + } + apiKeys = append(apiKeys, k) + } + err = rows.Err() + if err != nil { + return apiKeys, err + } + apiKeys, err = getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeAdmin) + if err != nil { + return apiKeys, err + } + + return getRelatedValuesForAPIKeys(ctx, apiKeys, dbHandle, APIKeyScopeUser) +} + +func sqlCommonGetAdminByUsername(username string, dbHandle sqlQuerier) (Admin, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAdminByUsernameQuery() + row := dbHandle.QueryRowContext(ctx, q, username) + + admin, err := getAdminFromDbRow(row) + if err != nil { + return admin, err + } + return getAdminWithGroups(ctx, admin, dbHandle) +} + +func sqlCommonValidateAdminAndPass(username, password, ip string, dbHandle *sql.DB) (Admin, error) { + admin, err := sqlCommonGetAdminByUsername(username, dbHandle) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating admin %q: %v", username, err) + return admin, err + } + err = admin.checkUserAndPass(password, ip) + return admin, err +} + +func sqlCommonAddAdmin(admin *Admin, dbHandle *sql.DB) error { + err := admin.validate() + if err != nil { + return err + } + + perms, err := json.Marshal(admin.Permissions) + if err != nil { + return err + } + + filters, err := json.Marshal(admin.Filters) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getAddAdminQuery(admin.Role) + _, err = tx.ExecContext(ctx, q, admin.Username, admin.Password, admin.Status, admin.Email, perms, + filters, admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role) + if err != nil { + return err + } + return generateAdminGroupMapping(ctx, admin, tx) + }) +} + +func sqlCommonUpdateAdmin(admin *Admin, dbHandle *sql.DB) error { + err := admin.validate() + if err != nil { + return err + } + + perms, err := json.Marshal(admin.Permissions) + if err != nil { + return err + } + + filters, err := json.Marshal(admin.Filters) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getUpdateAdminQuery(admin.Role) + _, err = tx.ExecContext(ctx, q, admin.Password, admin.Status, admin.Email, perms, filters, + admin.AdditionalInfo, admin.Description, util.GetTimeAsMsSinceEpoch(time.Now()), admin.Role, admin.Username) + if err != nil { + return err + } + return generateAdminGroupMapping(ctx, admin, tx) + }) +} + +func sqlCommonDeleteAdmin(admin Admin, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteAdminQuery() + res, err := dbHandle.ExecContext(ctx, q, admin.Username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetAdmins(limit, offset int, order string, dbHandle sqlQuerier) ([]Admin, error) { + admins := make([]Admin, 0, limit) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAdminsQuery(order) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return admins, err + } + defer rows.Close() + + for rows.Next() { + a, err := getAdminFromDbRow(rows) + if err != nil { + return admins, err + } + a.HideConfidentialData() + admins = append(admins, a) + } + err = rows.Err() + if err != nil { + return admins, err + } + return getAdminsWithGroups(ctx, admins, dbHandle) +} + +func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) { + admins := make([]Admin, 0, 30) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDumpAdminsQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return admins, err + } + defer rows.Close() + + for rows.Next() { + a, err := getAdminFromDbRow(rows) + if err != nil { + return admins, err + } + admins = append(admins, a) + } + err = rows.Err() + if err != nil { + return admins, err + } + return getAdminsWithGroups(ctx, admins, dbHandle) +} + +func sqlCommonGetIPListEntry(ipOrNet string, listType IPListType, dbHandle sqlQuerier) (IPListEntry, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getIPListEntryQuery() + row := dbHandle.QueryRowContext(ctx, q, listType, ipOrNet) + return getIPListEntryFromDbRow(row) +} + +func sqlCommonDumpIPListEntries(dbHandle *sql.DB) ([]IPListEntry, error) { + count, err := sqlCommonCountIPListEntries(0, dbHandle) + if err != nil { + return nil, err + } + if count > ipListMemoryLimit { + providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count) + return nil, nil + } + entries := make([]IPListEntry, 0, 100) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpListEntriesQuery() + + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return entries, err + } + defer rows.Close() + + for rows.Next() { + entry, err := getIPListEntryFromDbRow(rows) + if err != nil { + return entries, err + } + entries = append(entries, entry) + } + return entries, rows.Err() +} + +func sqlCommonCountIPListEntries(listType IPListType, dbHandle *sql.DB) (int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + var q string + var args []any + if listType == 0 { + q = getCountAllIPListEntriesQuery() + } else { + q = getCountIPListEntriesQuery() + args = append(args, listType) + } + var count int64 + err := dbHandle.QueryRowContext(ctx, q, args...).Scan(&count) + return count, err +} + +func sqlCommonGetIPListEntries(listType IPListType, filter, from, order string, limit int, dbHandle sqlQuerier) ([]IPListEntry, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getIPListEntriesQuery(filter, from, order, limit) + args := []any{listType} + if from != "" { + args = append(args, from) + } + if filter != "" { + args = append(args, filter+"%") + } + if limit > 0 { + args = append(args, limit) + } + entries := make([]IPListEntry, 0, limit) + rows, err := dbHandle.QueryContext(ctx, q, args...) + if err != nil { + return entries, err + } + defer rows.Close() + + for rows.Next() { + entry, err := getIPListEntryFromDbRow(rows) + if err != nil { + return entries, err + } + entries = append(entries, entry) + } + return entries, rows.Err() +} + +func sqlCommonGetRecentlyUpdatedIPListEntries(after int64, dbHandle sqlQuerier) ([]IPListEntry, error) { + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getRecentlyUpdatedIPListQuery() + entries := make([]IPListEntry, 0, 5) + rows, err := dbHandle.QueryContext(ctx, q, after) + if err != nil { + return entries, err + } + defer rows.Close() + + for rows.Next() { + entry, err := getIPListEntryFromDbRow(rows) + if err != nil { + return entries, err + } + entries = append(entries, entry) + } + return entries, rows.Err() +} + +func sqlCommonGetListEntriesForIP(ip string, listType IPListType, dbHandle sqlQuerier) ([]IPListEntry, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + var rows *sql.Rows + var err error + + entries := make([]IPListEntry, 0, 2) + if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { + rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryPg(), listType, ip) + if err != nil { + return entries, err + } + } else { + ipAddr, err := netip.ParseAddr(ip) + if err != nil { + return entries, fmt.Errorf("invalid ip address %s", ip) + } + var netType int + var ipBytes []byte + if ipAddr.Is4() || ipAddr.Is4In6() { + netType = ipTypeV4 + as4 := ipAddr.As4() + ipBytes = as4[:] + } else { + netType = ipTypeV6 + as16 := ipAddr.As16() + ipBytes = as16[:] + } + rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryNoPg(), listType, netType, ipBytes) + if err != nil { + return entries, err + } + } + defer rows.Close() + + for rows.Next() { + entry, err := getIPListEntryFromDbRow(rows) + if err != nil { + return entries, err + } + entries = append(entries, entry) + } + return entries, rows.Err() +} + +func sqlCommonAddIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error { + if err := entry.validate(); err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + var err error + q := getAddIPListEntryQuery() + first := entry.getFirst() + last := entry.getLast() + var netType int + if first.Is4() { + netType = ipTypeV4 + } else { + netType = ipTypeV6 + } + if config.IsShared == 1 { + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, getRemoveSoftDeletedIPListEntryQuery(), entry.Type, entry.IPOrNet) + if err != nil { + return err + } + if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { + _, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(), + netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now())) + } else { + _, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last, + netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now())) + } + return err + }) + } + if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { + _, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(), + netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now())) + } else { + _, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last, + netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now())) + } + return err +} + +func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error { + if err := entry.validate(); err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateIPListEntryQuery() + res, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description, + util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteIPListEntryQuery(softDelete) + var args []any + if softDelete { + ts := util.GetTimeAsMsSinceEpoch(time.Now()) + args = append(args, ts, ts) + } + args = append(args, entry.Type, entry.IPOrNet) + res, err := dbHandle.ExecContext(ctx, q, args...) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetRoleByName(name string, dbHandle sqlQuerier) (Role, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getRoleByNameQuery() + row := dbHandle.QueryRowContext(ctx, q, name) + role, err := getRoleFromDbRow(row) + if err != nil { + return role, err + } + role, err = getRoleWithUsers(ctx, role, dbHandle) + if err != nil { + return role, err + } + return getRoleWithAdmins(ctx, role, dbHandle) +} + +func sqlCommonDumpRoles(dbHandle sqlQuerier) ([]Role, error) { + roles := make([]Role, 0, 10) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpRolesQuery() + + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return roles, err + } + defer rows.Close() + + for rows.Next() { + role, err := getRoleFromDbRow(rows) + if err != nil { + return roles, err + } + roles = append(roles, role) + } + return roles, rows.Err() +} + +func sqlCommonGetRoles(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Role, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getRolesQuery(order, minimal) + + roles := make([]Role, 0, limit) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return roles, err + } + defer rows.Close() + + for rows.Next() { + var role Role + if minimal { + err = rows.Scan(&role.ID, &role.Name) + } else { + role, err = getRoleFromDbRow(rows) + } + if err != nil { + return roles, err + } + roles = append(roles, role) + } + err = rows.Err() + if err != nil { + return roles, err + } + if minimal { + return roles, nil + } + roles, err = getRolesWithUsers(ctx, roles, dbHandle) + if err != nil { + return roles, err + } + return getRolesWithAdmins(ctx, roles, dbHandle) +} + +func sqlCommonAddRole(role *Role, dbHandle *sql.DB) error { + if err := role.validate(); err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddRoleQuery() + _, err := dbHandle.ExecContext(ctx, q, role.Name, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now())) + return err +} + +func sqlCommonUpdateRole(role *Role, dbHandle *sql.DB) error { + if err := role.validate(); err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateRoleQuery() + res, err := dbHandle.ExecContext(ctx, q, role.Description, util.GetTimeAsMsSinceEpoch(time.Now()), role.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDeleteRole(role Role, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteRoleQuery() + res, err := dbHandle.ExecContext(ctx, q, role.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetGroupByName(name string, dbHandle sqlQuerier) (Group, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getGroupByNameQuery() + + row := dbHandle.QueryRowContext(ctx, q, name) + group, err := getGroupFromDbRow(row) + if err != nil { + return group, err + } + group, err = getGroupWithVirtualFolders(ctx, group, dbHandle) + if err != nil { + return group, err + } + group, err = getGroupWithUsers(ctx, group, dbHandle) + if err != nil { + return group, err + } + return getGroupWithAdmins(ctx, group, dbHandle) +} + +func sqlCommonDumpGroups(dbHandle sqlQuerier) ([]Group, error) { + groups := make([]Group, 0, 50) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpGroupsQuery() + + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return groups, err + } + defer rows.Close() + + for rows.Next() { + group, err := getGroupFromDbRow(rows) + if err != nil { + return groups, err + } + groups = append(groups, group) + } + err = rows.Err() + if err != nil { + return groups, err + } + return getGroupsWithVirtualFolders(ctx, groups, dbHandle) +} + +func sqlCommonGetUsersInGroups(names []string, dbHandle sqlQuerier) ([]string, error) { + if len(names) == 0 { + return nil, nil + } + maxNames := len(sqlPlaceholders) + usernames := make([]string, 0, len(names)) + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + for len(names) > 0 { + if maxNames > len(names) { + maxNames = len(names) + } + + q := getUsersInGroupsQuery(maxNames) + args := make([]any, 0, maxNames) + for _, name := range names[:maxNames] { + args = append(args, name) + } + + rows, err := dbHandle.QueryContext(ctx, q, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var username string + err = rows.Scan(&username) + if err != nil { + return usernames, err + } + usernames = append(usernames, username) + } + err = rows.Err() + if err != nil { + return usernames, err + } + names = names[maxNames:] + } + return usernames, nil +} + +func sqlCommonGetGroupsWithNames(names []string, dbHandle sqlQuerier) ([]Group, error) { + if len(names) == 0 { + return nil, nil + } + maxNames := len(sqlPlaceholders) + groups := make([]Group, 0, len(names)) + for len(names) > 0 { + if maxNames > len(names) { + maxNames = len(names) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getGroupsWithNamesQuery(maxNames) + args := make([]any, 0, maxNames) + for _, name := range names[:maxNames] { + args = append(args, name) + } + rows, err := dbHandle.QueryContext(ctx, q, args...) + if err != nil { + return groups, err + } + defer rows.Close() + + for rows.Next() { + group, err := getGroupFromDbRow(rows) + if err != nil { + return groups, err + } + groups = append(groups, group) + } + err = rows.Err() + if err != nil { + return groups, err + } + names = names[maxNames:] + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return getGroupsWithVirtualFolders(ctx, groups, dbHandle) +} + +func sqlCommonGetGroups(limit int, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]Group, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getGroupsQuery(order, minimal) + + groups := make([]Group, 0, limit) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return groups, err + } + defer rows.Close() + + for rows.Next() { + var group Group + if minimal { + err = rows.Scan(&group.ID, &group.Name) + } else { + group, err = getGroupFromDbRow(rows) + } + if err != nil { + return groups, err + } + groups = append(groups, group) + } + err = rows.Err() + if err != nil { + return groups, err + } + if minimal { + return groups, nil + } + groups, err = getGroupsWithVirtualFolders(ctx, groups, dbHandle) + if err != nil { + return groups, err + } + groups, err = getGroupsWithUsers(ctx, groups, dbHandle) + if err != nil { + return groups, err + } + groups, err = getGroupsWithAdmins(ctx, groups, dbHandle) + if err != nil { + return groups, err + } + for idx := range groups { + groups[idx].PrepareForRendering() + } + return groups, nil +} + +func sqlCommonAddGroup(group *Group, dbHandle *sql.DB) error { + if err := group.validate(); err != nil { + return err + } + settings, err := json.Marshal(group.UserSettings) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getAddGroupQuery() + _, err := tx.ExecContext(ctx, q, group.Name, group.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now()), settings) + if err != nil { + return err + } + return generateGroupVirtualFoldersMapping(ctx, group, tx) + }) +} + +func sqlCommonUpdateGroup(group *Group, dbHandle *sql.DB) error { + if err := group.validate(); err != nil { + return err + } + + settings, err := json.Marshal(group.UserSettings) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getUpdateGroupQuery() + _, err := tx.ExecContext(ctx, q, group.Description, settings, util.GetTimeAsMsSinceEpoch(time.Now()), group.Name) + if err != nil { + return err + } + return generateGroupVirtualFoldersMapping(ctx, group, tx) + }) +} + +func sqlCommonDeleteGroup(group Group, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteGroupQuery() + res, err := dbHandle.ExecContext(ctx, q, group.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetUserByUsername(username, role string, dbHandle sqlQuerier) (User, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUserByUsernameQuery(role) + args := []any{username} + if role != "" { + args = append(args, role) + } + row := dbHandle.QueryRowContext(ctx, q, args...) + user, err := getUserFromDbRow(row) + if err != nil { + return user, err + } + user, err = getUserWithVirtualFolders(ctx, user, dbHandle) + if err != nil { + return user, err + } + return getUserWithGroups(ctx, user, dbHandle) +} + +func sqlCommonValidateUserAndPass(username, password, ip, protocol string, dbHandle *sql.DB) (User, error) { + user, err := sqlCommonGetUserByUsername(username, "", dbHandle) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, err + } + return checkUserAndPass(&user, password, ip, protocol) +} + +func sqlCommonValidateUserAndTLSCertificate(username, protocol string, tlsCert *x509.Certificate, dbHandle *sql.DB) (User, error) { + var user User + if tlsCert == nil { + return user, errors.New("TLS certificate cannot be null or empty") + } + user, err := sqlCommonGetUserByUsername(username, "", dbHandle) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, err + } + return checkUserAndTLSCertificate(&user, protocol, tlsCert) +} + +func sqlCommonValidateUserAndPubKey(username string, pubKey []byte, isSSHCert bool, dbHandle *sql.DB) (User, string, error) { + var user User + if len(pubKey) == 0 { + return user, "", errors.New("credentials cannot be null or empty") + } + user, err := sqlCommonGetUserByUsername(username, "", dbHandle) + if err != nil { + providerLog(logger.LevelWarn, "error authenticating user %q: %v", username, err) + return user, "", err + } + return checkUserAndPubKey(&user, pubKey, isSSHCert) +} + +func sqlCommonCheckAvailability(dbHandle *sql.DB) (err error) { + defer func() { + if r := recover(); r != nil { + providerLog(logger.LevelError, "panic in check provider availability, stack trace: %s", string(debug.Stack())) + err = errors.New("unable to check provider status") + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + err = dbHandle.PingContext(ctx) + return +} + +func sqlCommonUpdateTransferQuota(username string, uploadSize, downloadSize int64, reset bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateTransferQuotaQuery(reset) + _, err := dbHandle.ExecContext(ctx, q, uploadSize, downloadSize, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + providerLog(logger.LevelDebug, "transfer quota updated for user %q, ul increment: %d dl increment: %d is reset? %t", + username, uploadSize, downloadSize, reset) + } else { + providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err) + } + return err +} + +func sqlCommonUpdateQuota(username string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateQuotaQuery(reset) + _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + providerLog(logger.LevelDebug, "quota updated for user %q, files increment: %d size increment: %d is reset? %t", + username, filesAdd, sizeAdd, reset) + } else { + providerLog(logger.LevelError, "error updating quota for user %q: %v", username, err) + } + return err +} + +func sqlCommonGetAdminSignature(username string, dbHandle *sql.DB) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAdminSignatureQuery() + var updatedAt int64 + err := dbHandle.QueryRowContext(ctx, q, username).Scan(&updatedAt) + if err != nil { + return "", err + } + return strconv.FormatInt(updatedAt, 10), nil +} + +func sqlCommonGetUserSignature(username string, dbHandle *sql.DB) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUserSignatureQuery() + var updatedAt int64 + err := dbHandle.QueryRowContext(ctx, q, username).Scan(&updatedAt) + if err != nil { + return "", err + } + return strconv.FormatInt(updatedAt, 10), nil +} + +func sqlCommonGetUsedQuota(username string, dbHandle *sql.DB) (int, int64, int64, int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getQuotaQuery() + var usedFiles int + var usedSize, usedUploadSize, usedDownloadSize int64 + err := dbHandle.QueryRowContext(ctx, q, username).Scan(&usedSize, &usedFiles, &usedUploadSize, &usedDownloadSize) + if err != nil { + providerLog(logger.LevelError, "error getting quota for user: %v, error: %v", username, err) + return 0, 0, 0, 0, err + } + return usedFiles, usedSize, usedUploadSize, usedDownloadSize, err +} + +func sqlCommonUpdateShareLastUse(shareID string, numTokens int, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateShareLastUseQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), numTokens, shareID) + if err == nil { + providerLog(logger.LevelDebug, "last use updated for shared object %q", shareID) + } else { + providerLog(logger.LevelWarn, "error updating last use for shared object %q: %v", shareID, err) + } + return err +} + +func sqlCommonUpdateAPIKeyLastUse(keyID string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateAPIKeyLastUseQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), keyID) + if err == nil { + providerLog(logger.LevelDebug, "last use updated for key %q", keyID) + } else { + providerLog(logger.LevelWarn, "error updating last use for key %q: %v", keyID, err) + } + return err +} + +func sqlCommonUpdateAdminLastLogin(username string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateAdminLastLoginQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + providerLog(logger.LevelDebug, "last login updated for admin %q", username) + } else { + providerLog(logger.LevelWarn, "error updating last login for admin %q: %v", username, err) + } + return err +} + +func sqlCommonSetUpdatedAt(username string, dbHandle *sql.DB) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getSetUpdateAtQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + providerLog(logger.LevelDebug, "updated_at set for user %q", username) + } else { + providerLog(logger.LevelWarn, "error setting updated_at for user %q: %v", username, err) + } +} + +func sqlCommonSetFirstDownloadTimestamp(username string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getSetFirstDownloadQuery() + res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonSetFirstUploadTimestamp(username string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getSetFirstUploadQuery() + res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonUpdateLastLogin(username string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateLastLoginQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err == nil { + providerLog(logger.LevelDebug, "last login updated for user %q", username) + } else { + providerLog(logger.LevelWarn, "error updating last login for user %q: %v", username, err) + } + return err +} + +func sqlCommonAddUser(user *User, dbHandle *sql.DB) error { + err := ValidateUser(user) + if err != nil { + return err + } + + permissions, err := user.GetPermissionsAsJSON() + if err != nil { + return err + } + publicKeys, err := user.GetPublicKeysAsJSON() + if err != nil { + return err + } + filters, err := user.GetFiltersAsJSON() + if err != nil { + return err + } + fsConfig, err := user.GetFsConfigAsJSON() + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + if config.IsShared == 1 { + _, err := tx.ExecContext(ctx, getRemoveSoftDeletedUserQuery(), user.Username) + if err != nil { + return err + } + } + q := getAddUserQuery(user.Role) + _, err := tx.ExecContext(ctx, q, user.Username, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, + user.MaxSessions, user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, + user.DownloadBandwidth, user.Status, user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, + user.Description, user.Email, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(time.Now()), + user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, user.Role, user.LastPasswordChange) + if err != nil { + return err + } + if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil { + return err + } + return generateUserGroupMapping(ctx, user, tx) + }) +} + +func sqlCommonUpdateUserPassword(username, password string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateUserPasswordQuery() + res, err := dbHandle.ExecContext(ctx, q, password, util.GetTimeAsMsSinceEpoch(time.Now()), username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonUpdateUser(user *User, dbHandle *sql.DB) error { + err := ValidateUser(user) + if err != nil { + return err + } + + permissions, err := user.GetPermissionsAsJSON() + if err != nil { + return err + } + publicKeys, err := user.GetPublicKeysAsJSON() + if err != nil { + return err + } + filters, err := user.GetFiltersAsJSON() + if err != nil { + return err + } + fsConfig, err := user.GetFsConfigAsJSON() + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getUpdateUserQuery(user.Role) + res, err := tx.ExecContext(ctx, q, user.Password, publicKeys, user.HomeDir, user.UID, user.GID, user.MaxSessions, + user.QuotaSize, user.QuotaFiles, permissions, user.UploadBandwidth, user.DownloadBandwidth, user.Status, + user.ExpirationDate, filters, fsConfig, user.AdditionalInfo, user.Description, user.Email, + util.GetTimeAsMsSinceEpoch(time.Now()), user.UploadDataTransfer, user.DownloadDataTransfer, user.TotalDataTransfer, + user.Role, user.LastPasswordChange, user.Username) + if err != nil { + return err + } + if err := sqlCommonRequireRowAffected(res); err != nil { + return err + } + if err := generateUserVirtualFoldersMapping(ctx, user, tx); err != nil { + return err + } + return generateUserGroupMapping(ctx, user, tx) + }) +} + +func sqlCommonDeleteUser(user User, softDelete bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteUserQuery(softDelete) + if softDelete { + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + if err := sqlCommonClearUserFolderMapping(ctx, &user, tx); err != nil { + return err + } + if err := sqlCommonClearUserGroupMapping(ctx, &user, tx); err != nil { + return err + } + ts := util.GetTimeAsMsSinceEpoch(time.Now()) + res, err := tx.ExecContext(ctx, q, ts, ts, user.Username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) + }) + } + res, err := dbHandle.ExecContext(ctx, q, user.Username) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDumpUsers(dbHandle sqlQuerier) ([]User, error) { + users := make([]User, 0, 100) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpUsersQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return users, err + } + + defer rows.Close() + for rows.Next() { + u, err := getUserFromDbRow(rows) + if err != nil { + return users, err + } + users = append(users, u) + } + err = rows.Err() + if err != nil { + return users, err + } + users, err = getUsersWithVirtualFolders(ctx, users, dbHandle) + if err != nil { + return users, err + } + return getUsersWithGroups(ctx, users, dbHandle) +} + +func sqlCommonGetRecentlyUpdatedUsers(after int64, dbHandle sqlQuerier) ([]User, error) { + users := make([]User, 0, 10) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getRecentlyUpdatedUsersQuery() + + rows, err := dbHandle.QueryContext(ctx, q, after) + if err != nil { + return users, err + } + defer rows.Close() + + for rows.Next() { + u, err := getUserFromDbRow(rows) + if err != nil { + return users, err + } + users = append(users, u) + } + err = rows.Err() + if err != nil { + return users, err + } + users, err = getUsersWithVirtualFolders(ctx, users, dbHandle) + if err != nil { + return users, err + } + users, err = getUsersWithGroups(ctx, users, dbHandle) + if err != nil { + return users, err + } + var groupNames []string + for _, u := range users { + for _, g := range u.Groups { + groupNames = append(groupNames, g.Name) + } + } + groupNames = util.RemoveDuplicates(groupNames, false) + if len(groupNames) == 0 { + return users, nil + } + groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle) + if err != nil { + return users, err + } + if len(groups) == 0 { + return users, nil + } + groupsMapping := make(map[string]Group) + for idx := range groups { + groupsMapping[groups[idx].Name] = groups[idx] + } + for idx := range users { + ref := &users[idx] + ref.applyGroupSettings(groupsMapping) + } + return users, nil +} + +func sqlGetMaxUsersForQuotaCheckRange() int { + maxUsers := 50 + if maxUsers > len(sqlPlaceholders) { + maxUsers = len(sqlPlaceholders) + } + return maxUsers +} + +func sqlCommonGetUsersForQuotaCheck(toFetch map[string]bool, dbHandle sqlQuerier) ([]User, error) { + maxUsers := sqlGetMaxUsersForQuotaCheckRange() + users := make([]User, 0, maxUsers) + + usernames := make([]string, 0, len(toFetch)) + for k := range toFetch { + usernames = append(usernames, k) + } + + for len(usernames) > 0 { + if maxUsers > len(usernames) { + maxUsers = len(usernames) + } + usersRange, err := sqlCommonGetUsersRangeForQuotaCheck(usernames[:maxUsers], dbHandle) + if err != nil { + return users, err + } + users = append(users, usersRange...) + usernames = usernames[maxUsers:] + } + + var usersWithFolders []User + + validIdx := 0 + for _, user := range users { + if toFetch[user.Username] { + usersWithFolders = append(usersWithFolders, user) + } else { + users[validIdx] = user + validIdx++ + } + } + users = users[:validIdx] + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + usersWithFolders, err := getUsersWithVirtualFolders(ctx, usersWithFolders, dbHandle) + if err != nil { + return users, err + } + users = append(users, usersWithFolders...) + users, err = getUsersWithGroups(ctx, users, dbHandle) + if err != nil { + return users, err + } + var groupNames []string + for _, u := range users { + for _, g := range u.Groups { + groupNames = append(groupNames, g.Name) + } + } + groupNames = util.RemoveDuplicates(groupNames, false) + if len(groupNames) == 0 { + return users, nil + } + groups, err := sqlCommonGetGroupsWithNames(groupNames, dbHandle) + if err != nil { + return users, err + } + groupsMapping := make(map[string]Group) + for idx := range groups { + groupsMapping[groups[idx].Name] = groups[idx] + } + for idx := range users { + ref := &users[idx] + ref.applyGroupSettings(groupsMapping) + } + return users, nil +} + +func sqlCommonGetUsersRangeForQuotaCheck(usernames []string, dbHandle sqlQuerier) ([]User, error) { + users := make([]User, 0, len(usernames)) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUsersForQuotaCheckQuery(len(usernames)) + queryArgs := make([]any, 0, len(usernames)) + for idx := range usernames { + queryArgs = append(queryArgs, usernames[idx]) + } + + rows, err := dbHandle.QueryContext(ctx, q, queryArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var user User + var filters []byte + err = rows.Scan(&user.ID, &user.Username, &user.QuotaSize, &user.UsedQuotaSize, &user.TotalDataTransfer, + &user.UploadDataTransfer, &user.DownloadDataTransfer, &user.UsedUploadDataTransfer, + &user.UsedDownloadDataTransfer, &filters) + if err != nil { + return users, err + } + var userFilters UserFilters + err = json.Unmarshal(filters, &userFilters) + if err == nil { + user.Filters = userFilters + } + users = append(users, user) + } + + return users, rows.Err() +} + +func sqlCommonAddActiveTransfer(transfer ActiveTransfer, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddActiveTransferQuery() + now := util.GetTimeAsMsSinceEpoch(time.Now()) + _, err := dbHandle.ExecContext(ctx, q, transfer.ID, transfer.ConnID, transfer.Type, transfer.Username, + transfer.FolderName, transfer.IP, transfer.TruncatedSize, transfer.CurrentULSize, transfer.CurrentDLSize, + now, now) + return err +} + +func sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateActiveTransferSizesQuery() + _, err := dbHandle.ExecContext(ctx, q, ulSize, dlSize, util.GetTimeAsMsSinceEpoch(time.Now()), connectionID, transferID) + return err +} + +func sqlCommonRemoveActiveTransfer(transferID int64, connectionID string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getRemoveActiveTransferQuery() + _, err := dbHandle.ExecContext(ctx, q, connectionID, transferID) + return err +} + +func sqlCommonCleanupActiveTransfers(before time.Time, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getCleanupActiveTransfersQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(before)) + return err +} + +func sqlCommonGetActiveTransfers(from time.Time, dbHandle sqlQuerier) ([]ActiveTransfer, error) { + transfers := make([]ActiveTransfer, 0, 30) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getActiveTransfersQuery() + rows, err := dbHandle.QueryContext(ctx, q, util.GetTimeAsMsSinceEpoch(from)) + if err != nil { + return nil, err + } + + defer rows.Close() + for rows.Next() { + var transfer ActiveTransfer + var folderName sql.NullString + err = rows.Scan(&transfer.ID, &transfer.ConnID, &transfer.Type, &transfer.Username, &folderName, &transfer.IP, + &transfer.TruncatedSize, &transfer.CurrentULSize, &transfer.CurrentDLSize, &transfer.CreatedAt, + &transfer.UpdatedAt) + if err != nil { + return transfers, err + } + if folderName.Valid { + transfer.FolderName = folderName.String + } + transfers = append(transfers, transfer) + } + + return transfers, rows.Err() +} + +func sqlCommonGetUsers(limit int, offset int, order, role string, dbHandle sqlQuerier) ([]User, error) { + users := make([]User, 0, limit) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUsersQuery(order, role) + var args []any + if role == "" { + args = append(args, limit, offset) + } else { + args = append(args, role, limit, offset) + } + rows, err := dbHandle.QueryContext(ctx, q, args...) + if err != nil { + return users, err + } + defer rows.Close() + + for rows.Next() { + u, err := getUserFromDbRow(rows) + if err != nil { + return users, err + } + users = append(users, u) + } + err = rows.Err() + if err != nil { + return users, err + } + users, err = getUsersWithVirtualFolders(ctx, users, dbHandle) + if err != nil { + return users, err + } + users, err = getUsersWithGroups(ctx, users, dbHandle) + if err != nil { + return users, err + } + for idx := range users { + users[idx].PrepareForRendering() + } + return users, nil +} + +func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]DefenderEntry, error) { + hosts := make([]DefenderEntry, 0, 100) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderHostsQuery() + rows, err := dbHandle.QueryContext(ctx, q, from, limit) + if err != nil { + providerLog(logger.LevelError, "unable to get defender hosts: %v", err) + return hosts, err + } + defer rows.Close() + + var idForScores []int64 + + for rows.Next() { + var banTime sql.NullInt64 + host := DefenderEntry{} + err = rows.Scan(&host.ID, &host.IP, &banTime) + if err != nil { + providerLog(logger.LevelError, "unable to scan defender host row: %v", err) + return hosts, err + } + var hostBanTime time.Time + if banTime.Valid && banTime.Int64 > 0 { + hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64) + } + if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) { + idForScores = append(idForScores, host.ID) + } else { + host.BanTime = hostBanTime + } + hosts = append(hosts, host) + } + err = rows.Err() + if err != nil { + providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err) + return hosts, err + } + + return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle) +} + +func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (DefenderEntry, error) { + var host DefenderEntry + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderIsHostBannedQuery() + row := dbHandle.QueryRowContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now())) + err := row.Scan(&host.ID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return host, util.NewRecordNotFoundError("host not found") + } + providerLog(logger.LevelError, "unable to check ban status for host %q: %v", ip, err) + return host, err + } + + return host, nil +} + +func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (DefenderEntry, error) { + var host DefenderEntry + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderHostQuery() + row := dbHandle.QueryRowContext(ctx, q, ip, from) + var banTime sql.NullInt64 + err := row.Scan(&host.ID, &host.IP, &banTime) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return host, util.NewRecordNotFoundError("host not found") + } + providerLog(logger.LevelError, "unable to get host for ip %q: %v", ip, err) + return host, err + } + if banTime.Valid && banTime.Int64 > 0 { + hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64) + if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) { + host.BanTime = hostBanTime + return host, nil + } + } + + hosts, err := getDefenderHostsWithScores(ctx, []DefenderEntry{host}, from, []int64{host.ID}, dbHandle) + if err != nil { + return host, err + } + if len(hosts) == 0 { + return host, util.NewRecordNotFoundError("host not found") + } + + return hosts[0], nil +} + +func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderIncrementBanTimeQuery() + _, err := dbHandle.ExecContext(ctx, q, minutesToAdd*60000, ip) + if err == nil { + providerLog(logger.LevelDebug, "ban time updated for ip %q, increment (minutes): %v", + ip, minutesToAdd) + } else { + providerLog(logger.LevelError, "error updating ban time for ip %q: %v", ip, err) + } + return err +} + +func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderSetBanTimeQuery() + _, err := dbHandle.ExecContext(ctx, q, banTime, ip) + if err == nil { + providerLog(logger.LevelDebug, "ip %q banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime)) + } else { + providerLog(logger.LevelError, "error setting ban time for ip %q: %v", ip, err) + } + return err +} + +func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteDefenderHostQuery() + res, err := dbHandle.ExecContext(ctx, q, ip) + if err != nil { + providerLog(logger.LevelError, "unable to delete defender host %q: %v", ip, err) + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil { + return err + } + return sqlCommonAddDefenderEvent(ctx, ip, score, tx) + }) +} + +func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error { + if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil { + return err + } + return sqlCommonCleanupDefenderHosts(from, dbHandler) +} + +func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error { + q := getAddDefenderHostQuery() + _, err := tx.ExecContext(ctx, q, ip, util.GetTimeAsMsSinceEpoch(time.Now())) + if err != nil { + providerLog(logger.LevelError, "unable to add defender host %q: %v", ip, err) + } + return err +} + +func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error { + q := getAddDefenderEventQuery() + _, err := tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip) + if err != nil { + providerLog(logger.LevelError, "unable to add defender event for %q: %v", ip, err) + } + return err +} + +func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderHostsCleanupQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), from) + if err != nil { + providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err) + } + return err +} + +func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDefenderEventsCleanupQuery() + _, err := dbHandle.ExecContext(ctx, q, from) + if err != nil { + providerLog(logger.LevelError, "unable to cleanup defender events: %v", err) + } + return err +} + +func getShareFromDbRow(row sqlScanner) (Share, error) { + var share Share + var description, password sql.NullString + var allowFrom, paths []byte + + err := row.Scan(&share.ShareID, &share.Name, &description, &share.Scope, + &paths, &share.Username, &share.CreatedAt, &share.UpdatedAt, + &share.LastUseAt, &share.ExpiresAt, &password, &share.MaxTokens, + &share.UsedTokens, &allowFrom) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return share, util.NewRecordNotFoundError(err.Error()) + } + return share, err + } + var list []string + err = json.Unmarshal(paths, &list) + if err != nil { + return share, err + } + share.Paths = list + if description.Valid { + share.Description = description.String + } + if password.Valid { + share.Password = password.String + } + list = nil + err = json.Unmarshal(allowFrom, &list) + if err == nil { + share.AllowFrom = list + } + return share, nil +} + +func getAPIKeyFromDbRow(row sqlScanner) (APIKey, error) { + var apiKey APIKey + var userID, adminID sql.NullInt64 + var description sql.NullString + + err := row.Scan(&apiKey.KeyID, &apiKey.Name, &apiKey.Key, &apiKey.Scope, &apiKey.CreatedAt, &apiKey.UpdatedAt, + &apiKey.LastUseAt, &apiKey.ExpiresAt, &description, &userID, &adminID) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return apiKey, util.NewRecordNotFoundError(err.Error()) + } + return apiKey, err + } + + if userID.Valid { + apiKey.userID = userID.Int64 + } + if adminID.Valid { + apiKey.adminID = adminID.Int64 + } + if description.Valid { + apiKey.Description = description.String + } + + return apiKey, nil +} + +func getAdminFromDbRow(row sqlScanner) (Admin, error) { + var admin Admin + var email, additionalInfo, description, role sql.NullString + var permissions, filters []byte + + err := row.Scan(&admin.ID, &admin.Username, &admin.Password, &admin.Status, &email, &permissions, + &filters, &additionalInfo, &description, &admin.CreatedAt, &admin.UpdatedAt, &admin.LastLogin, &role) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return admin, util.NewRecordNotFoundError(err.Error()) + } + return admin, err + } + + var perms []string + err = json.Unmarshal(permissions, &perms) + if err != nil { + return admin, err + } + admin.Permissions = perms + + if email.Valid { + admin.Email = email.String + } + + var adminFilters AdminFilters + err = json.Unmarshal(filters, &adminFilters) + if err == nil { + admin.Filters = adminFilters + } + if additionalInfo.Valid { + admin.AdditionalInfo = additionalInfo.String + } + if description.Valid { + admin.Description = description.String + } + if role.Valid { + admin.Role = role.String + } + + admin.SetEmptySecretsIfNil() + return admin, nil +} + +func getEventActionFromDbRow(row sqlScanner) (BaseEventAction, error) { + var action BaseEventAction + var description sql.NullString + var options []byte + + err := row.Scan(&action.ID, &action.Name, &description, &action.Type, &options) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return action, util.NewRecordNotFoundError(err.Error()) + } + return action, err + } + if description.Valid { + action.Description = description.String + } + var actionOptions BaseEventActionOptions + err = json.Unmarshal(options, &actionOptions) + if err == nil { + action.Options = actionOptions + } + return action, nil +} + +func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) { + var rule EventRule + var description sql.NullString + var conditions []byte + + err := row.Scan(&rule.ID, &rule.Name, &description, &rule.CreatedAt, &rule.UpdatedAt, &rule.Trigger, + &conditions, &rule.DeletedAt, &rule.Status) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return rule, util.NewRecordNotFoundError(err.Error()) + } + return rule, err + } + var ruleConditions EventConditions + err = json.Unmarshal(conditions, &ruleConditions) + if err == nil { + rule.Conditions = ruleConditions + } + + if description.Valid { + rule.Description = description.String + } + return rule, nil +} + +func getIPListEntryFromDbRow(row sqlScanner) (IPListEntry, error) { + var entry IPListEntry + var description sql.NullString + + err := row.Scan(&entry.Type, &entry.IPOrNet, &entry.Mode, &entry.Protocols, &description, + &entry.CreatedAt, &entry.UpdatedAt, &entry.DeletedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return entry, util.NewRecordNotFoundError(err.Error()) + } + return entry, err + } + if description.Valid { + entry.Description = description.String + } + return entry, err +} + +func getRoleFromDbRow(row sqlScanner) (Role, error) { + var role Role + var description sql.NullString + + err := row.Scan(&role.ID, &role.Name, &description, &role.CreatedAt, &role.UpdatedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return role, util.NewRecordNotFoundError(err.Error()) + } + return role, err + } + if description.Valid { + role.Description = description.String + } + + return role, nil +} + +func getGroupFromDbRow(row sqlScanner) (Group, error) { + var group Group + var description sql.NullString + var userSettings []byte + + err := row.Scan(&group.ID, &group.Name, &description, &group.CreatedAt, &group.UpdatedAt, &userSettings) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return group, util.NewRecordNotFoundError(err.Error()) + } + return group, err + } + if description.Valid { + group.Description = description.String + } + + var settings GroupUserSettings + err = json.Unmarshal(userSettings, &settings) + if err == nil { + group.UserSettings = settings + } + + return group, nil +} + +func getUserFromDbRow(row sqlScanner) (User, error) { + var user User + var password sql.NullString + var permissions, publicKey, filters, fsConfig []byte + var additionalInfo, description, email, role sql.NullString + + err := row.Scan(&user.ID, &user.Username, &password, &publicKey, &user.HomeDir, &user.UID, &user.GID, &user.MaxSessions, + &user.QuotaSize, &user.QuotaFiles, &permissions, &user.UsedQuotaSize, &user.UsedQuotaFiles, &user.LastQuotaUpdate, + &user.UploadBandwidth, &user.DownloadBandwidth, &user.ExpirationDate, &user.LastLogin, &user.Status, &filters, &fsConfig, + &additionalInfo, &description, &email, &user.CreatedAt, &user.UpdatedAt, &user.UploadDataTransfer, &user.DownloadDataTransfer, + &user.TotalDataTransfer, &user.UsedUploadDataTransfer, &user.UsedDownloadDataTransfer, &user.DeletedAt, &user.FirstDownload, + &user.FirstUpload, &role, &user.LastPasswordChange) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return user, util.NewRecordNotFoundError(err.Error()) + } + return user, err + } + if password.Valid { + user.Password = password.String + } + perms := make(map[string][]string) + err = json.Unmarshal(permissions, &perms) + if err != nil { + providerLog(logger.LevelError, "unable to deserialize permissions for user %q: %v", user.Username, err) + return user, fmt.Errorf("unable to deserialize permissions for user %q: %v", user.Username, err) + } + user.Permissions = perms + // we can have a empty string or an invalid json in null string + // so we do a relaxed test if the field is optional, for example we + // populate public keys only if unmarshal does not return an error + var pKeys []string + err = json.Unmarshal(publicKey, &pKeys) + if err == nil { + user.PublicKeys = pKeys + } + var userFilters UserFilters + err = json.Unmarshal(filters, &userFilters) + if err == nil { + user.Filters = userFilters + } + var fs vfs.Filesystem + err = json.Unmarshal(fsConfig, &fs) + if err == nil { + user.FsConfig = fs + } + if additionalInfo.Valid { + user.AdditionalInfo = additionalInfo.String + } + if description.Valid { + user.Description = description.String + } + if email.Valid { + user.Email = email.String + } + if role.Valid { + user.Role = role.String + } + user.SetEmptySecretsIfNil() + return user, nil +} + +func sqlCommonGetFolder(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { + var folder vfs.BaseVirtualFolder + q := getFolderByNameQuery() + row := dbHandle.QueryRowContext(ctx, q, name) + var mappedPath, description sql.NullString + var fsConfig []byte + err := row.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, &folder.LastQuotaUpdate, + &folder.Name, &description, &fsConfig) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return folder, util.NewRecordNotFoundError(err.Error()) + } + return folder, err + } + if mappedPath.Valid { + folder.MappedPath = mappedPath.String + } + if description.Valid { + folder.Description = description.String + } + var fs vfs.Filesystem + err = json.Unmarshal(fsConfig, &fs) + if err == nil { + folder.FsConfig = fs + } + return folder, err +} + +func sqlCommonGetFolderByName(ctx context.Context, name string, dbHandle sqlQuerier) (vfs.BaseVirtualFolder, error) { + folder, err := sqlCommonGetFolder(ctx, name, dbHandle) + if err != nil { + return folder, err + } + folders, err := getVirtualFoldersWithUsers([]vfs.BaseVirtualFolder{folder}, dbHandle) + if err != nil { + return folder, err + } + if len(folders) != 1 { + return folder, fmt.Errorf("unable to associate users with folder %q", name) + } + folders, err = getVirtualFoldersWithGroups([]vfs.BaseVirtualFolder{folders[0]}, dbHandle) + if err != nil { + return folder, err + } + if len(folders) != 1 { + return folder, fmt.Errorf("unable to associate groups with folder %q", name) + } + return folders[0], nil +} + +func sqlCommonAddFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { + err := ValidateFolder(folder) + if err != nil { + return err + } + fsConfig, err := json.Marshal(folder.FsConfig) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddFolderQuery() + _, err = dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.UsedQuotaSize, folder.UsedQuotaFiles, + folder.LastQuotaUpdate, folder.Name, folder.Description, fsConfig) + return err +} + +func sqlCommonUpdateFolder(folder *vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { + err := ValidateFolder(folder) + if err != nil { + return err + } + fsConfig, err := json.Marshal(folder.FsConfig) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateFolderQuery() + res, err := dbHandle.ExecContext(ctx, q, folder.MappedPath, folder.Description, fsConfig, folder.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDeleteFolder(folder vfs.BaseVirtualFolder, dbHandle sqlQuerier) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteFolderQuery() + res, err := dbHandle.ExecContext(ctx, q, folder.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDumpFolders(dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { + folders := make([]vfs.BaseVirtualFolder, 0, 50) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpFoldersQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return folders, err + } + defer rows.Close() + for rows.Next() { + var folder vfs.BaseVirtualFolder + var mappedPath, description sql.NullString + var fsConfig []byte + err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, + &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) + if err != nil { + return folders, err + } + if mappedPath.Valid { + folder.MappedPath = mappedPath.String + } + if description.Valid { + folder.Description = description.String + } + var fs vfs.Filesystem + err = json.Unmarshal(fsConfig, &fs) + if err == nil { + folder.FsConfig = fs + } + folders = append(folders, folder) + } + return folders, rows.Err() +} + +func sqlCommonGetFolders(limit, offset int, order string, minimal bool, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { + folders := make([]vfs.BaseVirtualFolder, 0, limit) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getFoldersQuery(order, minimal) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return folders, err + } + defer rows.Close() + for rows.Next() { + var folder vfs.BaseVirtualFolder + if minimal { + err = rows.Scan(&folder.ID, &folder.Name) + if err != nil { + return folders, err + } + } else { + var mappedPath, description sql.NullString + var fsConfig []byte + err = rows.Scan(&folder.ID, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, + &folder.LastQuotaUpdate, &folder.Name, &description, &fsConfig) + if err != nil { + return folders, err + } + if mappedPath.Valid { + folder.MappedPath = mappedPath.String + } + if description.Valid { + folder.Description = description.String + } + var fs vfs.Filesystem + err = json.Unmarshal(fsConfig, &fs) + if err == nil { + folder.FsConfig = fs + } + } + folder.PrepareForRendering() + folders = append(folders, folder) + } + + err = rows.Err() + if err != nil { + return folders, err + } + if minimal { + return folders, nil + } + folders, err = getVirtualFoldersWithUsers(folders, dbHandle) + if err != nil { + return folders, err + } + return getVirtualFoldersWithGroups(folders, dbHandle) +} + +func sqlCommonClearUserFolderMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { + q := getClearUserFolderMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, user.Username) + return err +} + +func sqlCommonClearGroupFolderMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error { + q := getClearGroupFolderMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, group.Name) + return err +} + +func sqlCommonClearUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { + q := getClearUserGroupMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, user.Username) + return err +} + +func sqlCommonAddUserFolderMapping(ctx context.Context, user *User, folder *vfs.VirtualFolder, sortOrder int, dbHandle sqlQuerier) error { + q := getAddUserFolderMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, user.Username, sortOrder) + return err +} + +func sqlCommonClearAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error { + q := getClearAdminGroupMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, admin.Username) + return err +} + +func sqlCommonAddGroupFolderMapping(ctx context.Context, group *Group, folder *vfs.VirtualFolder, sortOrder int, + dbHandle sqlQuerier, +) error { + q := getAddGroupFolderMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, folder.VirtualPath, folder.QuotaSize, folder.QuotaFiles, folder.Name, group.Name, sortOrder) + return err +} + +func sqlCommonAddUserGroupMapping(ctx context.Context, username, groupName string, groupType, sortOrder int, dbHandle sqlQuerier) error { + q := getAddUserGroupMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, username, groupName, groupType, sortOrder) + return err +} + +func sqlCommonAddAdminGroupMapping(ctx context.Context, username, groupName string, mappingOptions AdminGroupMappingOptions, + sortOrder int, dbHandle sqlQuerier, +) error { + options, err := json.Marshal(mappingOptions) + if err != nil { + return err + } + q := getAddAdminGroupMappingQuery() + _, err = dbHandle.ExecContext(ctx, q, username, groupName, options, sortOrder) + return err +} + +func generateGroupVirtualFoldersMapping(ctx context.Context, group *Group, dbHandle sqlQuerier) error { + err := sqlCommonClearGroupFolderMapping(ctx, group, dbHandle) + if err != nil { + return err + } + for idx := range group.VirtualFolders { + vfolder := &group.VirtualFolders[idx] + err = sqlCommonAddGroupFolderMapping(ctx, group, vfolder, idx, dbHandle) + if err != nil { + return err + } + } + return err +} + +func generateUserVirtualFoldersMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { + err := sqlCommonClearUserFolderMapping(ctx, user, dbHandle) + if err != nil { + return err + } + for idx := range user.VirtualFolders { + vfolder := &user.VirtualFolders[idx] + err = sqlCommonAddUserFolderMapping(ctx, user, vfolder, idx, dbHandle) + if err != nil { + return err + } + } + return err +} + +func generateUserGroupMapping(ctx context.Context, user *User, dbHandle sqlQuerier) error { + err := sqlCommonClearUserGroupMapping(ctx, user, dbHandle) + if err != nil { + return err + } + for idx, group := range user.Groups { + err = sqlCommonAddUserGroupMapping(ctx, user.Username, group.Name, group.Type, idx, dbHandle) + if err != nil { + return err + } + } + return err +} + +func generateAdminGroupMapping(ctx context.Context, admin *Admin, dbHandle sqlQuerier) error { + err := sqlCommonClearAdminGroupMapping(ctx, admin, dbHandle) + if err != nil { + return err + } + for idx, group := range admin.Groups { + err = sqlCommonAddAdminGroupMapping(ctx, admin.Username, group.Name, group.Options, idx, dbHandle) + if err != nil { + return err + } + } + return err +} + +func getDefenderHostsWithScores(ctx context.Context, hosts []DefenderEntry, from int64, idForScores []int64, + dbHandle sqlQuerier) ( + []DefenderEntry, + error, +) { + if len(idForScores) == 0 { + return hosts, nil + } + + hostsWithScores := make(map[int64]int) + q := getDefenderEventsQuery(idForScores) + rows, err := dbHandle.QueryContext(ctx, q, from) + if err != nil { + providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err) + return nil, err + } + defer rows.Close() + + for rows.Next() { + var hostID int64 + var score int + err = rows.Scan(&hostID, &score) + if err != nil { + providerLog(logger.LevelError, "error scanning host score row: %v", err) + return hosts, err + } + if score > 0 { + hostsWithScores[hostID] = score + } + } + + err = rows.Err() + if err != nil { + return hosts, err + } + + result := make([]DefenderEntry, 0, len(hosts)) + + for idx := range hosts { + hosts[idx].Score = hostsWithScores[hosts[idx].ID] + if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() { + result = append(result, hosts[idx]) + } + } + + return result, nil +} + +func getAdminWithGroups(ctx context.Context, admin Admin, dbHandle sqlQuerier) (Admin, error) { + admins, err := getAdminsWithGroups(ctx, []Admin{admin}, dbHandle) + if err != nil { + return admin, err + } + if len(admins) == 0 { + return admin, errSQLGroupsAssociation + } + return admins[0], err +} + +func getAdminsWithGroups(ctx context.Context, admins []Admin, dbHandle sqlQuerier) ([]Admin, error) { + if len(admins) == 0 { + return admins, nil + } + adminsGroups := make(map[int64][]AdminGroupMapping) + q := getRelatedGroupsForAdminsQuery(admins) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var group AdminGroupMapping + var adminID int64 + var options []byte + err = rows.Scan(&group.Name, &options, &adminID) + if err != nil { + return admins, err + } + err = json.Unmarshal(options, &group.Options) + if err != nil { + return admins, err + } + adminsGroups[adminID] = append(adminsGroups[adminID], group) + } + err = rows.Err() + if err != nil { + return admins, err + } + if len(adminsGroups) == 0 { + return admins, err + } + for idx := range admins { + ref := &admins[idx] + ref.Groups = adminsGroups[ref.ID] + } + return admins, err +} + +func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) { + users, err := getUsersWithVirtualFolders(ctx, []User{user}, dbHandle) + if err != nil { + return user, err + } + if len(users) == 0 { + return user, errSQLFoldersAssociation + } + return users[0], err +} + +func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) { + if len(users) == 0 { + return users, nil + } + + usersVirtualFolders := make(map[int64][]vfs.VirtualFolder) + q := getRelatedFoldersForUsersQuery(users) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var folder vfs.VirtualFolder + var userID int64 + var mappedPath, description sql.NullString + var fsConfig []byte + err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, + &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &userID, &fsConfig, + &description) + if err != nil { + return users, err + } + if mappedPath.Valid { + folder.MappedPath = mappedPath.String + } + if description.Valid { + folder.Description = description.String + } + var fs vfs.Filesystem + err = json.Unmarshal(fsConfig, &fs) + if err == nil { + folder.FsConfig = fs + } + usersVirtualFolders[userID] = append(usersVirtualFolders[userID], folder) + } + err = rows.Err() + if err != nil { + return users, err + } + if len(usersVirtualFolders) == 0 { + return users, err + } + for idx := range users { + ref := &users[idx] + ref.VirtualFolders = usersVirtualFolders[ref.ID] + } + return users, err +} + +func getUserWithGroups(ctx context.Context, user User, dbHandle sqlQuerier) (User, error) { + users, err := getUsersWithGroups(ctx, []User{user}, dbHandle) + if err != nil { + return user, err + } + if len(users) == 0 { + return user, errSQLGroupsAssociation + } + return users[0], err +} + +func getUsersWithGroups(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) { + if len(users) == 0 { + return users, nil + } + usersGroups := make(map[int64][]sdk.GroupMapping) + q := getRelatedGroupsForUsersQuery(users) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var group sdk.GroupMapping + var userID int64 + err = rows.Scan(&group.Name, &group.Type, &userID) + if err != nil { + return users, err + } + usersGroups[userID] = append(usersGroups[userID], group) + } + err = rows.Err() + if err != nil { + return users, err + } + if len(usersGroups) == 0 { + return users, err + } + for idx := range users { + ref := &users[idx] + ref.Groups = usersGroups[ref.ID] + } + return users, err +} + +func getGroupWithUsers(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) { + groups, err := getGroupsWithUsers(ctx, []Group{group}, dbHandle) + if err != nil { + return group, err + } + if len(groups) == 0 { + return group, errSQLUsersAssociation + } + return groups[0], err +} + +func getRoleWithUsers(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) { + roles, err := getRolesWithUsers(ctx, []Role{role}, dbHandle) + if err != nil { + return role, err + } + if len(roles) == 0 { + return role, errors.New("unable to associate users with role") + } + return roles[0], err +} + +func getRoleWithAdmins(ctx context.Context, role Role, dbHandle sqlQuerier) (Role, error) { + roles, err := getRolesWithAdmins(ctx, []Role{role}, dbHandle) + if err != nil { + return role, err + } + if len(roles) == 0 { + return role, errors.New("unable to associate admins with role") + } + return roles[0], err +} + +func getGroupWithAdmins(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) { + groups, err := getGroupsWithAdmins(ctx, []Group{group}, dbHandle) + if err != nil { + return group, err + } + if len(groups) == 0 { + return group, errSQLUsersAssociation + } + return groups[0], err +} + +func getGroupWithVirtualFolders(ctx context.Context, group Group, dbHandle sqlQuerier) (Group, error) { + groups, err := getGroupsWithVirtualFolders(ctx, []Group{group}, dbHandle) + if err != nil { + return group, err + } + if len(groups) == 0 { + return group, errSQLFoldersAssociation + } + return groups[0], err +} + +func getGroupsWithVirtualFolders(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) { + if len(groups) == 0 { + return groups, nil + } + q := getRelatedFoldersForGroupsQuery(groups) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + groupsVirtualFolders := make(map[int64][]vfs.VirtualFolder) + + for rows.Next() { + var groupID int64 + var folder vfs.VirtualFolder + var mappedPath, description sql.NullString + var fsConfig []byte + err = rows.Scan(&folder.ID, &folder.Name, &mappedPath, &folder.UsedQuotaSize, &folder.UsedQuotaFiles, + &folder.LastQuotaUpdate, &folder.VirtualPath, &folder.QuotaSize, &folder.QuotaFiles, &groupID, &fsConfig, + &description) + if err != nil { + return groups, err + } + if mappedPath.Valid { + folder.MappedPath = mappedPath.String + } + if description.Valid { + folder.Description = description.String + } + var fs vfs.Filesystem + err = json.Unmarshal(fsConfig, &fs) + if err == nil { + folder.FsConfig = fs + } + groupsVirtualFolders[groupID] = append(groupsVirtualFolders[groupID], folder) + } + err = rows.Err() + if err != nil { + return groups, err + } + if len(groupsVirtualFolders) == 0 { + return groups, err + } + for idx := range groups { + ref := &groups[idx] + ref.VirtualFolders = groupsVirtualFolders[ref.ID] + } + return groups, err +} + +func getGroupsWithUsers(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) { + if len(groups) == 0 { + return groups, nil + } + q := getRelatedUsersForGroupsQuery(groups) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + groupsUsers := make(map[int64][]string) + + for rows.Next() { + var username string + var groupID int64 + err = rows.Scan(&groupID, &username) + if err != nil { + return groups, err + } + groupsUsers[groupID] = append(groupsUsers[groupID], username) + } + err = rows.Err() + if err != nil { + return groups, err + } + if len(groupsUsers) == 0 { + return groups, err + } + for idx := range groups { + ref := &groups[idx] + ref.Users = groupsUsers[ref.ID] + } + return groups, err +} + +func getRolesWithUsers(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) { + if len(roles) == 0 { + return roles, nil + } + rows, err := dbHandle.QueryContext(ctx, getUsersWithRolesQuery(roles)) + if err != nil { + return nil, err + } + defer rows.Close() + + rolesUsers := make(map[int64][]string) + for rows.Next() { + var roleID int64 + var username string + err = rows.Scan(&roleID, &username) + if err != nil { + return roles, err + } + rolesUsers[roleID] = append(rolesUsers[roleID], username) + } + err = rows.Err() + if err != nil { + return roles, err + } + if len(rolesUsers) > 0 { + for idx := range roles { + ref := &roles[idx] + ref.Users = rolesUsers[ref.ID] + } + } + return roles, nil +} + +func getRolesWithAdmins(ctx context.Context, roles []Role, dbHandle sqlQuerier) ([]Role, error) { + if len(roles) == 0 { + return roles, nil + } + rows, err := dbHandle.QueryContext(ctx, getAdminsWithRolesQuery(roles)) + if err != nil { + return nil, err + } + defer rows.Close() + + rolesAdmins := make(map[int64][]string) + for rows.Next() { + var roleID int64 + var username string + err = rows.Scan(&roleID, &username) + if err != nil { + return roles, err + } + rolesAdmins[roleID] = append(rolesAdmins[roleID], username) + } + if err = rows.Err(); err != nil { + return roles, err + } + if len(rolesAdmins) > 0 { + for idx := range roles { + ref := &roles[idx] + ref.Admins = rolesAdmins[ref.ID] + } + } + return roles, nil +} + +func getGroupsWithAdmins(ctx context.Context, groups []Group, dbHandle sqlQuerier) ([]Group, error) { + if len(groups) == 0 { + return groups, nil + } + q := getRelatedAdminsForGroupsQuery(groups) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + groupsAdmins := make(map[int64][]string) + for rows.Next() { + var groupID int64 + var username string + err = rows.Scan(&groupID, &username) + if err != nil { + return groups, err + } + groupsAdmins[groupID] = append(groupsAdmins[groupID], username) + } + err = rows.Err() + if err != nil { + return groups, err + } + if len(groupsAdmins) > 0 { + for idx := range groups { + ref := &groups[idx] + ref.Admins = groupsAdmins[ref.ID] + } + } + return groups, nil +} + +func getVirtualFoldersWithGroups(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { + if len(folders) == 0 { + return folders, nil + } + vFoldersGroups := make(map[int64][]string) + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getRelatedGroupsForFoldersQuery(folders) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var name string + var folderID int64 + err = rows.Scan(&folderID, &name) + if err != nil { + return folders, err + } + vFoldersGroups[folderID] = append(vFoldersGroups[folderID], name) + } + err = rows.Err() + if err != nil { + return folders, err + } + if len(vFoldersGroups) == 0 { + return folders, err + } + for idx := range folders { + ref := &folders[idx] + ref.Groups = vFoldersGroups[ref.ID] + } + return folders, err +} + +func getVirtualFoldersWithUsers(folders []vfs.BaseVirtualFolder, dbHandle sqlQuerier) ([]vfs.BaseVirtualFolder, error) { + if len(folders) == 0 { + return folders, nil + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getRelatedUsersForFoldersQuery(folders) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + vFoldersUsers := make(map[int64][]string) + for rows.Next() { + var username string + var folderID int64 + err = rows.Scan(&folderID, &username) + if err != nil { + return folders, err + } + vFoldersUsers[folderID] = append(vFoldersUsers[folderID], username) + } + err = rows.Err() + if err != nil { + return folders, err + } + if len(vFoldersUsers) == 0 { + return folders, err + } + for idx := range folders { + ref := &folders[idx] + ref.Users = vFoldersUsers[ref.ID] + } + return folders, err +} + +func sqlCommonUpdateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateFolderQuotaQuery(reset) + _, err := dbHandle.ExecContext(ctx, q, sizeAdd, filesAdd, util.GetTimeAsMsSinceEpoch(time.Now()), name) + if err == nil { + providerLog(logger.LevelDebug, "quota updated for folder %q, files increment: %d size increment: %d is reset? %t", + name, filesAdd, sizeAdd, reset) + } else { + providerLog(logger.LevelWarn, "error updating quota for folder %q: %v", name, err) + } + return err +} + +func sqlCommonGetFolderUsedQuota(mappedPath string, dbHandle *sql.DB) (int, int64, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getQuotaFolderQuery() + var usedFiles int + var usedSize int64 + err := dbHandle.QueryRowContext(ctx, q, mappedPath).Scan(&usedSize, &usedFiles) + if err != nil { + providerLog(logger.LevelError, "error getting quota for folder: %v, error: %v", mappedPath, err) + return 0, 0, err + } + return usedFiles, usedSize, err +} + +func getAPIKeyWithRelatedFields(ctx context.Context, apiKey APIKey, dbHandle sqlQuerier) (APIKey, error) { + var apiKeys []APIKey + var err error + + scope := APIKeyScopeAdmin + if apiKey.userID > 0 { + scope = APIKeyScopeUser + } + apiKeys, err = getRelatedValuesForAPIKeys(ctx, []APIKey{apiKey}, dbHandle, scope) + if err != nil { + return apiKey, err + } + if len(apiKeys) > 0 { + apiKey = apiKeys[0] + } + return apiKey, nil +} + +func getRelatedValuesForAPIKeys(ctx context.Context, apiKeys []APIKey, dbHandle sqlQuerier, scope APIKeyScope) ([]APIKey, error) { + if len(apiKeys) == 0 { + return apiKeys, nil + } + values := make(map[int64]string) + var q string + if scope == APIKeyScopeUser { + q = getRelatedUsersForAPIKeysQuery(apiKeys) + } else { + q = getRelatedAdminsForAPIKeysQuery(apiKeys) + } + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var valueID int64 + var valueName string + err = rows.Scan(&valueID, &valueName) + if err != nil { + return apiKeys, err + } + values[valueID] = valueName + } + err = rows.Err() + if err != nil { + return apiKeys, err + } + if len(values) == 0 { + return apiKeys, nil + } + for idx := range apiKeys { + ref := &apiKeys[idx] + if scope == APIKeyScopeUser { + ref.User = values[ref.userID] + } else { + ref.Admin = values[ref.adminID] + } + } + return apiKeys, nil +} + +func sqlCommonGetAPIKeyRelatedIDs(apiKey *APIKey) (sql.NullInt64, sql.NullInt64, error) { + var userID, adminID sql.NullInt64 + if apiKey.User != "" { + u, err := provider.userExists(apiKey.User, "") + if err != nil { + return userID, adminID, util.NewGenericError(fmt.Sprintf("unable to validate user %v", apiKey.User)) + } + userID.Valid = true + userID.Int64 = u.ID + } + if apiKey.Admin != "" { + a, err := provider.adminExists(apiKey.Admin) + if err != nil { + return userID, adminID, util.NewValidationError(fmt.Sprintf("unable to validate admin %v", apiKey.Admin)) + } + adminID.Valid = true + adminID.Int64 = a.ID + } + return userID, adminID, nil +} + +func sqlCommonAddSession(session Session, dbHandle *sql.DB) error { + if err := session.validate(); err != nil { + return err + } + data, err := json.Marshal(session.Data) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddSessionQuery() + _, err = dbHandle.ExecContext(ctx, q, session.Key, data, session.Type, session.Timestamp) + return err +} + +func sqlCommonGetSession(key string, sessionType SessionType, dbHandle sqlQuerier) (Session, error) { + var session Session + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getSessionQuery() + var data []byte // type hint, some driver will use string instead of []byte if the type is any + err := dbHandle.QueryRowContext(ctx, q, key, sessionType).Scan(&session.Key, &data, &session.Type, &session.Timestamp) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return session, util.NewRecordNotFoundError(err.Error()) + } + return session, err + } + session.Data = data + return session, nil +} + +func sqlCommonDeleteSession(key string, sessionType SessionType, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteSessionQuery() + res, err := dbHandle.ExecContext(ctx, q, key, sessionType) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonCleanupSessions(sessionType SessionType, before int64, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getCleanupSessionsQuery() + _, err := dbHandle.ExecContext(ctx, q, sessionType, before) + return err +} + +func getActionsWithRuleNames(ctx context.Context, actions []BaseEventAction, dbHandle sqlQuerier, +) ([]BaseEventAction, error) { + if len(actions) == 0 { + return actions, nil + } + q := getRelatedRulesForActionsQuery(actions) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + actionsRules := make(map[int64][]string) + for rows.Next() { + var name string + var actionID int64 + if err = rows.Scan(&actionID, &name); err != nil { + return nil, err + } + actionsRules[actionID] = append(actionsRules[actionID], name) + } + err = rows.Err() + if err != nil { + return nil, err + } + if len(actionsRules) == 0 { + return actions, nil + } + for idx := range actions { + ref := &actions[idx] + ref.Rules = actionsRules[ref.ID] + } + return actions, nil +} + +func getRulesWithActions(ctx context.Context, rules []EventRule, dbHandle sqlQuerier) ([]EventRule, error) { + if len(rules) == 0 { + return rules, nil + } + rulesActions := make(map[int64][]EventAction) + q := getRelatedActionsForRulesQuery(rules) + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var action EventAction + var ruleID int64 + var description sql.NullString + var baseOptions, options []byte + err = rows.Scan(&action.ID, &action.Name, &description, &action.Type, &baseOptions, &options, + &action.Order, &ruleID) + if err != nil { + return rules, err + } + if len(baseOptions) > 0 { + err = json.Unmarshal(baseOptions, &action.BaseEventAction.Options) + if err != nil { + return rules, err + } + } + if len(options) > 0 { + err = json.Unmarshal(options, &action.Options) + if err != nil { + return rules, err + } + } + action.BaseEventAction.Options.SetEmptySecretsIfNil() + rulesActions[ruleID] = append(rulesActions[ruleID], action) + } + err = rows.Err() + if err != nil { + return rules, err + } + if len(rulesActions) == 0 { + return rules, nil + } + for idx := range rules { + ref := &rules[idx] + ref.Actions = rulesActions[ref.ID] + } + return rules, nil +} + +func generateEventRuleActionsMapping(ctx context.Context, rule *EventRule, dbHandle sqlQuerier) error { + q := getClearRuleActionMappingQuery() + _, err := dbHandle.ExecContext(ctx, q, rule.Name) + if err != nil { + return err + } + for _, action := range rule.Actions { + options, err := json.Marshal(action.Options) + if err != nil { + return err + } + q = getAddRuleActionMappingQuery() + _, err = dbHandle.ExecContext(ctx, q, rule.Name, action.Name, action.Order, options) + if err != nil { + return err + } + } + return nil +} + +func sqlCommonGetEventActionByName(name string, dbHandle sqlQuerier) (BaseEventAction, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getEventActionByNameQuery() + row := dbHandle.QueryRowContext(ctx, q, name) + + action, err := getEventActionFromDbRow(row) + if err != nil { + return action, err + } + actions, err := getActionsWithRuleNames(ctx, []BaseEventAction{action}, dbHandle) + if err != nil { + return action, err + } + if len(actions) != 1 { + return action, fmt.Errorf("unable to associate rules with action %q", name) + } + return actions[0], nil +} + +func sqlCommonDumpEventActions(dbHandle sqlQuerier) ([]BaseEventAction, error) { + actions := make([]BaseEventAction, 0, 10) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpEventActionsQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return actions, err + } + defer rows.Close() + + for rows.Next() { + action, err := getEventActionFromDbRow(rows) + if err != nil { + return actions, err + } + actions = append(actions, action) + } + return actions, rows.Err() +} + +func sqlCommonGetEventActions(limit int, offset int, order string, minimal bool, + dbHandle sqlQuerier, +) ([]BaseEventAction, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getEventsActionsQuery(order, minimal) + + actions := make([]BaseEventAction, 0, limit) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return actions, err + } + defer rows.Close() + + for rows.Next() { + var action BaseEventAction + if minimal { + err = rows.Scan(&action.ID, &action.Name) + } else { + action, err = getEventActionFromDbRow(rows) + } + if err != nil { + return actions, err + } + actions = append(actions, action) + } + err = rows.Err() + if err != nil { + return nil, err + } + if minimal { + return actions, nil + } + actions, err = getActionsWithRuleNames(ctx, actions, dbHandle) + if err != nil { + return nil, err + } + for idx := range actions { + actions[idx].PrepareForRendering() + } + return actions, nil +} + +func sqlCommonAddEventAction(action *BaseEventAction, dbHandle *sql.DB) error { + if err := action.validate(); err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddEventActionQuery() + options, err := json.Marshal(action.Options) + if err != nil { + return err + } + _, err = dbHandle.ExecContext(ctx, q, action.Name, action.Description, action.Type, options) + return err +} + +func sqlCommonUpdateEventAction(action *BaseEventAction, dbHandle *sql.DB) error { + if err := action.validate(); err != nil { + return err + } + options, err := json.Marshal(action.Options) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getUpdateEventActionQuery() + res, err := tx.ExecContext(ctx, q, action.Description, action.Type, options, action.Name) + if err != nil { + return err + } + if err := sqlCommonRequireRowAffected(res); err != nil { + return err + } + q = getUpdateRulesTimestampQuery() + _, err = tx.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), action.Name) + return err + }) +} + +func sqlCommonDeleteEventAction(action BaseEventAction, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteEventActionQuery() + res, err := dbHandle.ExecContext(ctx, q, action.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetEventRuleByName(name string, dbHandle sqlQuerier) (EventRule, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getEventRulesByNameQuery() + row := dbHandle.QueryRowContext(ctx, q, name) + rule, err := getEventRuleFromDbRow(row) + if err != nil { + return rule, err + } + rules, err := getRulesWithActions(ctx, []EventRule{rule}, dbHandle) + if err != nil { + return rule, err + } + if len(rules) != 1 { + return rule, fmt.Errorf("unable to associate rule %q with actions", name) + } + return rules[0], nil +} + +func sqlCommonDumpEventRules(dbHandle sqlQuerier) ([]EventRule, error) { + rules := make([]EventRule, 0, 10) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getDumpEventRulesQuery() + rows, err := dbHandle.QueryContext(ctx, q) + if err != nil { + return rules, err + } + defer rows.Close() + + for rows.Next() { + rule, err := getEventRuleFromDbRow(rows) + if err != nil { + return rules, err + } + rules = append(rules, rule) + } + err = rows.Err() + if err != nil { + return rules, err + } + return getRulesWithActions(ctx, rules, dbHandle) +} + +func sqlCommonGetRecentlyUpdatedRules(after int64, dbHandle sqlQuerier) ([]EventRule, error) { + rules := make([]EventRule, 0, 10) + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + q := getRecentlyUpdatedRulesQuery() + rows, err := dbHandle.QueryContext(ctx, q, after) + if err != nil { + return rules, err + } + defer rows.Close() + + for rows.Next() { + rule, err := getEventRuleFromDbRow(rows) + if err != nil { + return rules, err + } + rules = append(rules, rule) + } + err = rows.Err() + if err != nil { + return rules, err + } + return getRulesWithActions(ctx, rules, dbHandle) +} + +func sqlCommonGetEventRules(limit int, offset int, order string, dbHandle sqlQuerier) ([]EventRule, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getEventRulesQuery(order) + + rules := make([]EventRule, 0, limit) + rows, err := dbHandle.QueryContext(ctx, q, limit, offset) + if err != nil { + return rules, err + } + defer rows.Close() + + for rows.Next() { + rule, err := getEventRuleFromDbRow(rows) + if err != nil { + return rules, err + } + rules = append(rules, rule) + } + err = rows.Err() + if err != nil { + return rules, err + } + rules, err = getRulesWithActions(ctx, rules, dbHandle) + if err != nil { + return rules, err + } + for idx := range rules { + rules[idx].PrepareForRendering() + } + return rules, nil +} + +func sqlCommonAddEventRule(rule *EventRule, dbHandle *sql.DB) error { + if err := rule.validate(); err != nil { + return err + } + conditions, err := json.Marshal(rule.Conditions) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + if config.IsShared == 1 { + _, err := tx.ExecContext(ctx, getRemoveSoftDeletedRuleQuery(), rule.Name) + if err != nil { + return err + } + } + q := getAddEventRuleQuery() + _, err := tx.ExecContext(ctx, q, rule.Name, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now()), rule.Trigger, conditions, rule.Status) + if err != nil { + return err + } + return generateEventRuleActionsMapping(ctx, rule, tx) + }) +} + +func sqlCommonUpdateEventRule(rule *EventRule, dbHandle *sql.DB) error { + if err := rule.validate(); err != nil { + return err + } + conditions, err := json.Marshal(rule.Conditions) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + q := getUpdateEventRuleQuery() + _, err := tx.ExecContext(ctx, q, rule.Description, util.GetTimeAsMsSinceEpoch(time.Now()), + rule.Trigger, conditions, rule.Status, rule.Name) + if err != nil { + return err + } + return generateEventRuleActionsMapping(ctx, rule, tx) + }) +} + +func sqlCommonDeleteEventRule(rule EventRule, softDelete bool, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error { + if softDelete { + q := getClearRuleActionMappingQuery() + _, err := tx.ExecContext(ctx, q, rule.Name) + if err != nil { + return err + } + } + q := getDeleteEventRuleQuery(softDelete) + if softDelete { + ts := util.GetTimeAsMsSinceEpoch(time.Now()) + res, err := tx.ExecContext(ctx, q, ts, ts, rule.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) + } + res, err := tx.ExecContext(ctx, q, rule.Name) + if err != nil { + return err + } + if err = sqlCommonRequireRowAffected(res); err != nil { + return err + } + return sqlCommonDeleteTask(rule.Name, tx) + }) +} + +func sqlCommonGetTaskByName(name string, dbHandle sqlQuerier) (Task, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + task := Task{ + Name: name, + } + q := getTaskByNameQuery() + row := dbHandle.QueryRowContext(ctx, q, name) + err := row.Scan(&task.UpdateAt, &task.Version) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return task, util.NewRecordNotFoundError(err.Error()) + } + } + return task, err +} + +func sqlCommonAddTask(name string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddTaskQuery() + _, err := dbHandle.ExecContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now())) + return err +} + +func sqlCommonUpdateTask(name string, version int64, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateTaskQuery() + res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name, version) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonUpdateTaskTimestamp(name string, dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateTaskTimestampQuery() + res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonDeleteTask(name string, dbHandle sqlQuerier) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDeleteTaskQuery() + _, err := dbHandle.ExecContext(ctx, q, name) + return err +} + +func sqlCommonAddNode(dbHandle *sql.DB) error { + if err := currentNode.validate(); err != nil { + return fmt.Errorf("unable to register cluster node: %w", err) + } + data, err := json.Marshal(currentNode.Data) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getAddNodeQuery() + _, err = dbHandle.ExecContext(ctx, q, currentNode.Name, data, util.GetTimeAsMsSinceEpoch(time.Now()), + util.GetTimeAsMsSinceEpoch(time.Now())) + if err != nil { + return fmt.Errorf("unable to register cluster node: %w", err) + } + providerLog(logger.LevelInfo, "registered as cluster node %q, port: %d, proto: %s", + currentNode.Name, currentNode.Data.Port, currentNode.Data.Proto) + + return nil +} + +func sqlCommonGetNodeByName(name string, dbHandle *sql.DB) (Node, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + var data []byte + var node Node + + q := getNodeByNameQuery() + row := dbHandle.QueryRowContext(ctx, q, name, util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff))) + err := row.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return node, util.NewRecordNotFoundError(err.Error()) + } + return node, err + } + err = json.Unmarshal(data, &node.Data) + return node, err +} + +func sqlCommonGetNodes(dbHandle *sql.DB) ([]Node, error) { + var nodes []Node + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getNodesQuery() + rows, err := dbHandle.QueryContext(ctx, q, currentNode.Name, + util.GetTimeAsMsSinceEpoch(time.Now().Add(activeNodeTimeDiff))) + if err != nil { + return nodes, err + } + defer rows.Close() + for rows.Next() { + var node Node + var data []byte + + err = rows.Scan(&node.Name, &data, &node.CreatedAt, &node.UpdatedAt) + if err != nil { + return nodes, err + } + err = json.Unmarshal(data, &node.Data) + if err != nil { + return nodes, err + } + nodes = append(nodes, node) + } + + return nodes, rows.Err() +} + +func sqlCommonUpdateNodeTimestamp(dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateNodeTimestampQuery() + res, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now()), currentNode.Name) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonCleanupNodes(dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getCleanupNodesQuery() + _, err := dbHandle.ExecContext(ctx, q, util.GetTimeAsMsSinceEpoch(time.Now().Add(10*activeNodeTimeDiff))) + return err +} + +func sqlCommonGetConfigs(dbHandle sqlQuerier) (Configs, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + var result Configs + var configs []byte + q := getConfigsQuery() + err := dbHandle.QueryRowContext(ctx, q).Scan(&configs) + if err != nil { + return result, err + } + err = json.Unmarshal(configs, &result) + return result, err +} + +func sqlCommonSetConfigs(configs *Configs, dbHandle *sql.DB) error { + if err := configs.validate(); err != nil { + return err + } + asJSON, err := json.Marshal(configs) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getUpdateConfigsQuery() + res, err := dbHandle.ExecContext(ctx, q, asJSON) + if err != nil { + return err + } + return sqlCommonRequireRowAffected(res) +} + +func sqlCommonGetDatabaseVersion(dbHandle sqlQuerier, showInitWarn bool) (schemaVersion, error) { + var result schemaVersion + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + q := getDatabaseVersionQuery() + stmt, err := dbHandle.PrepareContext(ctx, q) + if err != nil { + providerLog(logger.LevelError, "error preparing database query %q: %v", q, err) + if showInitWarn && strings.Contains(err.Error(), sqlTableSchemaVersion) { + logger.WarnToConsole("database query error, did you forgot to run the \"initprovider\" command?") + } + return result, err + } + defer stmt.Close() + row := stmt.QueryRowContext(ctx) + err = row.Scan(&result.Version) + return result, err +} + +func sqlCommonRequireRowAffected(res sql.Result) error { + affected, err := res.RowsAffected() + if err == nil && affected == 0 { + return util.NewRecordNotFoundError(sql.ErrNoRows.Error()) + } + return nil +} + +func sqlCommonUpdateDatabaseVersion(ctx context.Context, dbHandle sqlQuerier, version int) error { + q := getUpdateDBVersionQuery() + _, err := dbHandle.ExecContext(ctx, q, version) + return err +} + +func sqlCommonExecSQLAndUpdateDBVersion(dbHandle *sql.DB, sqlQueries []string, newVersion int, isUp bool) error { + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + conn, err := dbHandle.Conn(ctx) + if err != nil { + return fmt.Errorf("unable to get connection from pool: %w", err) + } + defer conn.Close() + + if err := sqlAcquireLock(conn); err != nil { + return err + } + defer sqlReleaseLock(conn) + + if newVersion > 0 { + currentVersion, err := sqlCommonGetDatabaseVersion(conn, false) + if err == nil { + if (isUp && currentVersion.Version >= newVersion) || (!isUp && currentVersion.Version <= newVersion) { + providerLog(logger.LevelInfo, "current schema version: %v, requested: %v, did you execute simultaneous migrations?", + currentVersion.Version, newVersion) + return nil + } + } + } + + return sqlCommonExecuteTxOnConn(ctx, conn, func(tx *sql.Tx) error { + for _, q := range sqlQueries { + if strings.TrimSpace(q) == "" { + continue + } + _, err := tx.ExecContext(ctx, q) + if err != nil { + return err + } + } + if newVersion == 0 { + return nil + } + return sqlCommonUpdateDatabaseVersion(ctx, tx, newVersion) + }) +} + +func sqlAcquireLock(dbHandle *sql.Conn) error { + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + switch config.Driver { + case PGSQLDataProviderName: + _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_lock(101,1)`) + if err != nil { + return fmt.Errorf("unable to get advisory lock: %w", err) + } + providerLog(logger.LevelInfo, "acquired database lock") + case MySQLDataProviderName: + var lockResult sql.NullInt64 + err := dbHandle.QueryRowContext(ctx, `SELECT GET_LOCK('sftpgo.migration',30)`).Scan(&lockResult) + if err != nil { + return fmt.Errorf("unable to get lock: %w", err) + } + if !lockResult.Valid { + return errors.New("unable to get lock: null value returned") + } + if lockResult.Int64 != 1 { + return fmt.Errorf("unable to get lock, result: %d", lockResult.Int64) + } + providerLog(logger.LevelInfo, "acquired database lock") + } + + return nil +} + +func sqlReleaseLock(dbHandle *sql.Conn) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + switch config.Driver { + case PGSQLDataProviderName: + _, err := dbHandle.ExecContext(ctx, `SELECT pg_advisory_unlock(101,1)`) + if err != nil { + providerLog(logger.LevelWarn, "unable to release lock: %v", err) + } else { + providerLog(logger.LevelInfo, "released database lock") + } + case MySQLDataProviderName: + _, err := dbHandle.ExecContext(ctx, `SELECT RELEASE_LOCK('sftpgo.migration')`) + if err != nil { + providerLog(logger.LevelWarn, "unable to release lock: %v", err) + } else { + providerLog(logger.LevelInfo, "released database lock") + } + } +} + +func sqlCommonExecuteTxOnConn(ctx context.Context, conn *sql.Conn, txFn func(*sql.Tx) error) error { + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } + + err = txFn(tx) + if err != nil { + tx.Rollback() //nolint:errcheck + return err + } + return tx.Commit() +} + +func sqlCommonExecuteTx(ctx context.Context, dbHandle *sql.DB, txFn func(*sql.Tx) error) error { + if config.Driver == CockroachDataProviderName { + return crdb.ExecuteTx(ctx, dbHandle, nil, txFn) + } + + tx, err := dbHandle.BeginTx(ctx, nil) + if err != nil { + return err + } + + err = txFn(tx) + if err != nil { + // we don't change the returned error + tx.Rollback() //nolint:errcheck + return err + } + return tx.Commit() +} diff --git a/internal/dataprovider/sqlite.go b/internal/dataprovider/sqlite.go new file mode 100644 index 00000000..98d2d2b9 --- /dev/null +++ b/internal/dataprovider/sqlite.go @@ -0,0 +1,849 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nosqlite && cgo + +package dataprovider + +import ( + "context" + "crypto/x509" + "database/sql" + "errors" + "fmt" + "path/filepath" + "strings" + "time" + + "github.com/mattn/go-sqlite3" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + sqliteResetSQL = `DROP TABLE IF EXISTS "{{api_keys}}"; +DROP TABLE IF EXISTS "{{users_folders_mapping}}"; +DROP TABLE IF EXISTS "{{users_groups_mapping}}"; +DROP TABLE IF EXISTS "{{admins_groups_mapping}}"; +DROP TABLE IF EXISTS "{{groups_folders_mapping}}"; +DROP TABLE IF EXISTS "{{shares_groups_mapping}}"; +DROP TABLE IF EXISTS "{{admins}}"; +DROP TABLE IF EXISTS "{{folders}}"; +DROP TABLE IF EXISTS "{{shares}}"; +DROP TABLE IF EXISTS "{{users}}"; +DROP TABLE IF EXISTS "{{groups}}"; +DROP TABLE IF EXISTS "{{defender_events}}"; +DROP TABLE IF EXISTS "{{defender_hosts}}"; +DROP TABLE IF EXISTS "{{active_transfers}}"; +DROP TABLE IF EXISTS "{{shared_sessions}}"; +DROP TABLE IF EXISTS "{{rules_actions_mapping}}"; +DROP TABLE IF EXISTS "{{events_rules}}"; +DROP TABLE IF EXISTS "{{events_actions}}"; +DROP TABLE IF EXISTS "{{tasks}}"; +DROP TABLE IF EXISTS "{{roles}}"; +DROP TABLE IF EXISTS "{{ip_lists}}"; +DROP TABLE IF EXISTS "{{configs}}"; +DROP TABLE IF EXISTS "{{schema_version}}"; +` + sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY, "version" integer NOT NULL); +CREATE TABLE "{{roles}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); +CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, +"permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL, +"role_id" integer NULL REFERENCES "{{roles}}" ("id") ON DELETE NO ACTION, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL); +CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL, +"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL, +"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL, +"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL); +CREATE TABLE "{{defender_hosts}}" ("id" integer NOT NULL PRIMARY KEY, "ip" varchar(50) NOT NULL UNIQUE, +"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL); +CREATE TABLE "{{defender_events}}" ("id" integer NOT NULL PRIMARY KEY, "date_time" bigint NOT NULL, +"score" integer NOT NULL, "host_id" integer NOT NULL REFERENCES "{{defender_hosts}}" ("id") ON DELETE CASCADE +DEFERRABLE INITIALLY DEFERRED); +CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, +"last_quota_update" bigint NOT NULL, "filesystem" text NULL); +CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL); +CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL, "type" integer NOT NULL, +"data" text NOT NULL, "timestamp" bigint NOT NULL, PRIMARY KEY ("key", "type")); +CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, +"status" integer NOT NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, +"public_keys" text NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, +"max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, +"used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, +"upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, +"filters" text NULL, "filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL, "email" varchar(255) NULL, "upload_data_transfer" integer NOT NULL, +"download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, "used_upload_data_transfer" bigint NOT NULL, +"used_download_data_transfer" bigint NOT NULL, "deleted_at" bigint NOT NULL, "first_download" bigint NOT NULL, +"first_upload" bigint NOT NULL, "last_password_change" bigint NOT NULL, "role_id" integer NULL REFERENCES "{{roles}}" ("id") ON DELETE SET NULL); +CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY, +"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL, +CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id")); +CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY, +"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE NO ACTION, +"group_type" integer NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id")); +CREATE TABLE "{{users_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY, +"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "sort_order" integer NOT NULL, +CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id")); +CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY, "share_id" varchar(60) NOT NULL UNIQUE, +"name" varchar(255) NOT NULL, "description" varchar(512) NULL, "scope" integer NOT NULL, "paths" text NOT NULL, +"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, +"password" text NULL, "max_tokens" integer NOT NULL, "used_tokens" integer NOT NULL, "allow_from" text NULL, "options" text NULL, +"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED); +CREATE TABLE "{{api_keys}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL, +"key_id" varchar(50) NOT NULL UNIQUE, "api_key" varchar(255) NOT NULL UNIQUE, "scope" integer NOT NULL, +"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, +"description" text NULL, "admin_id" integer NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"user_id" integer NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED); +CREATE TABLE "{{events_rules}}" ("id" integer NOT NULL PRIMARY KEY, +"name" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, "description" varchar(512) NULL, "created_at" bigint NOT NULL, +"updated_at" bigint NOT NULL, "trigger" integer NOT NULL, "conditions" text NOT NULL, "deleted_at" bigint NOT NULL); +CREATE TABLE "{{events_actions}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, +"description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL); +CREATE TABLE "{{rules_actions_mapping}}" ("id" integer NOT NULL PRIMARY KEY, +"rule_id" integer NOT NULL REFERENCES "{{events_rules}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"action_id" integer NOT NULL REFERENCES "{{events_actions}}" ("id") ON DELETE NO ACTION DEFERRABLE INITIALLY DEFERRED, +"order" integer NOT NULL, "options" text NOT NULL, +CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id")); +CREATE TABLE "{{tasks}}" ("id" integer NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, +"updated_at" bigint NOT NULL, "version" bigint NOT NULL); +CREATE TABLE "{{admins_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY, +"admin_id" integer NOT NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, +"options" text NOT NULL, "sort_order" integer NOT NULL, CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id")); +CREATE TABLE "{{ip_lists}}" ("id" integer NOT NULL PRIMARY KEY, +"type" integer NOT NULL, "ipornet" varchar(50) NOT NULL, "mode" integer NOT NULL, "description" varchar(512) NULL, +"first" BLOB NOT NULL, "last" BLOB NOT NULL, "ip_type" integer NOT NULL, "protocols" integer NOT NULL, +"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "deleted_at" bigint NOT NULL, +CONSTRAINT "{{prefix}}unique_ipornet_type_mapping" UNIQUE ("type", "ipornet")); +CREATE TABLE "{{configs}}" ("id" integer NOT NULL PRIMARY KEY, "configs" text NOT NULL); +INSERT INTO {{configs}} (configs) VALUES ('{}'); +CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id"); +CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id"); +CREATE INDEX "{{prefix}}users_folders_mapping_sort_order_idx" ON "{{users_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id"); +CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id"); +CREATE INDEX "{{prefix}}users_groups_mapping_sort_order_idx" ON "{{users_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id"); +CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id"); +CREATE INDEX "{{prefix}}groups_folders_mapping_sort_order_idx" ON "{{groups_folders_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id"); +CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id"); +CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at"); +CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at"); +CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id"); +CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at"); +CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time"); +CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); +CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); +CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); +CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); +CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); +CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); +CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); +CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at"); +CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at"); +CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger"); +CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id"); +CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id"); +CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order"); +CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id"); +CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id"); +CREATE INDEX "{{prefix}}admins_groups_mapping_sort_order_idx" ON "{{admins_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}users_role_id_idx" ON "{{users}}" ("role_id"); +CREATE INDEX "{{prefix}}admins_role_id_idx" ON "{{admins}}" ("role_id"); +CREATE INDEX "{{prefix}}ip_lists_type_idx" ON "{{ip_lists}}" ("type"); +CREATE INDEX "{{prefix}}ip_lists_ipornet_idx" ON "{{ip_lists}}" ("ipornet"); +CREATE INDEX "{{prefix}}ip_lists_ip_type_idx" ON "{{ip_lists}}" ("ip_type"); +CREATE INDEX "{{prefix}}ip_lists_ip_updated_at_idx" ON "{{ip_lists}}" ("updated_at"); +CREATE INDEX "{{prefix}}ip_lists_ip_deleted_at_idx" ON "{{ip_lists}}" ("deleted_at"); +CREATE INDEX "{{prefix}}ip_lists_first_last_idx" ON "{{ip_lists}}" ("first", "last"); +INSERT INTO {{schema_version}} (version) VALUES (33); +` + sqliteV34SQL = ` +CREATE TABLE "{{shares_groups_mapping}}" ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "share_id" integer NOT NULL REFERENCES "{{shares}}" ("id") ON DELETE CASCADE, + "group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE, + "permissions" integer NOT NULL, + "sort_order" integer NOT NULL, + CONSTRAINT "{{prefix}}unique_share_group_mapping" UNIQUE ("share_id", "group_id") +); +CREATE INDEX "{{prefix}}shares_groups_mapping_sort_order_idx" ON "{{shares_groups_mapping}}" ("sort_order"); +CREATE INDEX "{{prefix}}shares_groups_mapping_group_id_idx" ON "{{shares_groups_mapping}}" ("group_id"); +CREATE INDEX "{{prefix}}shares_groups_mapping_share_id_idx" ON "{{shares_groups_mapping}}" ("share_id"); +` + sqliteV34DownSQL = `DROP TABLE IF EXISTS "{{shares_groups_mapping}}";` +) + +// SQLiteProvider defines the auth provider for SQLite database +type SQLiteProvider struct { + dbHandle *sql.DB +} + +func init() { + version.AddFeature("+sqlite") +} + +func initializeSQLiteProvider(basePath string) error { + var connectionString string + + if config.ConnectionString == "" { + dbPath := config.Name + if !util.IsFileInputValid(dbPath) { + return fmt.Errorf("invalid database path: %q", dbPath) + } + if !filepath.IsAbs(dbPath) { + dbPath = filepath.Join(basePath, dbPath) + } + connectionString = fmt.Sprintf("file:%s?cache=shared&_foreign_keys=1", dbPath) + } else { + connectionString = config.ConnectionString + } + dbHandle, err := sql.Open("sqlite3", connectionString) + if err != nil { + providerLog(logger.LevelError, "error creating sqlite database handler, connection string: %q, error: %v", + connectionString, err) + return err + } + providerLog(logger.LevelDebug, "sqlite database handle created, connection string: %q", connectionString) + dbHandle.SetMaxOpenConns(1) + provider = &SQLiteProvider{dbHandle: dbHandle} + return executePragmaOptimize(dbHandle) +} + +func (p *SQLiteProvider) checkAvailability() error { + return sqlCommonCheckAvailability(p.dbHandle) +} + +func (p *SQLiteProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { + return sqlCommonValidateUserAndPass(username, password, ip, protocol, p.dbHandle) +} + +func (p *SQLiteProvider) validateUserAndTLSCert(username, protocol string, tlsCert *x509.Certificate) (User, error) { + return sqlCommonValidateUserAndTLSCertificate(username, protocol, tlsCert, p.dbHandle) +} + +func (p *SQLiteProvider) validateUserAndPubKey(username string, publicKey []byte, isSSHCert bool) (User, string, error) { + return sqlCommonValidateUserAndPubKey(username, publicKey, isSSHCert, p.dbHandle) +} + +func (p *SQLiteProvider) updateTransferQuota(username string, uploadSize, downloadSize int64, reset bool) error { + return sqlCommonUpdateTransferQuota(username, uploadSize, downloadSize, reset, p.dbHandle) +} + +func (p *SQLiteProvider) updateQuota(username string, filesAdd int, sizeAdd int64, reset bool) error { + return sqlCommonUpdateQuota(username, filesAdd, sizeAdd, reset, p.dbHandle) +} + +func (p *SQLiteProvider) getUsedQuota(username string) (int, int64, int64, int64, error) { + return sqlCommonGetUsedQuota(username, p.dbHandle) +} + +func (p *SQLiteProvider) getAdminSignature(username string) (string, error) { + return sqlCommonGetAdminSignature(username, p.dbHandle) +} + +func (p *SQLiteProvider) getUserSignature(username string) (string, error) { + return sqlCommonGetUserSignature(username, p.dbHandle) +} + +func (p *SQLiteProvider) setUpdatedAt(username string) { + sqlCommonSetUpdatedAt(username, p.dbHandle) +} + +func (p *SQLiteProvider) updateLastLogin(username string) error { + return sqlCommonUpdateLastLogin(username, p.dbHandle) +} + +func (p *SQLiteProvider) updateAdminLastLogin(username string) error { + return sqlCommonUpdateAdminLastLogin(username, p.dbHandle) +} + +func (p *SQLiteProvider) userExists(username, role string) (User, error) { + return sqlCommonGetUserByUsername(username, role, p.dbHandle) +} + +func (p *SQLiteProvider) addUser(user *User) error { + return p.normalizeError(sqlCommonAddUser(user, p.dbHandle), fieldUsername) +} + +func (p *SQLiteProvider) updateUser(user *User) error { + return p.normalizeError(sqlCommonUpdateUser(user, p.dbHandle), -1) +} + +func (p *SQLiteProvider) deleteUser(user User, softDelete bool) error { + return sqlCommonDeleteUser(user, softDelete, p.dbHandle) +} + +func (p *SQLiteProvider) updateUserPassword(username, password string) error { + return sqlCommonUpdateUserPassword(username, password, p.dbHandle) +} + +func (p *SQLiteProvider) dumpUsers() ([]User, error) { + return sqlCommonDumpUsers(p.dbHandle) +} + +func (p *SQLiteProvider) getRecentlyUpdatedUsers(after int64) ([]User, error) { + return sqlCommonGetRecentlyUpdatedUsers(after, p.dbHandle) +} + +func (p *SQLiteProvider) getUsers(limit int, offset int, order, role string) ([]User, error) { + return sqlCommonGetUsers(limit, offset, order, role, p.dbHandle) +} + +func (p *SQLiteProvider) getUsersForQuotaCheck(toFetch map[string]bool) ([]User, error) { + return sqlCommonGetUsersForQuotaCheck(toFetch, p.dbHandle) +} + +func (p *SQLiteProvider) dumpFolders() ([]vfs.BaseVirtualFolder, error) { + return sqlCommonDumpFolders(p.dbHandle) +} + +func (p *SQLiteProvider) getFolders(limit, offset int, order string, minimal bool) ([]vfs.BaseVirtualFolder, error) { + return sqlCommonGetFolders(limit, offset, order, minimal, p.dbHandle) +} + +func (p *SQLiteProvider) getFolderByName(name string) (vfs.BaseVirtualFolder, error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + return sqlCommonGetFolderByName(ctx, name, p.dbHandle) +} + +func (p *SQLiteProvider) addFolder(folder *vfs.BaseVirtualFolder) error { + return p.normalizeError(sqlCommonAddFolder(folder, p.dbHandle), fieldName) +} + +func (p *SQLiteProvider) updateFolder(folder *vfs.BaseVirtualFolder) error { + return sqlCommonUpdateFolder(folder, p.dbHandle) +} + +func (p *SQLiteProvider) deleteFolder(folder vfs.BaseVirtualFolder) error { + return sqlCommonDeleteFolder(folder, p.dbHandle) +} + +func (p *SQLiteProvider) updateFolderQuota(name string, filesAdd int, sizeAdd int64, reset bool) error { + return sqlCommonUpdateFolderQuota(name, filesAdd, sizeAdd, reset, p.dbHandle) +} + +func (p *SQLiteProvider) getUsedFolderQuota(name string) (int, int64, error) { + return sqlCommonGetFolderUsedQuota(name, p.dbHandle) +} + +func (p *SQLiteProvider) getGroups(limit, offset int, order string, minimal bool) ([]Group, error) { + return sqlCommonGetGroups(limit, offset, order, minimal, p.dbHandle) +} + +func (p *SQLiteProvider) getGroupsWithNames(names []string) ([]Group, error) { + return sqlCommonGetGroupsWithNames(names, p.dbHandle) +} + +func (p *SQLiteProvider) getUsersInGroups(names []string) ([]string, error) { + return sqlCommonGetUsersInGroups(names, p.dbHandle) +} + +func (p *SQLiteProvider) groupExists(name string) (Group, error) { + return sqlCommonGetGroupByName(name, p.dbHandle) +} + +func (p *SQLiteProvider) addGroup(group *Group) error { + return p.normalizeError(sqlCommonAddGroup(group, p.dbHandle), fieldName) +} + +func (p *SQLiteProvider) updateGroup(group *Group) error { + return sqlCommonUpdateGroup(group, p.dbHandle) +} + +func (p *SQLiteProvider) deleteGroup(group Group) error { + return sqlCommonDeleteGroup(group, p.dbHandle) +} + +func (p *SQLiteProvider) dumpGroups() ([]Group, error) { + return sqlCommonDumpGroups(p.dbHandle) +} + +func (p *SQLiteProvider) adminExists(username string) (Admin, error) { + return sqlCommonGetAdminByUsername(username, p.dbHandle) +} + +func (p *SQLiteProvider) addAdmin(admin *Admin) error { + return p.normalizeError(sqlCommonAddAdmin(admin, p.dbHandle), fieldUsername) +} + +func (p *SQLiteProvider) updateAdmin(admin *Admin) error { + return p.normalizeError(sqlCommonUpdateAdmin(admin, p.dbHandle), -1) +} + +func (p *SQLiteProvider) deleteAdmin(admin Admin) error { + return sqlCommonDeleteAdmin(admin, p.dbHandle) +} + +func (p *SQLiteProvider) getAdmins(limit int, offset int, order string) ([]Admin, error) { + return sqlCommonGetAdmins(limit, offset, order, p.dbHandle) +} + +func (p *SQLiteProvider) dumpAdmins() ([]Admin, error) { + return sqlCommonDumpAdmins(p.dbHandle) +} + +func (p *SQLiteProvider) validateAdminAndPass(username, password, ip string) (Admin, error) { + return sqlCommonValidateAdminAndPass(username, password, ip, p.dbHandle) +} + +func (p *SQLiteProvider) apiKeyExists(keyID string) (APIKey, error) { + return sqlCommonGetAPIKeyByID(keyID, p.dbHandle) +} + +func (p *SQLiteProvider) addAPIKey(apiKey *APIKey) error { + return p.normalizeError(sqlCommonAddAPIKey(apiKey, p.dbHandle), -1) +} + +func (p *SQLiteProvider) updateAPIKey(apiKey *APIKey) error { + return p.normalizeError(sqlCommonUpdateAPIKey(apiKey, p.dbHandle), -1) +} + +func (p *SQLiteProvider) deleteAPIKey(apiKey APIKey) error { + return sqlCommonDeleteAPIKey(apiKey, p.dbHandle) +} + +func (p *SQLiteProvider) getAPIKeys(limit int, offset int, order string) ([]APIKey, error) { + return sqlCommonGetAPIKeys(limit, offset, order, p.dbHandle) +} + +func (p *SQLiteProvider) dumpAPIKeys() ([]APIKey, error) { + return sqlCommonDumpAPIKeys(p.dbHandle) +} + +func (p *SQLiteProvider) updateAPIKeyLastUse(keyID string) error { + return sqlCommonUpdateAPIKeyLastUse(keyID, p.dbHandle) +} + +func (p *SQLiteProvider) shareExists(shareID, username string) (Share, error) { + return sqlCommonGetShareByID(shareID, username, p.dbHandle) +} + +func (p *SQLiteProvider) addShare(share *Share) error { + return p.normalizeError(sqlCommonAddShare(share, p.dbHandle), fieldName) +} + +func (p *SQLiteProvider) updateShare(share *Share) error { + return p.normalizeError(sqlCommonUpdateShare(share, p.dbHandle), -1) +} + +func (p *SQLiteProvider) deleteShare(share Share) error { + return sqlCommonDeleteShare(share, p.dbHandle) +} + +func (p *SQLiteProvider) getShares(limit int, offset int, order, username string) ([]Share, error) { + return sqlCommonGetShares(limit, offset, order, username, p.dbHandle) +} + +func (p *SQLiteProvider) dumpShares() ([]Share, error) { + return sqlCommonDumpShares(p.dbHandle) +} + +func (p *SQLiteProvider) updateShareLastUse(shareID string, numTokens int) error { + return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle) +} + +func (p *SQLiteProvider) getDefenderHosts(from int64, limit int) ([]DefenderEntry, error) { + return sqlCommonGetDefenderHosts(from, limit, p.dbHandle) +} + +func (p *SQLiteProvider) getDefenderHostByIP(ip string, from int64) (DefenderEntry, error) { + return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle) +} + +func (p *SQLiteProvider) isDefenderHostBanned(ip string) (DefenderEntry, error) { + return sqlCommonIsDefenderHostBanned(ip, p.dbHandle) +} + +func (p *SQLiteProvider) updateDefenderBanTime(ip string, minutes int) error { + return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle) +} + +func (p *SQLiteProvider) deleteDefenderHost(ip string) error { + return sqlCommonDeleteDefenderHost(ip, p.dbHandle) +} + +func (p *SQLiteProvider) addDefenderEvent(ip string, score int) error { + return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle) +} + +func (p *SQLiteProvider) setDefenderBanTime(ip string, banTime int64) error { + return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle) +} + +func (p *SQLiteProvider) cleanupDefender(from int64) error { + return sqlCommonDefenderCleanup(from, p.dbHandle) +} + +func (p *SQLiteProvider) addActiveTransfer(transfer ActiveTransfer) error { + return sqlCommonAddActiveTransfer(transfer, p.dbHandle) +} + +func (p *SQLiteProvider) updateActiveTransferSizes(ulSize, dlSize, transferID int64, connectionID string) error { + return sqlCommonUpdateActiveTransferSizes(ulSize, dlSize, transferID, connectionID, p.dbHandle) +} + +func (p *SQLiteProvider) removeActiveTransfer(transferID int64, connectionID string) error { + return sqlCommonRemoveActiveTransfer(transferID, connectionID, p.dbHandle) +} + +func (p *SQLiteProvider) cleanupActiveTransfers(before time.Time) error { + return sqlCommonCleanupActiveTransfers(before, p.dbHandle) +} + +func (p *SQLiteProvider) getActiveTransfers(from time.Time) ([]ActiveTransfer, error) { + return sqlCommonGetActiveTransfers(from, p.dbHandle) +} + +func (p *SQLiteProvider) addSharedSession(session Session) error { + return sqlCommonAddSession(session, p.dbHandle) +} + +func (p *SQLiteProvider) deleteSharedSession(key string, sessionType SessionType) error { + return sqlCommonDeleteSession(key, sessionType, p.dbHandle) +} + +func (p *SQLiteProvider) getSharedSession(key string, sessionType SessionType) (Session, error) { + return sqlCommonGetSession(key, sessionType, p.dbHandle) +} + +func (p *SQLiteProvider) cleanupSharedSessions(sessionType SessionType, before int64) error { + return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) +} + +func (p *SQLiteProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) { + return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle) +} + +func (p *SQLiteProvider) dumpEventActions() ([]BaseEventAction, error) { + return sqlCommonDumpEventActions(p.dbHandle) +} + +func (p *SQLiteProvider) eventActionExists(name string) (BaseEventAction, error) { + return sqlCommonGetEventActionByName(name, p.dbHandle) +} + +func (p *SQLiteProvider) addEventAction(action *BaseEventAction) error { + return p.normalizeError(sqlCommonAddEventAction(action, p.dbHandle), fieldName) +} + +func (p *SQLiteProvider) updateEventAction(action *BaseEventAction) error { + return sqlCommonUpdateEventAction(action, p.dbHandle) +} + +func (p *SQLiteProvider) deleteEventAction(action BaseEventAction) error { + return sqlCommonDeleteEventAction(action, p.dbHandle) +} + +func (p *SQLiteProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) { + return sqlCommonGetEventRules(limit, offset, order, p.dbHandle) +} + +func (p *SQLiteProvider) dumpEventRules() ([]EventRule, error) { + return sqlCommonDumpEventRules(p.dbHandle) +} + +func (p *SQLiteProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) { + return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle) +} + +func (p *SQLiteProvider) eventRuleExists(name string) (EventRule, error) { + return sqlCommonGetEventRuleByName(name, p.dbHandle) +} + +func (p *SQLiteProvider) addEventRule(rule *EventRule) error { + return p.normalizeError(sqlCommonAddEventRule(rule, p.dbHandle), fieldName) +} + +func (p *SQLiteProvider) updateEventRule(rule *EventRule) error { + return sqlCommonUpdateEventRule(rule, p.dbHandle) +} + +func (p *SQLiteProvider) deleteEventRule(rule EventRule, softDelete bool) error { + return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle) +} + +func (p *SQLiteProvider) getTaskByName(name string) (Task, error) { + return sqlCommonGetTaskByName(name, p.dbHandle) +} + +func (p *SQLiteProvider) addTask(name string) error { + return sqlCommonAddTask(name, p.dbHandle) +} + +func (p *SQLiteProvider) updateTask(name string, version int64) error { + return sqlCommonUpdateTask(name, version, p.dbHandle) +} + +func (p *SQLiteProvider) updateTaskTimestamp(name string) error { + return sqlCommonUpdateTaskTimestamp(name, p.dbHandle) +} + +func (*SQLiteProvider) addNode() error { + return ErrNotImplemented +} + +func (*SQLiteProvider) getNodeByName(_ string) (Node, error) { + return Node{}, ErrNotImplemented +} + +func (*SQLiteProvider) getNodes() ([]Node, error) { + return nil, ErrNotImplemented +} + +func (*SQLiteProvider) updateNodeTimestamp() error { + return ErrNotImplemented +} + +func (*SQLiteProvider) cleanupNodes() error { + return ErrNotImplemented +} + +func (p *SQLiteProvider) roleExists(name string) (Role, error) { + return sqlCommonGetRoleByName(name, p.dbHandle) +} + +func (p *SQLiteProvider) addRole(role *Role) error { + return p.normalizeError(sqlCommonAddRole(role, p.dbHandle), fieldName) +} + +func (p *SQLiteProvider) updateRole(role *Role) error { + return sqlCommonUpdateRole(role, p.dbHandle) +} + +func (p *SQLiteProvider) deleteRole(role Role) error { + return sqlCommonDeleteRole(role, p.dbHandle) +} + +func (p *SQLiteProvider) getRoles(limit int, offset int, order string, minimal bool) ([]Role, error) { + return sqlCommonGetRoles(limit, offset, order, minimal, p.dbHandle) +} + +func (p *SQLiteProvider) dumpRoles() ([]Role, error) { + return sqlCommonDumpRoles(p.dbHandle) +} + +func (p *SQLiteProvider) ipListEntryExists(ipOrNet string, listType IPListType) (IPListEntry, error) { + return sqlCommonGetIPListEntry(ipOrNet, listType, p.dbHandle) +} + +func (p *SQLiteProvider) addIPListEntry(entry *IPListEntry) error { + return p.normalizeError(sqlCommonAddIPListEntry(entry, p.dbHandle), fieldIPNet) +} + +func (p *SQLiteProvider) updateIPListEntry(entry *IPListEntry) error { + return sqlCommonUpdateIPListEntry(entry, p.dbHandle) +} + +func (p *SQLiteProvider) deleteIPListEntry(entry IPListEntry, softDelete bool) error { + return sqlCommonDeleteIPListEntry(entry, softDelete, p.dbHandle) +} + +func (p *SQLiteProvider) getIPListEntries(listType IPListType, filter, from, order string, limit int) ([]IPListEntry, error) { + return sqlCommonGetIPListEntries(listType, filter, from, order, limit, p.dbHandle) +} + +func (p *SQLiteProvider) getRecentlyUpdatedIPListEntries(after int64) ([]IPListEntry, error) { + return sqlCommonGetRecentlyUpdatedIPListEntries(after, p.dbHandle) +} + +func (p *SQLiteProvider) dumpIPListEntries() ([]IPListEntry, error) { + return sqlCommonDumpIPListEntries(p.dbHandle) +} + +func (p *SQLiteProvider) countIPListEntries(listType IPListType) (int64, error) { + return sqlCommonCountIPListEntries(listType, p.dbHandle) +} + +func (p *SQLiteProvider) getListEntriesForIP(ip string, listType IPListType) ([]IPListEntry, error) { + return sqlCommonGetListEntriesForIP(ip, listType, p.dbHandle) +} + +func (p *SQLiteProvider) getConfigs() (Configs, error) { + return sqlCommonGetConfigs(p.dbHandle) +} + +func (p *SQLiteProvider) setConfigs(configs *Configs) error { + return sqlCommonSetConfigs(configs, p.dbHandle) +} + +func (p *SQLiteProvider) setFirstDownloadTimestamp(username string) error { + return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle) +} + +func (p *SQLiteProvider) setFirstUploadTimestamp(username string) error { + return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle) +} + +func (p *SQLiteProvider) close() error { + return p.dbHandle.Close() +} + +func (p *SQLiteProvider) reloadConfig() error { + return nil +} + +// initializeDatabase creates the initial database structure +func (p *SQLiteProvider) initializeDatabase() error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, false) + if err == nil && dbVersion.Version > 0 { + return ErrNoInitRequired + } + if errors.Is(err, sql.ErrNoRows) { + return errSchemaVersionEmpty + } + logger.InfoToConsole("creating initial database schema, version 33") + providerLog(logger.LevelInfo, "creating initial database schema, version 33") + sql := sqlReplaceAll(sqliteInitialSQL) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 33, true) +} + +func (p *SQLiteProvider) migrateDatabase() error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) + if err != nil { + return err + } + + switch version := dbVersion.Version; { + case version == sqlDatabaseVersion: + providerLog(logger.LevelDebug, "sql database is up to date, current version: %d", version) + return ErrNoInitRequired + case version < 33: + err = errSchemaVersionTooOld(version) + providerLog(logger.LevelError, "%v", err) + logger.ErrorToConsole("%v", err) + return err + case version == 33: + return updateSQLiteDatabaseFromV33(p.dbHandle) + default: + if version > sqlDatabaseVersion { + providerLog(logger.LevelError, "database schema version %d is newer than the supported one: %d", version, + sqlDatabaseVersion) + logger.WarnToConsole("database schema version %d is newer than the supported one: %d", version, + sqlDatabaseVersion) + return nil + } + return fmt.Errorf("database schema version not handled: %d", version) + } +} + +func (p *SQLiteProvider) revertDatabase(targetVersion int) error { + dbVersion, err := sqlCommonGetDatabaseVersion(p.dbHandle, true) + if err != nil { + return err + } + if dbVersion.Version == targetVersion { + return errors.New("current version match target version, nothing to do") + } + + switch dbVersion.Version { + case 34: + return downgradeSQLiteDatabaseFromV34(p.dbHandle) + default: + return fmt.Errorf("database schema version not handled: %d", dbVersion.Version) + } +} + +func (p *SQLiteProvider) resetDatabase() error { + sql := sqlReplaceAll(sqliteResetSQL) + return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) +} + +func (p *SQLiteProvider) normalizeError(err error, fieldType int) error { + if err == nil { + return nil + } + if e, ok := err.(sqlite3.Error); ok { + switch e.ExtendedCode { + case 1555, 2067: + var message string + switch fieldType { + case fieldUsername: + message = util.I18nErrorDuplicatedUsername + case fieldIPNet: + message = util.I18nErrorDuplicatedIPNet + default: + message = util.I18nErrorDuplicatedName + } + return util.NewI18nError( + fmt.Errorf("%w: %s", ErrDuplicatedKey, err.Error()), + message, + ) + case 787: + return fmt.Errorf("%w: %s", ErrForeignKeyViolated, err.Error()) + } + } + return err +} + +func executePragmaOptimize(dbHandle *sql.DB) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout) + defer cancel() + + _, err := dbHandle.ExecContext(ctx, "PRAGMA optimize;") + return err +} + +func updateSQLiteDatabaseFromV33(dbHandle *sql.DB) error { + return updateSQLiteDatabaseFrom33To34(dbHandle) +} + +func downgradeSQLiteDatabaseFromV34(dbHandle *sql.DB) error { + return downgradeSQLiteDatabaseFrom34To33(dbHandle) +} + +func updateSQLiteDatabaseFrom33To34(dbHandle *sql.DB) error { + logger.InfoToConsole("updating database schema version: 33 -> 34") + providerLog(logger.LevelInfo, "updating database schema version: 33 -> 34") + + sql := strings.ReplaceAll(sqliteV34SQL, "{{prefix}}", config.SQLTablesPrefix) + sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares) + sql = strings.ReplaceAll(sql, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 34, true) +} + +func downgradeSQLiteDatabaseFrom34To33(dbHandle *sql.DB) error { + logger.InfoToConsole("downgrading database schema version: 34 -> 33") + providerLog(logger.LevelInfo, "downgrading database schema version: 34 -> 33") + + sql := strings.ReplaceAll(sqliteV34DownSQL, "{{shares_groups_mapping}}", sqlTableSharesGroupsMapping) + return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 33, false) +} + +/*func setPragmaFK(dbHandle *sql.DB, value string) error { + ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) + defer cancel() + + sql := fmt.Sprintf("PRAGMA foreign_keys=%v;", value) + + _, err := dbHandle.ExecContext(ctx, sql) + return err +}*/ diff --git a/internal/dataprovider/sqlite_disabled.go b/internal/dataprovider/sqlite_disabled.go new file mode 100644 index 00000000..22138967 --- /dev/null +++ b/internal/dataprovider/sqlite_disabled.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nosqlite || !cgo + +package dataprovider + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-sqlite") +} + +func initializeSQLiteProvider(_ string) error { + return errors.New("SQLite disabled at build time") +} diff --git a/internal/dataprovider/sqlqueries.go b/internal/dataprovider/sqlqueries.go new file mode 100644 index 00000000..2d5aaf6b --- /dev/null +++ b/internal/dataprovider/sqlqueries.go @@ -0,0 +1,1194 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "fmt" + "strconv" + "strings" + + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + selectUserFields = "u.id,u.username,u.password,u.public_keys,u.home_dir,u.uid,u.gid,u.max_sessions,u.quota_size,u.quota_files," + + "u.permissions,u.used_quota_size,u.used_quota_files,u.last_quota_update,u.upload_bandwidth,u.download_bandwidth," + + "u.expiration_date,u.last_login,u.status,u.filters,u.filesystem,u.additional_info,u.description,u.email,u.created_at," + + "u.updated_at,u.upload_data_transfer,u.download_data_transfer,u.total_data_transfer," + + "u.used_upload_data_transfer,u.used_download_data_transfer,u.deleted_at,u.first_download,u.first_upload,r.name,u.last_password_change" + selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem" + selectAdminFields = "a.id,a.username,a.password,a.status,a.email,a.permissions,a.filters,a.additional_info,a.description,a.created_at,a.updated_at,a.last_login,r.name" + selectAPIKeyFields = "key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id" + selectShareFields = "s.share_id,s.name,s.description,s.scope,s.paths,u.username,s.created_at,s.updated_at,s.last_use_at," + + "s.expires_at,s.password,s.max_tokens,s.used_tokens,s.allow_from" + selectGroupFields = "id,name,description,created_at,updated_at,user_settings" + selectEventActionFields = "id,name,description,type,options" + selectRoleFields = "id,name,description,created_at,updated_at" + selectIPListEntryFields = "type,ipornet,mode,protocols,description,created_at,updated_at,deleted_at" + selectMinimalFields = "id,name" +) + +func getSQLPlaceholders() []string { + var placeholders []string + for i := 1; i <= 100; i++ { + if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName { + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + } else { + placeholders = append(placeholders, "?") + } + } + return placeholders +} + +func getSQLQuotedName(name string) string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("`%s`", name) + } + + return fmt.Sprintf(`"%s"`, name) +} + +func getSelectEventRuleFields() string { + if config.Driver == MySQLDataProviderName { + return "id,name,description,created_at,updated_at,`trigger`,conditions,deleted_at,status" + } + + return `id,name,description,created_at,updated_at,"trigger",conditions,deleted_at,status` +} + +func getCoalesceDefaultForRole(role string) string { + if role != "" { + return "0" + } + return "NULL" +} + +func getAddSessionQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("INSERT INTO %s (`key`,`data`,`type`,`timestamp`) VALUES (%s,%s,%s,%s) "+ + "ON DUPLICATE KEY UPDATE `data`=VALUES(`data`), `timestamp`=VALUES(`timestamp`)", + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key,type) DO UPDATE SET data= + EXCLUDED.data, timestamp=EXCLUDED.timestamp`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getDeleteSessionQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s AND `type` = %s", + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) + } + return fmt.Sprintf(`DELETE FROM %s WHERE key = %s AND type = %s`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getSessionQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s AND `type` = %s", + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) + } + return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s AND type = %s`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getCleanupSessionsQuery() string { + return fmt.Sprintf(`DELETE from %s WHERE type = %s AND timestamp < %s`, + sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getAddDefenderHostQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("INSERT INTO %s (`ip`,`updated_at`,`ban_time`) VALUES (%s,%s,0) ON DUPLICATE KEY UPDATE `updated_at`=VALUES(`updated_at`)", + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) + } + return fmt.Sprintf(`INSERT INTO %s (ip,updated_at,ban_time) VALUES (%s,%s,0) ON CONFLICT (ip) DO UPDATE SET updated_at = EXCLUDED.updated_at RETURNING id`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getAddDefenderEventQuery() string { + return fmt.Sprintf(`INSERT INTO %s (date_time,score,host_id) VALUES (%s,%s,(SELECT id from %s WHERE ip = %s))`, + sqlTableDefenderEvents, sqlPlaceholders[0], sqlPlaceholders[1], sqlTableDefenderHosts, sqlPlaceholders[2]) +} + +func getDefenderHostsQuery() string { + return fmt.Sprintf(`SELECT id,ip,ban_time FROM %s WHERE updated_at >= %s OR ban_time > 0 ORDER BY updated_at DESC LIMIT %s`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDefenderHostQuery() string { + return fmt.Sprintf(`SELECT id,ip,ban_time FROM %s WHERE ip = %s AND (updated_at >= %s OR ban_time > 0)`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDefenderEventsQuery(hostIDs []int64) string { + var sb strings.Builder + for _, hID := range hostIDs { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(hID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } else { + sb.WriteString("(0)") + } + return fmt.Sprintf(`SELECT host_id,SUM(score) FROM %s WHERE date_time >= %s AND host_id IN %s GROUP BY host_id`, + sqlTableDefenderEvents, sqlPlaceholders[0], sb.String()) +} + +func getDefenderIsHostBannedQuery() string { + return fmt.Sprintf(`SELECT id FROM %s WHERE ip = %s AND ban_time >= %s`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDefenderIncrementBanTimeQuery() string { + return fmt.Sprintf(`UPDATE %s SET ban_time = ban_time + %s WHERE ip = %s`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDefenderSetBanTimeQuery() string { + return fmt.Sprintf(`UPDATE %s SET ban_time = %s WHERE ip = %s`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDeleteDefenderHostQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE ip = %s`, sqlTableDefenderHosts, sqlPlaceholders[0]) +} + +func getDefenderHostsCleanupQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE ban_time < %s AND NOT EXISTS ( + SELECT id FROM %s WHERE %s.host_id = %s.id AND %s.date_time > %s)`, + sqlTableDefenderHosts, sqlPlaceholders[0], sqlTableDefenderEvents, sqlTableDefenderEvents, sqlTableDefenderHosts, + sqlTableDefenderEvents, sqlPlaceholders[1]) +} + +func getDefenderEventsCleanupQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE date_time < %s`, sqlTableDefenderEvents, sqlPlaceholders[0]) +} + +func getIPListEntryQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE type = %s AND ipornet = %s AND deleted_at = 0`, + selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getIPListEntriesQuery(filter, from, order string, limit int) string { + var sb strings.Builder + var idx int + + sb.WriteString("SELECT ") + sb.WriteString(selectIPListEntryFields) + sb.WriteString(" FROM ") + sb.WriteString(sqlTableIPLists) + sb.WriteString(" WHERE type = ") + sb.WriteString(sqlPlaceholders[idx]) + idx++ + if from != "" { + if order == OrderASC { + sb.WriteString(" AND ipornet > ") + } else { + sb.WriteString(" AND ipornet < ") + } + sb.WriteString(sqlPlaceholders[idx]) + idx++ + } + if filter != "" { + sb.WriteString(" AND ipornet LIKE ") + sb.WriteString(sqlPlaceholders[idx]) + idx++ + } + sb.WriteString(" AND deleted_at = 0 ") + sb.WriteString(" ORDER BY ipornet ") + sb.WriteString(order) + if limit > 0 { + sb.WriteString(" LIMIT ") + sb.WriteString(sqlPlaceholders[idx]) + } + return sb.String() +} + +func getCountIPListEntriesQuery() string { + return fmt.Sprintf(`SELECT count(ipornet) FROM %s WHERE type = %s AND deleted_at = 0`, sqlTableIPLists, sqlPlaceholders[0]) +} + +func getCountAllIPListEntriesQuery() string { + return fmt.Sprintf(`SELECT count(ipornet) FROM %s WHERE deleted_at = 0`, sqlTableIPLists) +} + +func getIPListEntriesForIPQueryPg() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE type = %s AND deleted_at = 0 AND %s::inet BETWEEN first AND last`, + selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getIPListEntriesForIPQueryNoPg() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE type = %s AND deleted_at = 0 AND ip_type = %s AND %s BETWEEN first AND last`, + selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getRecentlyUpdatedIPListQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE updated_at >= %s OR deleted_at > 0`, + selectIPListEntryFields, sqlTableIPLists, sqlPlaceholders[0]) +} + +func getDumpListEntriesQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at = 0`, selectIPListEntryFields, sqlTableIPLists) +} + +func getAddIPListEntryQuery() string { + return fmt.Sprintf(`INSERT INTO %s (type,ipornet,first,last,ip_type,protocols,description,mode,created_at,updated_at,deleted_at) + VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0)`, sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], + sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], + sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9]) +} + +func getUpdateIPListEntryQuery() string { + return fmt.Sprintf(`UPDATE %s SET mode=%s,protocols=%s,description=%s,updated_at=%s WHERE type = %s AND ipornet = %s`, + sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], + sqlPlaceholders[4], sqlPlaceholders[5]) +} + +func getDeleteIPListEntryQuery(softDelete bool) string { + if softDelete { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE type = %s AND ipornet = %s`, + sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`DELETE FROM %s WHERE type = %s AND ipornet = %s`, + sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getRemoveSoftDeletedIPListEntryQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE type = %s AND ipornet = %s AND deleted_at > 0`, + sqlTableIPLists, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getConfigsQuery() string { + return fmt.Sprintf(`SELECT configs FROM %s LIMIT 1`, sqlTableConfigs) +} + +func getUpdateConfigsQuery() string { + return fmt.Sprintf(`UPDATE %s SET configs = %s`, sqlTableConfigs, sqlPlaceholders[0]) +} + +func getRoleByNameQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectRoleFields, sqlTableRoles, + sqlPlaceholders[0]) +} + +func getRolesQuery(order string, minimal bool) string { + var fieldSelection string + if minimal { + fieldSelection = selectMinimalFields + } else { + fieldSelection = selectRoleFields + } + return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, + sqlTableRoles, order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUsersWithRolesQuery(roles []Role) string { + var sb strings.Builder + for _, r := range roles { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(r.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT r.id, u.username FROM %s u INNER JOIN %s r ON u.role_id = r.id WHERE u.role_id IN %s`, + sqlTableUsers, sqlTableRoles, sb.String()) +} + +func getAdminsWithRolesQuery(roles []Role) string { + var sb strings.Builder + for _, r := range roles { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(r.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT r.id, a.username FROM %s a INNER JOIN %s r ON a.role_id = r.id WHERE a.role_id IN %s`, + sqlTableAdmins, sqlTableRoles, sb.String()) +} + +func getDumpRolesQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, selectRoleFields, sqlTableRoles) +} + +func getAddRoleQuery() string { + return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at) + VALUES (%s,%s,%s,%s)`, sqlTableRoles, sqlPlaceholders[0], sqlPlaceholders[1], + sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getUpdateRoleQuery() string { + return fmt.Sprintf(`UPDATE %s SET description=%s,updated_at=%s + WHERE name = %s`, sqlTableRoles, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getDeleteRoleQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableRoles, sqlPlaceholders[0]) +} + +func getGroupByNameQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectGroupFields, getSQLQuotedName(sqlTableGroups), + sqlPlaceholders[0]) +} + +func getGroupsQuery(order string, minimal bool) string { + var fieldSelection string + if minimal { + fieldSelection = selectMinimalFields + } else { + fieldSelection = selectGroupFields + } + return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, + getSQLQuotedName(sqlTableGroups), order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getGroupsWithNamesQuery(numArgs int) string { + var sb strings.Builder + for idx := 0; idx < numArgs; idx++ { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(sqlPlaceholders[idx]) + } + if sb.Len() > 0 { + sb.WriteString(")") + } else { + sb.WriteString("('')") + } + return fmt.Sprintf(`SELECT %s FROM %s WHERE name in %s`, selectGroupFields, getSQLQuotedName(sqlTableGroups), sb.String()) +} + +func getUsersInGroupsQuery(numArgs int) string { + var sb strings.Builder + for idx := 0; idx < numArgs; idx++ { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(sqlPlaceholders[idx]) + } + if sb.Len() > 0 { + sb.WriteString(")") + } else { + sb.WriteString("('')") + } + return fmt.Sprintf(`SELECT username FROM %s WHERE id IN (SELECT user_id from %s WHERE group_id IN (SELECT id FROM %s WHERE name IN %s))`, + sqlTableUsers, sqlTableUsersGroupsMapping, getSQLQuotedName(sqlTableGroups), sb.String()) +} + +func getDumpGroupsQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, selectGroupFields, getSQLQuotedName(sqlTableGroups)) +} + +func getAddGroupQuery() string { + return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at,user_settings) + VALUES (%s,%s,%s,%s,%s)`, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0], sqlPlaceholders[1], + sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4]) +} + +func getUpdateGroupQuery() string { + return fmt.Sprintf(`UPDATE %s SET description=%s,user_settings=%s,updated_at=%s + WHERE name = %s`, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], + sqlPlaceholders[3]) +} + +func getDeleteGroupQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0]) +} + +func getAdminByUsernameQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s a LEFT JOIN %s r on r.id = a.role_id WHERE a.username = %s`, + selectAdminFields, sqlTableAdmins, sqlTableRoles, sqlPlaceholders[0]) +} + +func getAdminsQuery(order string) string { + return fmt.Sprintf(`SELECT %s FROM %s a LEFT JOIN %s r on r.id = a.role_id ORDER BY a.username %s LIMIT %s OFFSET %s`, + selectAdminFields, sqlTableAdmins, sqlTableRoles, order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDumpAdminsQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s a LEFT JOIN %s r on r.id = a.role_id`, + selectAdminFields, sqlTableAdmins, sqlTableRoles) +} + +func getAddAdminQuery(role string) string { + return fmt.Sprintf(`INSERT INTO %s (username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login,role_id) + VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,COALESCE((SELECT id from %s WHERE name = %s),%s))`, + sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlTableRoles, sqlPlaceholders[10], getCoalesceDefaultForRole(role)) +} + +func getUpdateAdminQuery(role string) string { + return fmt.Sprintf(`UPDATE %s SET password=%s,status=%s,email=%s,permissions=%s,filters=%s,additional_info=%s,description=%s,updated_at=%s, + role_id=COALESCE((SELECT id from %s WHERE name = %s),%s) WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0], + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], + sqlPlaceholders[7], sqlTableRoles, sqlPlaceholders[8], getCoalesceDefaultForRole(role), sqlPlaceholders[9]) +} + +func getDeleteAdminQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0]) +} + +func getShareByIDQuery(filterUser bool) string { + if filterUser { + return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id WHERE s.share_id = %s AND u.username = %s`, + selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) + } + return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id WHERE s.share_id = %s`, + selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0]) +} + +func getSharesQuery(order string) string { + return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id WHERE u.username = %s ORDER BY s.share_id %s LIMIT %s OFFSET %s`, + selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getDumpSharesQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s s INNER JOIN %s u ON s.user_id = u.id`, + selectShareFields, sqlTableShares, sqlTableUsers) +} + +func getAddShareQuery() string { + return fmt.Sprintf(`INSERT INTO %s (share_id,name,description,scope,paths,created_at,updated_at,last_use_at, + expires_at,password,max_tokens,used_tokens,allow_from,user_id) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)`, + sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], + sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], + sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11], + sqlPlaceholders[12], sqlPlaceholders[13]) +} + +func getUpdateShareRestoreQuery() string { + return fmt.Sprintf(`UPDATE %s SET name=%s,description=%s,scope=%s,paths=%s,created_at=%s,updated_at=%s, + last_use_at=%s,expires_at=%s,password=%s,max_tokens=%s,used_tokens=%s,allow_from=%s,user_id=%s WHERE share_id = %s`, sqlTableShares, + sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13]) +} + +func getUpdateShareQuery() string { + return fmt.Sprintf(`UPDATE %s SET name=%s,description=%s,scope=%s,paths=%s,updated_at=%s,expires_at=%s, + password=%s,max_tokens=%s,allow_from=%s,user_id=%s WHERE share_id = %s`, sqlTableShares, + sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlPlaceholders[10]) +} + +func getDeleteShareQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE share_id = %s`, sqlTableShares, sqlPlaceholders[0]) +} + +func getAPIKeyByIDQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE key_id = %s`, selectAPIKeyFields, sqlTableAPIKeys, sqlPlaceholders[0]) +} + +func getAPIKeysQuery(order string) string { + return fmt.Sprintf(`SELECT %s FROM %s ORDER BY key_id %s LIMIT %s OFFSET %s`, selectAPIKeyFields, sqlTableAPIKeys, + order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDumpAPIKeysQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, selectAPIKeyFields, sqlTableAPIKeys) +} + +func getAddAPIKeyQuery() string { + return fmt.Sprintf(`INSERT INTO %s (key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id) + VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1], + sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], + sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10]) +} + +func getUpdateAPIKeyQuery() string { + return fmt.Sprintf(`UPDATE %s SET name=%s,scope=%s,expires_at=%s,user_id=%s,admin_id=%s,description=%s,updated_at=%s + WHERE key_id = %s`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], + sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7]) +} + +func getDeleteAPIKeyQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE key_id = %s`, sqlTableAPIKeys, sqlPlaceholders[0]) +} + +func getRelatedUsersForAPIKeysQuery(apiKeys []APIKey) string { + var sb strings.Builder + for _, k := range apiKeys { + if k.userID == 0 { + continue + } + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(k.userID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } else { + sb.WriteString("(0)") + } + return fmt.Sprintf(`SELECT id,username FROM %s WHERE id IN %s ORDER BY username`, sqlTableUsers, sb.String()) +} + +func getRelatedAdminsForAPIKeysQuery(apiKeys []APIKey) string { + var sb strings.Builder + for _, k := range apiKeys { + if k.adminID == 0 { + continue + } + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(k.adminID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } else { + sb.WriteString("(0)") + } + return fmt.Sprintf(`SELECT id,username FROM %s WHERE id IN %s ORDER BY username`, sqlTableAdmins, sb.String()) +} + +func getUserByUsernameQuery(role string) string { + if role == "" { + return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.username = %s AND u.deleted_at = 0`, + selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0]) + } + return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.username = %s AND u.deleted_at = 0 + AND u.role_id is NOT NULL AND r.name = %s`, + selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUsersQuery(order, role string) string { + if role == "" { + return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE + u.deleted_at = 0 ORDER BY u.username %s LIMIT %s OFFSET %s`, + selectUserFields, sqlTableUsers, sqlTableRoles, order, sqlPlaceholders[0], sqlPlaceholders[1]) + } + return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE + u.deleted_at = 0 AND u.role_id is NOT NULL AND r.name = %s ORDER BY u.username %s LIMIT %s OFFSET %s`, + selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getUsersForQuotaCheckQuery(numArgs int) string { + var sb strings.Builder + for idx := 0; idx < numArgs; idx++ { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(sqlPlaceholders[idx]) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size,total_data_transfer,upload_data_transfer, + download_data_transfer,used_upload_data_transfer,used_download_data_transfer,filters FROM %s WHERE username IN %s`, + sqlTableUsers, sb.String()) +} + +func getRecentlyUpdatedUsersQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.updated_at >= %s OR u.deleted_at > 0`, + selectUserFields, sqlTableUsers, sqlTableRoles, sqlPlaceholders[0]) +} + +func getDumpUsersQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s u LEFT JOIN %s r on r.id = u.role_id WHERE u.deleted_at = 0`, + selectUserFields, sqlTableUsers, sqlTableRoles) +} + +func getDumpFoldersQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, selectFolderFields, sqlTableFolders) +} + +func getUpdateTransferQuotaQuery(reset bool) string { + if reset { + return fmt.Sprintf(`UPDATE %s SET used_upload_data_transfer = %s,used_download_data_transfer = %s,last_quota_update = %s + WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`UPDATE %s SET used_upload_data_transfer = used_upload_data_transfer + %s, + used_download_data_transfer = used_download_data_transfer + %s,last_quota_update = %s + WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getUpdateQuotaQuery(reset bool) string { + if reset { + return fmt.Sprintf(`UPDATE %s SET used_quota_size = %s,used_quota_files = %s,last_quota_update = %s + WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`UPDATE %s SET used_quota_size = used_quota_size + %s,used_quota_files = used_quota_files + %s,last_quota_update = %s + WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getAdminSignatureQuery() string { + return fmt.Sprintf(`SELECT updated_at FROM %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0]) +} + +func getUserSignatureQuery() string { + return fmt.Sprintf(`SELECT updated_at FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0]) +} + +func getSetUpdateAtQuery() string { + return fmt.Sprintf(`UPDATE %s SET updated_at = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getSetFirstUploadQuery() string { + return fmt.Sprintf(`UPDATE %s SET first_upload = %s WHERE username = %s AND first_upload = 0`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getSetFirstDownloadQuery() string { + return fmt.Sprintf(`UPDATE %s SET first_download = %s WHERE username = %s AND first_download = 0`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateLastLoginQuery() string { + return fmt.Sprintf(`UPDATE %s SET last_login = %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateAdminLastLoginQuery() string { + return fmt.Sprintf(`UPDATE %s SET last_login = %s WHERE username = %s`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateAPIKeyLastUseQuery() string { + return fmt.Sprintf(`UPDATE %s SET last_use_at = %s WHERE key_id = %s`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateShareLastUseQuery() string { + return fmt.Sprintf(`UPDATE %s SET last_use_at = %s, used_tokens = used_tokens +%s WHERE share_id = %s`, + sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getQuotaQuery() string { + return fmt.Sprintf(`SELECT used_quota_size,used_quota_files,used_upload_data_transfer, + used_download_data_transfer FROM %s WHERE username = %s`, + sqlTableUsers, sqlPlaceholders[0]) +} + +func getAddUserQuery(role string) string { + return fmt.Sprintf(`INSERT INTO %s (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions, + used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date,filters, + filesystem,additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer, + used_upload_data_transfer,used_download_data_transfer,deleted_at,first_download,first_upload,role_id,last_password_change) + VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,0,0,%s,%s,%s,0,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,0,0,0,0,0, + COALESCE((SELECT id from %s WHERE name=%s),%s),%s)`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], + sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], + sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22], sqlPlaceholders[23], sqlTableRoles, + sqlPlaceholders[24], getCoalesceDefaultForRole(role), sqlPlaceholders[25]) +} + +func getUpdateUserQuery(role string) string { + return fmt.Sprintf(`UPDATE %s SET password=%s,public_keys=%s,home_dir=%s,uid=%s,gid=%s,max_sessions=%s,quota_size=%s, + quota_files=%s,permissions=%s,upload_bandwidth=%s,download_bandwidth=%s,status=%s,expiration_date=%s,filters=%s,filesystem=%s, + additional_info=%s,description=%s,email=%s,updated_at=%s,upload_data_transfer=%s,download_data_transfer=%s, + total_data_transfer=%s,role_id=COALESCE((SELECT id from %s WHERE name=%s),%s),last_password_change=%s WHERE username = %s`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], + sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], + sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14], + sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19], + sqlPlaceholders[20], sqlPlaceholders[21], sqlTableRoles, sqlPlaceholders[22], getCoalesceDefaultForRole(role), + sqlPlaceholders[23], sqlPlaceholders[24]) +} + +func getUpdateUserPasswordQuery() string { + return fmt.Sprintf(`UPDATE %s SET password=%s,updated_at=%s WHERE username = %s`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getDeleteUserQuery(softDelete bool) string { + if softDelete { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE username = %s`, + sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) + } + return fmt.Sprintf(`DELETE FROM %s WHERE username = %s`, sqlTableUsers, sqlPlaceholders[0]) +} + +func getRemoveSoftDeletedUserQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE username = %s AND deleted_at > 0`, sqlTableUsers, sqlPlaceholders[0]) +} + +func getFolderByNameQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0]) +} + +func getAddFolderQuery() string { + return fmt.Sprintf(`INSERT INTO %s (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem) + VALUES (%s,%s,%s,%s,%s,%s,%s)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], + sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6]) +} + +func getUpdateFolderQuery() string { + return fmt.Sprintf(`UPDATE %s SET path=%s,description=%s,filesystem=%s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0], + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getDeleteFolderQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0]) +} + +func getClearUserGroupMappingQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE user_id = (SELECT id FROM %s WHERE username = %s)`, sqlTableUsersGroupsMapping, + sqlTableUsers, sqlPlaceholders[0]) +} + +func getAddUserGroupMappingQuery() string { + return fmt.Sprintf(`INSERT INTO %s (user_id,group_id,group_type,sort_order) VALUES ((SELECT id FROM %s WHERE username = %s), + (SELECT id FROM %s WHERE name = %s),%s,%s)`, + sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0], getSQLQuotedName(sqlTableGroups), + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getClearAdminGroupMappingQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE admin_id = (SELECT id FROM %s WHERE username = %s)`, sqlTableAdminsGroupsMapping, + sqlTableAdmins, sqlPlaceholders[0]) +} + +func getAddAdminGroupMappingQuery() string { + return fmt.Sprintf(`INSERT INTO %s (admin_id,group_id,options,sort_order) VALUES ((SELECT id FROM %s WHERE username = %s), + (SELECT id FROM %s WHERE name = %s),%s,%s)`, + sqlTableAdminsGroupsMapping, sqlTableAdmins, sqlPlaceholders[0], getSQLQuotedName(sqlTableGroups), + sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getClearGroupFolderMappingQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE group_id = (SELECT id FROM %s WHERE name = %s)`, sqlTableGroupsFoldersMapping, + getSQLQuotedName(sqlTableGroups), sqlPlaceholders[0]) +} + +func getAddGroupFolderMappingQuery() string { + return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,group_id,sort_order) + VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE name = %s),%s)`, + sqlTableGroupsFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, + sqlPlaceholders[3], getSQLQuotedName(sqlTableGroups), sqlPlaceholders[4], sqlPlaceholders[5]) +} + +func getClearUserFolderMappingQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE user_id = (SELECT id FROM %s WHERE username = %s)`, sqlTableUsersFoldersMapping, + sqlTableUsers, sqlPlaceholders[0]) +} + +func getAddUserFolderMappingQuery() string { + return fmt.Sprintf(`INSERT INTO %s (virtual_path,quota_size,quota_files,folder_id,user_id,sort_order) + VALUES (%s,%s,%s,(SELECT id FROM %s WHERE name = %s),(SELECT id FROM %s WHERE username = %s),%s)`, + sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders, + sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4], sqlPlaceholders[5]) +} + +func getFoldersQuery(order string, minimal bool) string { + var fieldSelection string + if minimal { + fieldSelection = selectMinimalFields + } else { + fieldSelection = selectFolderFields + } + return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, sqlTableFolders, + order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateFolderQuotaQuery(reset bool) string { + if reset { + return fmt.Sprintf(`UPDATE %s SET used_quota_size = %s,used_quota_files = %s,last_quota_update = %s + WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`UPDATE %s SET used_quota_size = used_quota_size + %s,used_quota_files = used_quota_files + %s,last_quota_update = %s + WHERE name = %s`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getQuotaFolderQuery() string { + return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %s WHERE name = %s`, sqlTableFolders, + sqlPlaceholders[0]) +} + +func getRelatedGroupsForUsersQuery(users []User) string { + var sb strings.Builder + for _, u := range users { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(u.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT g.name,ug.group_type,ug.user_id FROM %s g INNER JOIN %s ug ON g.id = ug.group_id WHERE + ug.user_id IN %s ORDER BY ug.sort_order`, getSQLQuotedName(sqlTableGroups), sqlTableUsersGroupsMapping, sb.String()) +} + +func getRelatedGroupsForAdminsQuery(admins []Admin) string { + var sb strings.Builder + for _, a := range admins { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(a.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT g.name,ag.options,ag.admin_id FROM %s g INNER JOIN %s ag ON g.id = ag.group_id WHERE + ag.admin_id IN %s ORDER BY ag.sort_order`, getSQLQuotedName(sqlTableGroups), sqlTableAdminsGroupsMapping, sb.String()) +} + +func getRelatedFoldersForUsersQuery(users []User) string { + var sb strings.Builder + for _, u := range users { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(u.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, + fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE + fm.user_id IN %s ORDER BY fm.sort_order`, sqlTableFolders, sqlTableUsersFoldersMapping, sb.String()) +} + +func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string { + var sb strings.Builder + for _, f := range folders { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(f.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT fm.folder_id,u.username FROM %s fm INNER JOIN %s u ON fm.user_id = u.id + WHERE fm.folder_id IN %s ORDER BY u.username`, sqlTableUsersFoldersMapping, sqlTableUsers, sb.String()) +} + +func getRelatedGroupsForFoldersQuery(folders []vfs.BaseVirtualFolder) string { + var sb strings.Builder + for _, f := range folders { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(f.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT fm.folder_id,g.name FROM %s fm INNER JOIN %s g ON fm.group_id = g.id + WHERE fm.folder_id IN %s ORDER BY g.name`, sqlTableGroupsFoldersMapping, getSQLQuotedName(sqlTableGroups), + sb.String()) +} + +func getRelatedUsersForGroupsQuery(groups []Group) string { + var sb strings.Builder + for _, g := range groups { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(g.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT um.group_id,u.username FROM %s um INNER JOIN %s u ON um.user_id = u.id + WHERE um.group_id IN %s ORDER BY u.username`, sqlTableUsersGroupsMapping, sqlTableUsers, sb.String()) +} + +func getRelatedAdminsForGroupsQuery(groups []Group) string { + var sb strings.Builder + for _, g := range groups { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(g.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT am.group_id,a.username FROM %s am INNER JOIN %s a ON am.admin_id = a.id + WHERE am.group_id IN %s ORDER BY a.username`, sqlTableAdminsGroupsMapping, sqlTableAdmins, sb.String()) +} + +func getRelatedFoldersForGroupsQuery(groups []Group) string { + var sb strings.Builder + for _, g := range groups { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(g.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path, + fm.quota_size,fm.quota_files,fm.group_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE + fm.group_id IN %s ORDER BY fm.sort_order`, sqlTableFolders, sqlTableGroupsFoldersMapping, sb.String()) +} + +func getActiveTransfersQuery() string { + return fmt.Sprintf(`SELECT transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size, + current_ul_size,current_dl_size,created_at,updated_at FROM %s WHERE updated_at > %s`, + sqlTableActiveTransfers, sqlPlaceholders[0]) +} + +func getAddActiveTransferQuery() string { + return fmt.Sprintf(`INSERT INTO %s (transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size, + current_ul_size,current_dl_size,created_at,updated_at) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)`, + sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], + sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], + sqlPlaceholders[9], sqlPlaceholders[10]) +} + +func getUpdateActiveTransferSizesQuery() string { + return fmt.Sprintf(`UPDATE %s SET current_ul_size=%s,current_dl_size=%s,updated_at=%s WHERE connection_id = %s AND transfer_id = %s`, + sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4]) +} + +func getRemoveActiveTransferQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE connection_id = %s AND transfer_id = %s`, + sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getCleanupActiveTransfersQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE updated_at < %s`, sqlTableActiveTransfers, sqlPlaceholders[0]) +} + +func getRelatedRulesForActionsQuery(actions []BaseEventAction) string { + var sb strings.Builder + for _, a := range actions { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(a.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT am.action_id,r.name FROM %s am INNER JOIN %s r ON am.rule_id = r.id + WHERE am.action_id IN %s ORDER BY r.name ASC`, sqlTableRulesActionsMapping, sqlTableEventsRules, sb.String()) +} + +func getEventsActionsQuery(order string, minimal bool) string { + var fieldSelection string + if minimal { + fieldSelection = selectMinimalFields + } else { + fieldSelection = selectEventActionFields + } + return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %s OFFSET %s`, fieldSelection, + sqlTableEventsActions, order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDumpEventActionsQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, selectEventActionFields, sqlTableEventsActions) +} + +func getEventActionByNameQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectEventActionFields, sqlTableEventsActions, + sqlPlaceholders[0]) +} + +func getAddEventActionQuery() string { + return fmt.Sprintf(`INSERT INTO %s (name,description,type,options) VALUES (%s,%s,%s,%s)`, + sqlTableEventsActions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getUpdateEventActionQuery() string { + return fmt.Sprintf(`UPDATE %s SET description=%s,type=%s,options=%s WHERE name = %s`, sqlTableEventsActions, + sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getDeleteEventActionQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableEventsActions, sqlPlaceholders[0]) +} + +func getEventRulesQuery(order string) string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at = 0 ORDER BY name %s LIMIT %s OFFSET %s`, + getSelectEventRuleFields(), sqlTableEventsRules, order, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDumpEventRulesQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at = 0`, getSelectEventRuleFields(), sqlTableEventsRules) +} + +func getRecentlyUpdatedRulesQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE updated_at >= %s OR deleted_at > 0`, getSelectEventRuleFields(), + sqlTableEventsRules, sqlPlaceholders[0]) +} + +func getEventRulesByNameQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s AND deleted_at = 0`, getSelectEventRuleFields(), sqlTableEventsRules, + sqlPlaceholders[0]) +} + +func getAddEventRuleQuery() string { + return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at,%s,conditions,deleted_at,status) + VALUES (%s,%s,%s,%s,%s,%s,0,%s)`, + sqlTableEventsRules, getSQLQuotedName("trigger"), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], + sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6]) +} + +func getUpdateEventRuleQuery() string { + return fmt.Sprintf(`UPDATE %s SET description=%s,updated_at=%s,%s=%s,conditions=%s,status=%s WHERE name = %s`, + sqlTableEventsRules, sqlPlaceholders[0], sqlPlaceholders[1], getSQLQuotedName("trigger"), sqlPlaceholders[2], + sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5]) +} + +func getDeleteEventRuleQuery(softDelete bool) string { + if softDelete { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s,deleted_at=%s WHERE name = %s`, + sqlTableEventsRules, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) + } + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableEventsRules, sqlPlaceholders[0]) +} + +func getRemoveSoftDeletedRuleQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s AND deleted_at > 0`, sqlTableEventsRules, sqlPlaceholders[0]) +} + +func getClearRuleActionMappingQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE rule_id = (SELECT id FROM %s WHERE name = %s)`, sqlTableRulesActionsMapping, + sqlTableEventsRules, sqlPlaceholders[0]) +} + +func getUpdateRulesTimestampQuery() string { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE id IN (SELECT rule_id FROM %s WHERE action_id = (SELECT id from %s WHERE name = %s))`, + sqlTableEventsRules, sqlPlaceholders[0], sqlTableRulesActionsMapping, sqlTableEventsActions, sqlPlaceholders[1]) +} + +func getRelatedActionsForRulesQuery(rules []EventRule) string { + var sb strings.Builder + for _, r := range rules { + if sb.Len() == 0 { + sb.WriteString("(") + } else { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(r.ID, 10)) + } + if sb.Len() > 0 { + sb.WriteString(")") + } + return fmt.Sprintf(`SELECT a.id,a.name,a.description,a.type,a.options,am.options,am.%s, + am.rule_id FROM %s a INNER JOIN %s am ON a.id = am.action_id WHERE am.rule_id IN %s ORDER BY am.%s ASC`, + getSQLQuotedName("order"), sqlTableEventsActions, sqlTableRulesActionsMapping, sb.String(), + getSQLQuotedName("order")) +} + +func getAddRuleActionMappingQuery() string { + return fmt.Sprintf(`INSERT INTO %s (rule_id,action_id,%s,options) VALUES ((SELECT id FROM %s WHERE name = %s), + (SELECT id FROM %s WHERE name = %s),%s,%s)`, + sqlTableRulesActionsMapping, getSQLQuotedName("order"), sqlTableEventsRules, sqlPlaceholders[0], + sqlTableEventsActions, sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getTaskByNameQuery() string { + return fmt.Sprintf(`SELECT updated_at,version FROM %s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0]) +} + +func getAddTaskQuery() string { + return fmt.Sprintf(`INSERT INTO %s (name,updated_at,version) VALUES (%s,%s,0)`, + sqlTableTasks, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getUpdateTaskQuery() string { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s,version = version + 1 WHERE name = %s AND version = %s`, + sqlTableTasks, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2]) +} + +func getUpdateTaskTimestampQuery() string { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE name = %s`, + sqlTableTasks, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getDeleteTaskQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, sqlTableTasks, sqlPlaceholders[0]) +} + +func getAddNodeQuery() string { + if config.Driver == MySQLDataProviderName { + return fmt.Sprintf("INSERT INTO %s (`name`,`data`,created_at,`updated_at`) VALUES (%s,%s,%s,%s) ON DUPLICATE KEY UPDATE "+ + "`data`=VALUES(`data`), `created_at`=VALUES(`created_at`), `updated_at`=VALUES(`updated_at`)", + sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) + } + return fmt.Sprintf(`INSERT INTO %s (name,data,created_at,updated_at) VALUES (%s,%s,%s,%s) ON CONFLICT(name) + DO UPDATE SET data=EXCLUDED.data, created_at=EXCLUDED.created_at, updated_at=EXCLUDED.updated_at`, + sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3]) +} + +func getUpdateNodeTimestampQuery() string { + return fmt.Sprintf(`UPDATE %s SET updated_at=%s WHERE name = %s`, + sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getNodeByNameQuery() string { + return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name = %s AND updated_at > %s`, + sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getNodesQuery() string { + return fmt.Sprintf(`SELECT name,data,created_at,updated_at FROM %s WHERE name != %s AND updated_at > %s`, + sqlTableNodes, sqlPlaceholders[0], sqlPlaceholders[1]) +} + +func getCleanupNodesQuery() string { + return fmt.Sprintf(`DELETE FROM %s WHERE updated_at < %s`, sqlTableNodes, sqlPlaceholders[0]) +} + +func getDatabaseVersionQuery() string { + return fmt.Sprintf("SELECT version from %s LIMIT 1", sqlTableSchemaVersion) +} + +func getUpdateDBVersionQuery() string { + return fmt.Sprintf(`UPDATE %s SET version=%s`, sqlTableSchemaVersion, sqlPlaceholders[0]) +} diff --git a/internal/dataprovider/unixcrypt.go b/internal/dataprovider/unixcrypt.go new file mode 100644 index 00000000..26c8f646 --- /dev/null +++ b/internal/dataprovider/unixcrypt.go @@ -0,0 +1,38 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build unixcrypt && cgo + +package dataprovider + +import ( + "strings" + + "github.com/amoghe/go-crypt" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("+unixcrypt") +} + +func compareYescryptPassword(hashedPwd, plainPwd string) (bool, error) { + lastIdx := strings.LastIndex(hashedPwd, "$") + pwd, err := crypt.Crypt(plainPwd, hashedPwd[:lastIdx+1]) + if err != nil { + return false, err + } + return pwd == hashedPwd, nil +} diff --git a/internal/dataprovider/unixcrypt_disabled.go b/internal/dataprovider/unixcrypt_disabled.go new file mode 100644 index 00000000..3e865110 --- /dev/null +++ b/internal/dataprovider/unixcrypt_disabled.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !unixcrypt || !cgo + +package dataprovider + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-unixcrypt") +} + +func compareYescryptPassword(_, _ string) (bool, error) { + return false, errors.New("yescrypt hash format is not supported or disabled") +} diff --git a/internal/dataprovider/user.go b/internal/dataprovider/user.go new file mode 100644 index 00000000..8cea928d --- /dev/null +++ b/internal/dataprovider/user.go @@ -0,0 +1,1871 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package dataprovider + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "net" + "os" + "path" + "path/filepath" + "slices" + "strconv" + "strings" + "time" + + "github.com/rs/xid" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// Available permissions for SFTPGo users +const ( + // All permissions are granted + PermAny = "*" + // List items such as files and directories is allowed + PermListItems = "list" + // download files is allowed + PermDownload = "download" + // upload files is allowed + PermUpload = "upload" + // overwrite an existing file, while uploading, is allowed + // upload permission is required to allow file overwrite + PermOverwrite = "overwrite" + // delete files or directories is allowed + PermDelete = "delete" + // delete files is allowed + PermDeleteFiles = "delete_files" + // delete directories is allowed + PermDeleteDirs = "delete_dirs" + // rename files or directories is allowed + PermRename = "rename" + // rename files is allowed + PermRenameFiles = "rename_files" + // rename directories is allowed + PermRenameDirs = "rename_dirs" + // create directories is allowed + PermCreateDirs = "create_dirs" + // create symbolic links is allowed + PermCreateSymlinks = "create_symlinks" + // changing file or directory permissions is allowed + PermChmod = "chmod" + // changing file or directory owner and group is allowed + PermChown = "chown" + // changing file or directory access and modification time is allowed + PermChtimes = "chtimes" + // copying files or directories is allowed + PermCopy = "copy" +) + +// Available login methods +const ( + LoginMethodNoAuthTried = "no_auth_tried" + LoginMethodPassword = "password" + SSHLoginMethodPassword = "password-over-SSH" + SSHLoginMethodPublicKey = "publickey" + SSHLoginMethodKeyboardInteractive = "keyboard-interactive" + SSHLoginMethodKeyAndPassword = "publickey+password" + SSHLoginMethodKeyAndKeyboardInt = "publickey+keyboard-interactive" + LoginMethodTLSCertificate = "TLSCertificate" + LoginMethodTLSCertificateAndPwd = "TLSCertificate+password" + LoginMethodIDP = "IDP" +) + +var ( + errNoMatchingVirtualFolder = errors.New("no matching virtual folder found") + permsRenameAny = []string{PermRename, PermRenameDirs, PermRenameFiles} + permsDeleteAny = []string{PermDelete, PermDeleteDirs, PermDeleteFiles} +) + +// RecoveryCode defines a 2FA recovery code +type RecoveryCode struct { + Secret *kms.Secret `json:"secret"` + Used bool `json:"used,omitempty"` +} + +// UserTOTPConfig defines the time-based one time password configuration +type UserTOTPConfig struct { + Enabled bool `json:"enabled,omitempty"` + ConfigName string `json:"config_name,omitempty"` + Secret *kms.Secret `json:"secret,omitempty"` + // TOTP will be required for the specified protocols. + // SSH protocol (SFTP/SCP/SSH commands) will ask for the TOTP passcode if the client uses keyboard interactive + // authentication. + // FTP have no standard way to support two factor authentication, if you + // enable the support for this protocol you have to add the TOTP passcode after the password. + // For example if your password is "password" and your one time passcode is + // "123456" you have to use "password123456" as password. + Protocols []string `json:"protocols,omitempty"` +} + +// UserFilters defines additional restrictions for a user +// TODO: rename to UserOptions in v3 +type UserFilters struct { + sdk.BaseUserFilters + // User must change password from WebClient/REST API at next login. + RequirePasswordChange bool `json:"require_password_change,omitempty"` + // AdditionalEmails defines additional email addresses + AdditionalEmails []string `json:"additional_emails,omitempty"` + // Time-based one time passwords configuration + TOTPConfig UserTOTPConfig `json:"totp_config,omitempty"` + // Recovery codes to use if the user loses access to their second factor auth device. + // Each code can only be used once, you should use these codes to login and disable or + // reset 2FA for your account + RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"` +} + +// User defines a SFTPGo user +type User struct { + sdk.BaseUser + // Additional restrictions + Filters UserFilters `json:"filters"` + // Mapping between virtual paths and virtual folders + VirtualFolders []vfs.VirtualFolder `json:"virtual_folders,omitempty"` + // Filesystem configuration details + FsConfig vfs.Filesystem `json:"filesystem"` + // groups associated with this user + Groups []sdk.GroupMapping `json:"groups,omitempty"` + // we store the filesystem here using the base path as key. + fsCache map[string]vfs.Fs `json:"-"` + // true if group settings are already applied for this user + groupSettingsApplied bool `json:"-"` + // in multi node setups we mark the user as deleted to be able to update the webdav cache + DeletedAt int64 `json:"-"` +} + +// GetFilesystem returns the base filesystem for this user +func (u *User) GetFilesystem(connectionID string) (fs vfs.Fs, err error) { + return u.GetFilesystemForPath("/", connectionID) +} + +func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) { + switch u.FsConfig.Provider { + case sdk.S3FilesystemProvider: + return vfs.NewS3Fs(connectionID, u.GetHomeDir(), "", u.FsConfig.S3Config) + case sdk.GCSFilesystemProvider: + return vfs.NewGCSFs(connectionID, u.GetHomeDir(), "", u.FsConfig.GCSConfig) + case sdk.AzureBlobFilesystemProvider: + return vfs.NewAzBlobFs(connectionID, u.GetHomeDir(), "", u.FsConfig.AzBlobConfig) + case sdk.CryptedFilesystemProvider: + return vfs.NewCryptFs(connectionID, u.GetHomeDir(), "", u.FsConfig.CryptConfig) + case sdk.SFTPFilesystemProvider: + forbiddenSelfUsers, err := u.getForbiddenSFTPSelfUsers(u.FsConfig.SFTPConfig.Username) + if err != nil { + return nil, err + } + forbiddenSelfUsers = append(forbiddenSelfUsers, u.Username) + return vfs.NewSFTPFs(connectionID, "", u.GetHomeDir(), forbiddenSelfUsers, u.FsConfig.SFTPConfig) + case sdk.HTTPFilesystemProvider: + return vfs.NewHTTPFs(connectionID, u.GetHomeDir(), "", u.FsConfig.HTTPConfig) + default: + return vfs.NewOsFs(connectionID, u.GetHomeDir(), "", &u.FsConfig.OSConfig), nil + } +} + +func (u *User) checkDirWithParents(virtualDirPath, connectionID string) error { + dirs := util.GetDirsForVirtualPath(virtualDirPath) + for idx := len(dirs) - 1; idx >= 0; idx-- { + vPath := dirs[idx] + if vPath == "/" { + continue + } + fs, err := u.GetFilesystemForPath(vPath, connectionID) + if err != nil { + return fmt.Errorf("unable to get fs for path %q: %w", vPath, err) + } + if fs.HasVirtualFolders() { + continue + } + fsPath, err := fs.ResolvePath(vPath) + if err != nil { + return fmt.Errorf("unable to resolve path %q: %w", vPath, err) + } + _, err = fs.Stat(fsPath) + if err == nil { + continue + } + if fs.IsNotExist(err) { + err = fs.Mkdir(fsPath) + if err != nil { + return err + } + vfs.SetPathPermissions(fs, fsPath, u.GetUID(), u.GetGID()) + } else { + return fmt.Errorf("unable to stat path %q: %w", vPath, err) + } + } + + return nil +} + +func (u *User) checkLocalHomeDir(connectionID string) { + switch u.FsConfig.Provider { + case sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider: + return + default: + osFs := vfs.NewOsFs(connectionID, u.GetHomeDir(), "", nil) + osFs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) + } +} + +func (u *User) checkRootPath(connectionID string) error { + fs, err := u.GetFilesystemForPath("/", connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "could not create main filesystem for user %q err: %v", u.Username, err) + return fmt.Errorf("could not create root filesystem: %w", err) + } + fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) + return nil +} + +// CheckFsRoot check the root directory for the main fs and the virtual folders. +// It returns an error if the main filesystem cannot be created +func (u *User) CheckFsRoot(connectionID string) error { + if u.Filters.DisableFsChecks { + return nil + } + delay := lastLoginMinDelay + if u.Filters.ExternalAuthCacheTime > 0 { + cacheTime := time.Duration(u.Filters.ExternalAuthCacheTime) * time.Second + if cacheTime > delay { + delay = cacheTime + } + } + if isLastActivityRecent(u.LastLogin, delay) { + if u.LastLogin > u.UpdatedAt { + if config.IsShared == 1 { + u.checkLocalHomeDir(connectionID) + } + return nil + } + } + err := u.checkRootPath(connectionID) + if err != nil { + return err + } + if u.Filters.StartDirectory != "" { + err = u.checkDirWithParents(u.Filters.StartDirectory, connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "could not create start directory %q, err: %v", + u.Filters.StartDirectory, err) + } + } + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + fs, err := u.GetFilesystemForPath(v.VirtualPath, connectionID) + if err == nil { + fs.CheckRootPath(u.Username, u.GetUID(), u.GetGID()) + } + // now check intermediary folders + err = u.checkDirWithParents(path.Dir(v.VirtualPath), connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "could not create intermediary dir to %q, err: %v", v.VirtualPath, err) + } + } + return nil +} + +// GetCleanedPath returns a clean POSIX absolute path using the user start directory as base +// if the provided rawVirtualPath is relative +func (u *User) GetCleanedPath(rawVirtualPath string) string { + if u.Filters.StartDirectory != "" { + if !path.IsAbs(rawVirtualPath) { + var b strings.Builder + + b.Grow(len(u.Filters.StartDirectory) + 1 + len(rawVirtualPath)) + b.WriteString(u.Filters.StartDirectory) + b.WriteString("/") + b.WriteString(rawVirtualPath) + return util.CleanPath(b.String()) + } + } + return util.CleanPath(rawVirtualPath) +} + +// isFsEqual returns true if the filesystem configurations are the same +func (u *User) isFsEqual(other *User) bool { + if u.FsConfig.Provider == sdk.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() { + return false + } + if !u.FsConfig.IsEqual(other.FsConfig) { + return false + } + if u.Filters.StartDirectory != other.Filters.StartDirectory { + return false + } + if len(u.VirtualFolders) != len(other.VirtualFolders) { + return false + } + for idx := range u.VirtualFolders { + f := &u.VirtualFolders[idx] + found := false + for idx1 := range other.VirtualFolders { + f1 := &other.VirtualFolders[idx1] + if f.VirtualPath == f1.VirtualPath { + found = true + if f.FsConfig.Provider == sdk.LocalFilesystemProvider && f.MappedPath != f1.MappedPath { + return false + } + if !f.FsConfig.IsEqual(f1.FsConfig) { + return false + } + } + } + if !found { + return false + } + } + return true +} + +func (u *User) isTimeBasedAccessAllowed(when time.Time) bool { + if len(u.Filters.AccessTime) == 0 { + return true + } + if when.IsZero() { + when = time.Now() + } + if UseLocalTime() { + when = when.Local() + } else { + when = when.UTC() + } + weekDay := when.Weekday() + hhMM := when.Format("15:04") + for _, p := range u.Filters.AccessTime { + if p.DayOfWeek == int(weekDay) { + if hhMM >= p.From && hhMM <= p.To { + return true + } + } + } + return false +} + +// CheckLoginConditions checks user access restrictions +func (u *User) CheckLoginConditions() error { + if u.Status < 1 { + return fmt.Errorf("user %q is disabled", u.Username) + } + if u.ExpirationDate > 0 && u.ExpirationDate < util.GetTimeAsMsSinceEpoch(time.Now()) { + return fmt.Errorf("user %q is expired, expiration timestamp: %v current timestamp: %v", u.Username, + u.ExpirationDate, util.GetTimeAsMsSinceEpoch(time.Now())) + } + if u.isTimeBasedAccessAllowed(time.Now()) { + return nil + } + return errors.New("access is not allowed at this time") +} + +// hideConfidentialData hides user confidential data +func (u *User) hideConfidentialData() { + u.Password = "" + u.FsConfig.HideConfidentialData() + if u.Filters.TOTPConfig.Secret != nil { + u.Filters.TOTPConfig.Secret.Hide() + } + for _, code := range u.Filters.RecoveryCodes { + if code.Secret != nil { + code.Secret.Hide() + } + } +} + +// CheckMaxShareExpiration returns an error if the share expiration exceed the +// maximum allowed date. +func (u *User) CheckMaxShareExpiration(expiresAt time.Time) error { + if u.Filters.MaxSharesExpiration == 0 { + return nil + } + maxAllowedExpiration := time.Now().Add(24 * time.Hour * time.Duration(u.Filters.MaxSharesExpiration+1)) + maxAllowedExpiration = time.Date(maxAllowedExpiration.Year(), maxAllowedExpiration.Month(), + maxAllowedExpiration.Day(), 0, 0, 0, 0, maxAllowedExpiration.Location()) + if util.GetTimeAsMsSinceEpoch(expiresAt) == 0 || expiresAt.After(maxAllowedExpiration) { + return util.NewValidationError(fmt.Sprintf("the share must expire before %s", maxAllowedExpiration.Format(time.DateOnly))) + } + return nil +} + +// GetEmailAddresses returns all the email addresses. +func (u *User) GetEmailAddresses() []string { + var res []string + if u.Email != "" { + res = append(res, u.Email) + } + return slices.Concat(res, u.Filters.AdditionalEmails) +} + +// GetSubDirPermissions returns permissions for sub directories +func (u *User) GetSubDirPermissions() []sdk.DirectoryPermissions { + var result []sdk.DirectoryPermissions + for k, v := range u.Permissions { + if k == "/" { + continue + } + dirPerms := sdk.DirectoryPermissions{ + Path: k, + Permissions: v, + } + result = append(result, dirPerms) + } + return result +} + +func (u *User) setAnonymousSettings() { + for k := range u.Permissions { + u.Permissions[k] = []string{PermListItems, PermDownload} + } + u.Filters.DeniedProtocols = append(u.Filters.DeniedProtocols, protocolSSH, protocolHTTP) + u.Filters.DeniedProtocols = util.RemoveDuplicates(u.Filters.DeniedProtocols, false) + for _, method := range ValidLoginMethods { + if method != LoginMethodPassword { + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, method) + } + } + u.Filters.DeniedLoginMethods = util.RemoveDuplicates(u.Filters.DeniedLoginMethods, false) +} + +// RenderAsJSON implements the renderer interface used within plugins +func (u *User) RenderAsJSON(reload bool) ([]byte, error) { + if reload { + user, err := provider.userExists(u.Username, "") + if err != nil { + providerLog(logger.LevelError, "unable to reload user before rendering as json: %v", err) + return nil, err + } + user.PrepareForRendering() + return json.Marshal(user) + } + u.PrepareForRendering() + return json.Marshal(u) +} + +// PrepareForRendering prepares a user for rendering. +// It hides confidential data and set to nil the empty secrets +// so they are not serialized +func (u *User) PrepareForRendering() { + u.hideConfidentialData() + u.FsConfig.SetNilSecretsIfEmpty() + for idx := range u.VirtualFolders { + folder := &u.VirtualFolders[idx] + folder.PrepareForRendering() + } +} + +// HasRedactedSecret returns true if the user has a redacted secret +func (u *User) hasRedactedSecret() bool { + if u.FsConfig.HasRedactedSecret() { + return true + } + + for idx := range u.VirtualFolders { + folder := &u.VirtualFolders[idx] + if folder.HasRedactedSecret() { + return true + } + } + + return u.Filters.TOTPConfig.Secret.IsRedacted() +} + +// CloseFs closes the underlying filesystems +func (u *User) CloseFs() error { + if u.fsCache == nil { + return nil + } + + var err error + for _, fs := range u.fsCache { + errClose := fs.Close() + if err == nil { + err = errClose + } + } + return err +} + +// IsPasswordHashed returns true if the password is hashed +func (u *User) IsPasswordHashed() bool { + return util.IsStringPrefixInSlice(u.Password, hashPwdPrefixes) +} + +// IsTLSVerificationEnabled returns true if we need to check the TLS authentication +func (u *User) IsTLSVerificationEnabled() bool { + if len(u.Filters.TLSCerts) > 0 { + return true + } + if u.Filters.TLSUsername != "" { + return u.Filters.TLSUsername != sdk.TLSUsernameNone + } + return false +} + +// SetEmptySecrets sets to empty any user secret +func (u *User) SetEmptySecrets() { + u.FsConfig.SetEmptySecrets() + for idx := range u.VirtualFolders { + folder := &u.VirtualFolders[idx] + folder.FsConfig.SetEmptySecrets() + } + u.Filters.TOTPConfig.Secret = kms.NewEmptySecret() +} + +// GetPermissionsForPath returns the permissions for the given path. +// The path must be a SFTPGo virtual path +func (u *User) GetPermissionsForPath(p string) []string { + permissions := []string{} + if perms, ok := u.Permissions["/"]; ok { + // if only root permissions are defined returns them unconditionally + if len(u.Permissions) == 1 { + return perms + } + // fallback permissions + permissions = perms + } + dirsForPath := util.GetDirsForVirtualPath(p) + // dirsForPath contains all the dirs for a given path in reverse order + // for example if the path is: /1/2/3/4 it contains: + // [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] + // so the first match is the one we are interested to + for idx := range dirsForPath { + if perms, ok := u.Permissions[dirsForPath[idx]]; ok { + return perms + } + for dir, perms := range u.Permissions { + if match, err := path.Match(dir, dirsForPath[idx]); err == nil && match { + return perms + } + } + } + return permissions +} + +func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) { + if allowSelfConnections == 0 { + return nil, nil + } + sftpUser, err := UserExists(username, "") + if err == nil { + err = sftpUser.LoadAndApplyGroupSettings() + } + if err == nil { + // we don't allow local nested SFTP folders + var forbiddens []string + if sftpUser.FsConfig.Provider == sdk.SFTPFilesystemProvider { + forbiddens = append(forbiddens, sftpUser.Username) + return forbiddens, nil + } + for idx := range sftpUser.VirtualFolders { + v := &sftpUser.VirtualFolders[idx] + if v.FsConfig.Provider == sdk.SFTPFilesystemProvider { + forbiddens = append(forbiddens, sftpUser.Username) + return forbiddens, nil + } + } + return forbiddens, nil + } + if !errors.Is(err, util.ErrNotFound) { + return nil, err + } + + return nil, nil +} + +// GetFsConfigForPath returns the file system configuration for the specified virtual path +func (u *User) GetFsConfigForPath(virtualPath string) vfs.Filesystem { + if virtualPath != "" && virtualPath != "/" && len(u.VirtualFolders) > 0 { + folder, err := u.GetVirtualFolderForPath(virtualPath) + if err == nil { + return folder.FsConfig + } + } + + return u.FsConfig +} + +// GetFilesystemForPath returns the filesystem for the given path +func (u *User) GetFilesystemForPath(virtualPath, connectionID string) (vfs.Fs, error) { + if u.fsCache == nil { + u.fsCache = make(map[string]vfs.Fs) + } + // allow to override the `/` path with a virtual folder + if len(u.VirtualFolders) > 0 { + folder, err := u.GetVirtualFolderForPath(virtualPath) + if err == nil { + if fs, ok := u.fsCache[folder.VirtualPath]; ok { + return fs, nil + } + forbiddenSelfUsers := []string{u.Username} + if folder.FsConfig.Provider == sdk.SFTPFilesystemProvider { + forbiddens, err := u.getForbiddenSFTPSelfUsers(folder.FsConfig.SFTPConfig.Username) + if err != nil { + return nil, err + } + forbiddenSelfUsers = append(forbiddenSelfUsers, forbiddens...) + } + fs, err := folder.GetFilesystem(connectionID, forbiddenSelfUsers) + if err == nil { + u.fsCache[folder.VirtualPath] = fs + } + return fs, err + } + } + + if val, ok := u.fsCache["/"]; ok { + return val, nil + } + fs, err := u.getRootFs(connectionID) + if err != nil { + return fs, err + } + u.fsCache["/"] = fs + return fs, err +} + +// GetVirtualFolderForPath returns the virtual folder containing the specified virtual path. +// If the path is not inside a virtual folder an error is returned +func (u *User) GetVirtualFolderForPath(virtualPath string) (vfs.VirtualFolder, error) { + var folder vfs.VirtualFolder + if len(u.VirtualFolders) == 0 { + return folder, errNoMatchingVirtualFolder + } + dirsForPath := util.GetDirsForVirtualPath(virtualPath) + for index := range dirsForPath { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if v.VirtualPath == dirsForPath[index] { + return *v, nil + } + } + } + return folder, errNoMatchingVirtualFolder +} + +// ScanQuota scans the user home dir and virtual folders, included in its quota, +// and returns the number of files and their size +func (u *User) ScanQuota() (int, int64, error) { + fs, err := u.getRootFs(xid.New().String()) + if err != nil { + return 0, 0, err + } + defer fs.Close() + + numFiles, size, err := fs.ScanRootDirContents() + if err != nil { + return numFiles, size, err + } + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if !v.IsIncludedInUserQuota() { + continue + } + num, s, err := v.ScanQuota() + if err != nil { + return numFiles, size, err + } + numFiles += num + size += s + } + + return numFiles, size, nil +} + +// GetVirtualFoldersInPath returns the virtual folders inside virtualPath including +// any parents +func (u *User) GetVirtualFoldersInPath(virtualPath string) map[string]bool { + result := make(map[string]bool) + + for idx := range u.VirtualFolders { + dirsForPath := util.GetDirsForVirtualPath(u.VirtualFolders[idx].VirtualPath) + for index := range dirsForPath { + d := dirsForPath[index] + if d == "/" { + continue + } + if path.Dir(d) == virtualPath { + result[d] = true + } + } + } + + if u.Filters.StartDirectory != "" { + dirsForPath := util.GetDirsForVirtualPath(u.Filters.StartDirectory) + for index := range dirsForPath { + d := dirsForPath[index] + if d == "/" { + continue + } + if path.Dir(d) == virtualPath { + result[d] = true + } + } + } + + return result +} + +func (u *User) hasVirtualDirs() bool { + if u.Filters.StartDirectory != "" { + return true + } + numFolders := len(u.VirtualFolders) + if numFolders == 1 { + return u.VirtualFolders[0].VirtualPath != "/" + } + return numFolders > 0 +} + +// GetVirtualFoldersInfo returns []os.FileInfo for virtual folders +func (u *User) GetVirtualFoldersInfo(virtualPath string) []os.FileInfo { + filter := u.getPatternsFilterForPath(virtualPath) + if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { + return nil + } + vdirs := u.GetVirtualFoldersInPath(virtualPath) + result := make([]os.FileInfo, 0, len(vdirs)) + + for dir := range u.GetVirtualFoldersInPath(virtualPath) { + dirName := path.Base(dir) + if filter.DenyPolicy == sdk.DenyPolicyHide { + if !filter.CheckAllowed(dirName) { + continue + } + } + result = append(result, vfs.NewFileInfo(dirName, true, 0, time.Unix(0, 0), false)) + } + + return result +} + +// FilterListDir removes hidden items from the given files list +func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os.FileInfo { + filter := u.getPatternsFilterForPath(virtualPath) + if !u.hasVirtualDirs() && filter.DenyPolicy != sdk.DenyPolicyHide { + return dirContents + } + vdirs := make(map[string]bool) + for dir := range u.GetVirtualFoldersInPath(virtualPath) { + dirName := path.Base(dir) + if filter.DenyPolicy == sdk.DenyPolicyHide { + if !filter.CheckAllowed(dirName) { + continue + } + } + vdirs[dirName] = true + } + + validIdx := 0 + for idx := range dirContents { + fi := dirContents[idx] + + if fi.Name() != "." && fi.Name() != ".." { + if _, ok := vdirs[fi.Name()]; ok { + continue + } + if filter.DenyPolicy == sdk.DenyPolicyHide { + if !filter.CheckAllowed(fi.Name()) { + continue + } + } + } + dirContents[validIdx] = fi + validIdx++ + } + + return dirContents[:validIdx] +} + +// IsMappedPath returns true if the specified filesystem path has a virtual folder mapping. +// The filesystem path must be cleaned before calling this method +func (u *User) IsMappedPath(fsPath string) bool { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if fsPath == v.MappedPath { + return true + } + } + return false +} + +// IsVirtualFolder returns true if the specified virtual path is a virtual folder +func (u *User) IsVirtualFolder(virtualPath string) bool { + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if virtualPath == v.VirtualPath { + return true + } + } + return false +} + +// HasVirtualFoldersInside returns true if there are virtual folders inside the +// specified virtual path. We assume that path are cleaned +func (u *User) HasVirtualFoldersInside(virtualPath string) bool { + if virtualPath == "/" && len(u.VirtualFolders) > 0 { + return true + } + for idx := range u.VirtualFolders { + v := &u.VirtualFolders[idx] + if len(v.VirtualPath) > len(virtualPath) { + if strings.HasPrefix(v.VirtualPath, virtualPath+"/") { + return true + } + } + } + return false +} + +// HasPermissionsInside returns true if the specified virtualPath has no permissions itself and +// no subdirs with defined permissions +func (u *User) HasPermissionsInside(virtualPath string) bool { + for dir, perms := range u.Permissions { + if len(perms) == 1 && perms[0] == PermAny { + continue + } + if dir == virtualPath { + return true + } else if len(dir) > len(virtualPath) { + if strings.HasPrefix(dir, virtualPath+"/") { + return true + } + } + } + return false +} + +// HasPerm returns true if the user has the given permission or any permission +func (u *User) HasPerm(permission, path string) bool { + perms := u.GetPermissionsForPath(path) + if slices.Contains(perms, PermAny) { + return true + } + return slices.Contains(perms, permission) +} + +// HasAnyPerm returns true if the user has at least one of the given permissions +func (u *User) HasAnyPerm(permissions []string, path string) bool { + perms := u.GetPermissionsForPath(path) + if slices.Contains(perms, PermAny) { + return true + } + for _, permission := range permissions { + if slices.Contains(perms, permission) { + return true + } + } + return false +} + +// HasPerms returns true if the user has all the given permissions +func (u *User) HasPerms(permissions []string, path string) bool { + perms := u.GetPermissionsForPath(path) + if slices.Contains(perms, PermAny) { + return true + } + for _, permission := range permissions { + if !slices.Contains(perms, permission) { + return false + } + } + return true +} + +// HasPermsDeleteAll returns true if the user can delete both files and directories +// for the given path +func (u *User) HasPermsDeleteAll(path string) bool { + perms := u.GetPermissionsForPath(path) + canDeleteFiles := false + canDeleteDirs := false + for _, permission := range perms { + if permission == PermAny || permission == PermDelete { + return true + } + if permission == PermDeleteFiles { + canDeleteFiles = true + } + if permission == PermDeleteDirs { + canDeleteDirs = true + } + } + return canDeleteFiles && canDeleteDirs +} + +// HasPermsRenameAll returns true if the user can rename both files and directories +// for the given path +func (u *User) HasPermsRenameAll(path string) bool { + perms := u.GetPermissionsForPath(path) + canRenameFiles := false + canRenameDirs := false + for _, permission := range perms { + if permission == PermAny || permission == PermRename { + return true + } + if permission == PermRenameFiles { + canRenameFiles = true + } + if permission == PermRenameDirs { + canRenameDirs = true + } + } + return canRenameFiles && canRenameDirs +} + +// HasNoQuotaRestrictions returns true if no quota restrictions need to be applyed +func (u *User) HasNoQuotaRestrictions(checkFiles bool) bool { + if u.QuotaSize == 0 && (!checkFiles || u.QuotaFiles == 0) { + return true + } + return false +} + +// IsLoginMethodAllowed returns true if the specified login method is allowed +func (u *User) IsLoginMethodAllowed(loginMethod, protocol string) bool { + if len(u.Filters.DeniedLoginMethods) == 0 { + return true + } + if slices.Contains(u.Filters.DeniedLoginMethods, loginMethod) { + return false + } + if protocol == protocolSSH && loginMethod == LoginMethodPassword { + if slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) { + return false + } + } + return true +} + +// GetNextAuthMethods returns the list of authentications methods that can +// continue for multi-step authentication. We call this method after a +// successful public key authentication. +func (u *User) GetNextAuthMethods() []string { + var methods []string + for _, method := range u.GetAllowedLoginMethods() { + if method == SSHLoginMethodKeyAndPassword { + methods = append(methods, LoginMethodPassword) + } + if method == SSHLoginMethodKeyAndKeyboardInt { + methods = append(methods, SSHLoginMethodKeyboardInteractive) + } + } + return methods +} + +// IsPartialAuth returns true if the specified login method is a step for +// a multi-step Authentication. +// We support publickey+password and publickey+keyboard-interactive, so +// only publickey can returns partial success. +// We can have partial success if only multi-step Auth methods are enabled +func (u *User) IsPartialAuth() bool { + for _, method := range u.GetAllowedLoginMethods() { + if method == LoginMethodTLSCertificate || method == LoginMethodTLSCertificateAndPwd || + method == SSHLoginMethodPassword { + continue + } + if method == LoginMethodPassword && slices.Contains(u.Filters.DeniedLoginMethods, SSHLoginMethodPassword) { + continue + } + if !slices.Contains(SSHMultiStepsLoginMethods, method) { + return false + } + } + return true +} + +// GetAllowedLoginMethods returns the allowed login methods +func (u *User) GetAllowedLoginMethods() []string { + var allowedMethods []string + for _, method := range ValidLoginMethods { + if method == SSHLoginMethodPassword { + continue + } + if !slices.Contains(u.Filters.DeniedLoginMethods, method) { + allowedMethods = append(allowedMethods, method) + } + } + return allowedMethods +} + +func (u *User) getPatternsFilterForPath(virtualPath string) sdk.PatternsFilter { + var filter sdk.PatternsFilter + if len(u.Filters.FilePatterns) == 0 { + return filter + } + dirsForPath := util.GetDirsForVirtualPath(virtualPath) + for idx, dir := range dirsForPath { + for _, f := range u.Filters.FilePatterns { + if f.Path == dir { + if idx > 0 && len(f.AllowedPatterns) > 0 && len(f.DeniedPatterns) > 0 && f.DeniedPatterns[0] == "*" { + if f.CheckAllowed(path.Base(dirsForPath[idx-1])) { + return filter + } + } + filter = f + break + } + } + if filter.Path != "" { + break + } + } + return filter +} + +func (u *User) isDirHidden(virtualPath string) bool { + if len(u.Filters.FilePatterns) == 0 { + return false + } + for _, dirPath := range util.GetDirsForVirtualPath(virtualPath) { + if dirPath == "/" { + return false + } + filter := u.getPatternsFilterForPath(dirPath) + if filter.DenyPolicy == sdk.DenyPolicyHide && filter.Path != dirPath { + if !filter.CheckAllowed(path.Base(dirPath)) { + return true + } + } + } + return false +} + +func (u *User) getMinPasswordEntropy() float64 { + if u.Filters.PasswordStrength > 0 { + return float64(u.Filters.PasswordStrength) + } + return config.PasswordValidation.Users.MinEntropy +} + +// IsFileAllowed returns true if the specified file is allowed by the file restrictions filters. +// The second parameter returned is the deny policy +func (u *User) IsFileAllowed(virtualPath string) (bool, int) { + dirPath := path.Dir(virtualPath) + if u.isDirHidden(dirPath) { + return false, sdk.DenyPolicyHide + } + filter := u.getPatternsFilterForPath(dirPath) + return filter.CheckAllowed(path.Base(virtualPath)), filter.DenyPolicy +} + +// CanManageMFA returns true if the user can add a multi-factor authentication configuration +func (u *User) CanManageMFA() bool { + if slices.Contains(u.Filters.WebClient, sdk.WebClientMFADisabled) { + return false + } + return len(mfa.GetAvailableTOTPConfigs()) > 0 +} + +func (u *User) skipExternalAuth() bool { + if u.Filters.Hooks.ExternalAuthDisabled { + return true + } + if u.ID <= 0 { + return false + } + if u.Filters.ExternalAuthCacheTime <= 0 { + return false + } + return isLastActivityRecent(u.LastLogin, time.Duration(u.Filters.ExternalAuthCacheTime)*time.Second) +} + +// CanManageShares returns true if the user can add, update and list shares +func (u *User) CanManageShares() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientSharesDisabled) +} + +// CanResetPassword returns true if this user is allowed to reset its password +func (u *User) CanResetPassword() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordResetDisabled) +} + +// CanChangePassword returns true if this user is allowed to change its password +func (u *User) CanChangePassword() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) +} + +// CanChangeAPIKeyAuth returns true if this user is allowed to enable/disable API key authentication +func (u *User) CanChangeAPIKeyAuth() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientAPIKeyAuthChangeDisabled) +} + +// CanChangeInfo returns true if this user is allowed to change its info such as email and description +func (u *User) CanChangeInfo() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientInfoChangeDisabled) +} + +// CanManagePublicKeys returns true if this user is allowed to manage public keys +// from the WebClient. Used in WebClient UI +func (u *User) CanManagePublicKeys() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled) +} + +// CanManageTLSCerts returns true if this user is allowed to manage TLS certificates +// from the WebClient. Used in WebClient UI +func (u *User) CanManageTLSCerts() bool { + return !slices.Contains(u.Filters.WebClient, sdk.WebClientTLSCertChangeDisabled) +} + +// CanUpdateProfile returns true if the user is allowed to update the profile. +// Used in WebClient UI +func (u *User) CanUpdateProfile() bool { + return u.CanManagePublicKeys() || u.CanChangeAPIKeyAuth() || u.CanChangeInfo() || u.CanManageTLSCerts() +} + +// CanAddFilesFromWeb returns true if the client can add files from the web UI. +// The specified target is the directory where the files must be uploaded +func (u *User) CanAddFilesFromWeb(target string) bool { + if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { + return false + } + return u.HasPerm(PermUpload, target) || u.HasPerm(PermOverwrite, target) +} + +// CanAddDirsFromWeb returns true if the client can add directories from the web UI. +// The specified target is the directory where the new directory must be created +func (u *User) CanAddDirsFromWeb(target string) bool { + if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { + return false + } + return u.HasPerm(PermCreateDirs, target) +} + +// CanRenameFromWeb returns true if the client can rename objects from the web UI. +// The specified src and dest are the source and target directories for the rename. +func (u *User) CanRenameFromWeb(src, dest string) bool { + if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { + return false + } + return u.HasAnyPerm(permsRenameAny, src) && u.HasAnyPerm(permsRenameAny, dest) +} + +// CanDeleteFromWeb returns true if the client can delete objects from the web UI. +// The specified target is the parent directory for the object to delete +func (u *User) CanDeleteFromWeb(target string) bool { + if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { + return false + } + return u.HasAnyPerm(permsDeleteAny, target) +} + +// CanCopyFromWeb returns true if the client can copy objects from the web UI. +// The specified src and dest are the source and target directories for the copy. +func (u *User) CanCopyFromWeb(src, dest string) bool { + if slices.Contains(u.Filters.WebClient, sdk.WebClientWriteDisabled) { + return false + } + if !u.HasPerm(PermListItems, src) { + return false + } + if !u.HasPerm(PermDownload, src) { + return false + } + return u.HasPerm(PermCopy, src) && u.HasPerm(PermCopy, dest) +} + +// InactivityDays returns the number of days of inactivity +func (u *User) InactivityDays(when time.Time) int { + if when.IsZero() { + when = time.Now() + } + lastActivity := u.LastLogin + if lastActivity == 0 { + lastActivity = u.CreatedAt + } + if lastActivity == 0 { + // unable to determine inactivity + return 0 + } + return int(float64(when.Sub(util.GetTimeFromMsecSinceEpoch(lastActivity))) / float64(24*time.Hour)) +} + +// PasswordExpiresIn returns the number of days before the password expires. +// The returned value is negative if the password is expired. +// The caller must ensure that a PasswordExpiration is set +func (u *User) PasswordExpiresIn() int { + lastPwdChange := util.GetTimeFromMsecSinceEpoch(u.LastPasswordChange) + pwdExpiration := lastPwdChange.Add(time.Duration(u.Filters.PasswordExpiration) * 24 * time.Hour) + res := int(math.Round(float64(time.Until(pwdExpiration)) / float64(24*time.Hour))) + if res == 0 && pwdExpiration.After(time.Now()) { + res = 1 + } + return res +} + +// MustChangePassword returns true if the user must change the password +func (u *User) MustChangePassword() bool { + if u.Filters.RequirePasswordChange { + return true + } + if u.Filters.PasswordExpiration == 0 { + return false + } + lastPwdChange := util.GetTimeFromMsecSinceEpoch(u.LastPasswordChange) + return lastPwdChange.Add(time.Duration(u.Filters.PasswordExpiration) * 24 * time.Hour).Before(time.Now()) +} + +// MustSetSecondFactor returns true if the user must set a second factor authentication +func (u *User) MustSetSecondFactor() bool { + if len(u.Filters.TwoFactorAuthProtocols) > 0 { + if !u.Filters.TOTPConfig.Enabled { + return true + } + for _, p := range u.Filters.TwoFactorAuthProtocols { + if !slices.Contains(u.Filters.TOTPConfig.Protocols, p) { + return true + } + } + } + return false +} + +// MustSetSecondFactorForProtocol returns true if the user must set a second factor authentication +// for the specified protocol +func (u *User) MustSetSecondFactorForProtocol(protocol string) bool { + if slices.Contains(u.Filters.TwoFactorAuthProtocols, protocol) { + if !u.Filters.TOTPConfig.Enabled { + return true + } + if !slices.Contains(u.Filters.TOTPConfig.Protocols, protocol) { + return true + } + } + return false +} + +// GetSignature returns a signature for this user. +// It will change after an update +func (u *User) GetSignature() string { + return strconv.FormatInt(u.UpdatedAt, 10) +} + +// GetBandwidthForIP returns the upload and download bandwidth for the specified IP +func (u *User) GetBandwidthForIP(clientIP, connectionID string) (int64, int64) { + if len(u.Filters.BandwidthLimits) > 0 { + ip := net.ParseIP(clientIP) + if ip != nil { + for _, bwLimit := range u.Filters.BandwidthLimits { + for _, source := range bwLimit.Sources { + _, ipNet, err := net.ParseCIDR(source) + if err == nil { + if ipNet.Contains(ip) { + logger.Debug(logSender, connectionID, "override bandwidth limit for ip %q, upload limit: %v KB/s, download limit: %v KB/s", + clientIP, bwLimit.UploadBandwidth, bwLimit.DownloadBandwidth) + return bwLimit.UploadBandwidth, bwLimit.DownloadBandwidth + } + } + } + } + } + } + return u.UploadBandwidth, u.DownloadBandwidth +} + +// IsLoginFromAddrAllowed returns true if the login is allowed from the specified remoteAddr. +// If AllowedIP is defined only the specified IP/Mask can login. +// If DeniedIP is defined the specified IP/Mask cannot login. +// If an IP is both allowed and denied then login will be allowed +func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool { + if len(u.Filters.AllowedIP) == 0 && len(u.Filters.DeniedIP) == 0 { + return true + } + remoteIP := net.ParseIP(util.GetIPFromRemoteAddress(remoteAddr)) + // if remoteIP is invalid we allow login, this should never happen + if remoteIP == nil { + logger.Warn(logSender, "", "login allowed for invalid IP. remote address: %q", remoteAddr) + return true + } + for _, IPMask := range u.Filters.AllowedIP { + _, IPNet, err := net.ParseCIDR(IPMask) + if err != nil { + return false + } + if IPNet.Contains(remoteIP) { + return true + } + } + for _, IPMask := range u.Filters.DeniedIP { + _, IPNet, err := net.ParseCIDR(IPMask) + if err != nil { + return false + } + if IPNet.Contains(remoteIP) { + return false + } + } + return len(u.Filters.AllowedIP) == 0 +} + +// GetPermissionsAsJSON returns the permissions as json byte array +func (u *User) GetPermissionsAsJSON() ([]byte, error) { + return json.Marshal(u.Permissions) +} + +// GetPublicKeysAsJSON returns the public keys as json byte array +func (u *User) GetPublicKeysAsJSON() ([]byte, error) { + return json.Marshal(u.PublicKeys) +} + +// GetFiltersAsJSON returns the filters as json byte array +func (u *User) GetFiltersAsJSON() ([]byte, error) { + return json.Marshal(u.Filters) +} + +// GetFsConfigAsJSON returns the filesystem config as json byte array +func (u *User) GetFsConfigAsJSON() ([]byte, error) { + return json.Marshal(u.FsConfig) +} + +// GetUID returns a validate uid, suitable for use with os.Chown +func (u *User) GetUID() int { + if u.UID <= 0 || u.UID > math.MaxInt32 { + return -1 + } + return u.UID +} + +// GetGID returns a validate gid, suitable for use with os.Chown +func (u *User) GetGID() int { + if u.GID <= 0 || u.GID > math.MaxInt32 { + return -1 + } + return u.GID +} + +// GetHomeDir returns the shortest path name equivalent to the user's home directory +func (u *User) GetHomeDir() string { + return u.HomeDir +} + +// HasRecentActivity returns true if the last user login is recent and so we can skip some expensive checks +func (u *User) HasRecentActivity() bool { + return isLastActivityRecent(u.LastLogin, lastLoginMinDelay) +} + +// HasQuotaRestrictions returns true if there are any disk quota restrictions +func (u *User) HasQuotaRestrictions() bool { + return u.QuotaFiles > 0 || u.QuotaSize > 0 +} + +// HasTransferQuotaRestrictions returns true if there are any data transfer restrictions +func (u *User) HasTransferQuotaRestrictions() bool { + return u.UploadDataTransfer > 0 || u.TotalDataTransfer > 0 || u.DownloadDataTransfer > 0 +} + +// GetDataTransferLimits returns upload, download and total data transfer limits +func (u *User) GetDataTransferLimits() (int64, int64, int64) { + var total, ul, dl int64 + if u.TotalDataTransfer > 0 { + total = u.TotalDataTransfer * 1048576 + } + if u.DownloadDataTransfer > 0 { + dl = u.DownloadDataTransfer * 1048576 + } + if u.UploadDataTransfer > 0 { + ul = u.UploadDataTransfer * 1048576 + } + return ul, dl, total +} + +// GetAllowedIPAsString returns the allowed IP as comma separated string +func (u *User) GetAllowedIPAsString() string { + return strings.Join(u.Filters.AllowedIP, ",") +} + +// GetDeniedIPAsString returns the denied IP as comma separated string +func (u *User) GetDeniedIPAsString() string { + return strings.Join(u.Filters.DeniedIP, ",") +} + +// HasExternalAuth returns true if the external authentication is globally enabled +// and it is not disabled for this user +func (u *User) HasExternalAuth() bool { + if u.Filters.Hooks.ExternalAuthDisabled { + return false + } + if config.ExternalAuthHook != "" { + return true + } + return plugin.Handler.HasAuthenticators() +} + +// CountUnusedRecoveryCodes returns the number of unused recovery codes +func (u *User) CountUnusedRecoveryCodes() int { + unused := 0 + for _, code := range u.Filters.RecoveryCodes { + if !code.Used { + unused++ + } + } + return unused +} + +// SetEmptySecretsIfNil sets the secrets to empty if nil +func (u *User) SetEmptySecretsIfNil() { + u.HasPassword = u.Password != "" + u.FsConfig.SetEmptySecretsIfNil() + for idx := range u.VirtualFolders { + vfolder := &u.VirtualFolders[idx] + vfolder.FsConfig.SetEmptySecretsIfNil() + } + if u.Filters.TOTPConfig.Secret == nil { + u.Filters.TOTPConfig.Secret = kms.NewEmptySecret() + } +} + +func (u *User) hasMainDataTransferLimits() bool { + return u.UploadDataTransfer > 0 || u.DownloadDataTransfer > 0 || u.TotalDataTransfer > 0 +} + +// HasPrimaryGroup returns true if the user has the specified primary group +func (u *User) HasPrimaryGroup(name string) bool { + for _, g := range u.Groups { + if g.Name == name { + return g.Type == sdk.GroupTypePrimary + } + } + return false +} + +// HasSecondaryGroup returns true if the user has the specified secondary group +func (u *User) HasSecondaryGroup(name string) bool { + for _, g := range u.Groups { + if g.Name == name { + return g.Type == sdk.GroupTypeSecondary + } + } + return false +} + +// HasMembershipGroup returns true if the user has the specified membership group +func (u *User) HasMembershipGroup(name string) bool { + for _, g := range u.Groups { + if g.Name == name { + return g.Type == sdk.GroupTypeMembership + } + } + return false +} + +func (u *User) hasSettingsFromGroups() bool { + for _, g := range u.Groups { + if g.Type != sdk.GroupTypeMembership { + return true + } + } + return false +} + +func (u *User) applyGroupSettings(groupsMapping map[string]Group) { + if !u.hasSettingsFromGroups() { + return + } + if u.groupSettingsApplied { + return + } + replacer := u.getGroupPlacehodersReplacer() + for _, g := range u.Groups { + if g.Type == sdk.GroupTypePrimary { + if group, ok := groupsMapping[g.Name]; ok { + u.mergeWithPrimaryGroup(&group, replacer) + } else { + providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name) + } + break + } + } + for _, g := range u.Groups { + if g.Type == sdk.GroupTypeSecondary { + if group, ok := groupsMapping[g.Name]; ok { + u.mergeAdditiveProperties(&group, sdk.GroupTypeSecondary, replacer) + } else { + providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name) + } + } + } + u.removeDuplicatesAfterGroupMerge() +} + +// LoadAndApplyGroupSettings update the user by loading and applying the group settings +func (u *User) LoadAndApplyGroupSettings() error { + if !u.hasSettingsFromGroups() { + return nil + } + if u.groupSettingsApplied { + return nil + } + names := make([]string, 0, len(u.Groups)) + var primaryGroupName string + for _, g := range u.Groups { + if g.Type == sdk.GroupTypePrimary { + primaryGroupName = g.Name + } + if g.Type != sdk.GroupTypeMembership { + names = append(names, g.Name) + } + } + groups, err := provider.getGroupsWithNames(names) + if err != nil { + return fmt.Errorf("unable to get groups: %w", err) + } + replacer := u.getGroupPlacehodersReplacer() + // make sure to always merge with the primary group first + for idx := range groups { + g := groups[idx] + if g.Name == primaryGroupName { + u.mergeWithPrimaryGroup(&g, replacer) + lastIdx := len(groups) - 1 + groups[idx] = groups[lastIdx] + groups = groups[:lastIdx] + break + } + } + for idx := range groups { + g := groups[idx] + u.mergeAdditiveProperties(&g, sdk.GroupTypeSecondary, replacer) + } + u.removeDuplicatesAfterGroupMerge() + return nil +} + +func (u *User) getGroupPlacehodersReplacer() *strings.Replacer { + return strings.NewReplacer("%username%", u.Username, "%role%", u.Role) +} + +func (u *User) replacePlaceholder(value string, replacer *strings.Replacer) string { + if value == "" { + return value + } + return replacer.Replace(value) +} + +func (u *User) replaceFsConfigPlaceholders(fsConfig vfs.Filesystem, replacer *strings.Replacer) vfs.Filesystem { + switch fsConfig.Provider { + case sdk.S3FilesystemProvider: + fsConfig.S3Config.KeyPrefix = u.replacePlaceholder(fsConfig.S3Config.KeyPrefix, replacer) + case sdk.GCSFilesystemProvider: + fsConfig.GCSConfig.KeyPrefix = u.replacePlaceholder(fsConfig.GCSConfig.KeyPrefix, replacer) + case sdk.AzureBlobFilesystemProvider: + fsConfig.AzBlobConfig.KeyPrefix = u.replacePlaceholder(fsConfig.AzBlobConfig.KeyPrefix, replacer) + case sdk.SFTPFilesystemProvider: + fsConfig.SFTPConfig.Username = u.replacePlaceholder(fsConfig.SFTPConfig.Username, replacer) + fsConfig.SFTPConfig.Prefix = u.replacePlaceholder(fsConfig.SFTPConfig.Prefix, replacer) + case sdk.HTTPFilesystemProvider: + fsConfig.HTTPConfig.Username = u.replacePlaceholder(fsConfig.HTTPConfig.Username, replacer) + } + return fsConfig +} + +func (u *User) mergeCryptFsConfig(group *Group) { + if group.UserSettings.FsConfig.Provider == sdk.CryptedFilesystemProvider { + if u.FsConfig.CryptConfig.ReadBufferSize == 0 { + u.FsConfig.CryptConfig.ReadBufferSize = group.UserSettings.FsConfig.CryptConfig.ReadBufferSize + } + if u.FsConfig.CryptConfig.WriteBufferSize == 0 { + u.FsConfig.CryptConfig.WriteBufferSize = group.UserSettings.FsConfig.CryptConfig.WriteBufferSize + } + } +} + +func (u *User) mergeWithPrimaryGroup(group *Group, replacer *strings.Replacer) { + if group.UserSettings.HomeDir != "" { + u.HomeDir = filepath.Clean(u.replacePlaceholder(group.UserSettings.HomeDir, replacer)) + } + if group.UserSettings.FsConfig.Provider != 0 { + u.FsConfig = u.replaceFsConfigPlaceholders(group.UserSettings.FsConfig, replacer) + u.mergeCryptFsConfig(group) + } else { + if u.FsConfig.OSConfig.ReadBufferSize == 0 { + u.FsConfig.OSConfig.ReadBufferSize = group.UserSettings.FsConfig.OSConfig.ReadBufferSize + } + if u.FsConfig.OSConfig.WriteBufferSize == 0 { + u.FsConfig.OSConfig.WriteBufferSize = group.UserSettings.FsConfig.OSConfig.WriteBufferSize + } + } + if u.MaxSessions == 0 { + u.MaxSessions = group.UserSettings.MaxSessions + } + if u.QuotaSize == 0 { + u.QuotaSize = group.UserSettings.QuotaSize + } + if u.QuotaFiles == 0 { + u.QuotaFiles = group.UserSettings.QuotaFiles + } + if u.UploadBandwidth == 0 { + u.UploadBandwidth = group.UserSettings.UploadBandwidth + } + if u.DownloadBandwidth == 0 { + u.DownloadBandwidth = group.UserSettings.DownloadBandwidth + } + if !u.hasMainDataTransferLimits() { + u.UploadDataTransfer = group.UserSettings.UploadDataTransfer + u.DownloadDataTransfer = group.UserSettings.DownloadDataTransfer + u.TotalDataTransfer = group.UserSettings.TotalDataTransfer + } + if u.ExpirationDate == 0 && group.UserSettings.ExpiresIn > 0 { + u.ExpirationDate = u.CreatedAt + int64(group.UserSettings.ExpiresIn)*86400000 + } + u.mergePrimaryGroupFilters(&group.UserSettings.Filters, replacer) + u.mergeAdditiveProperties(group, sdk.GroupTypePrimary, replacer) +} + +func (u *User) mergePrimaryGroupFilters(filters *sdk.BaseUserFilters, replacer *strings.Replacer) { //nolint:gocyclo + if u.Filters.MaxUploadFileSize == 0 { + u.Filters.MaxUploadFileSize = filters.MaxUploadFileSize + } + if !u.IsTLSVerificationEnabled() { + u.Filters.TLSUsername = filters.TLSUsername + } + if !u.Filters.Hooks.CheckPasswordDisabled { + u.Filters.Hooks.CheckPasswordDisabled = filters.Hooks.CheckPasswordDisabled + } + if !u.Filters.Hooks.PreLoginDisabled { + u.Filters.Hooks.PreLoginDisabled = filters.Hooks.PreLoginDisabled + } + if !u.Filters.Hooks.ExternalAuthDisabled { + u.Filters.Hooks.ExternalAuthDisabled = filters.Hooks.ExternalAuthDisabled + } + if !u.Filters.DisableFsChecks { + u.Filters.DisableFsChecks = filters.DisableFsChecks + } + if !u.Filters.AllowAPIKeyAuth { + u.Filters.AllowAPIKeyAuth = filters.AllowAPIKeyAuth + } + if !u.Filters.IsAnonymous { + u.Filters.IsAnonymous = filters.IsAnonymous + } + if u.Filters.ExternalAuthCacheTime == 0 { + u.Filters.ExternalAuthCacheTime = filters.ExternalAuthCacheTime + } + if u.Filters.FTPSecurity == 0 { + u.Filters.FTPSecurity = filters.FTPSecurity + } + if u.Filters.StartDirectory == "" { + u.Filters.StartDirectory = u.replacePlaceholder(filters.StartDirectory, replacer) + } + if u.Filters.DefaultSharesExpiration == 0 { + u.Filters.DefaultSharesExpiration = filters.DefaultSharesExpiration + } + if u.Filters.MaxSharesExpiration == 0 { + u.Filters.MaxSharesExpiration = filters.MaxSharesExpiration + } + if u.Filters.PasswordExpiration == 0 { + u.Filters.PasswordExpiration = filters.PasswordExpiration + } + if u.Filters.PasswordStrength == 0 { + u.Filters.PasswordStrength = filters.PasswordStrength + } +} + +func (u *User) mergeAdditiveProperties(group *Group, groupType int, replacer *strings.Replacer) { + u.mergeVirtualFolders(group, groupType, replacer) + u.mergePermissions(group, groupType, replacer) + u.mergeFilePatterns(group, groupType, replacer) + u.Filters.BandwidthLimits = append(u.Filters.BandwidthLimits, group.UserSettings.Filters.BandwidthLimits...) + u.Filters.AllowedIP = append(u.Filters.AllowedIP, group.UserSettings.Filters.AllowedIP...) + u.Filters.DeniedIP = append(u.Filters.DeniedIP, group.UserSettings.Filters.DeniedIP...) + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, group.UserSettings.Filters.DeniedLoginMethods...) + u.Filters.DeniedProtocols = append(u.Filters.DeniedProtocols, group.UserSettings.Filters.DeniedProtocols...) + u.Filters.WebClient = append(u.Filters.WebClient, group.UserSettings.Filters.WebClient...) + u.Filters.TwoFactorAuthProtocols = append(u.Filters.TwoFactorAuthProtocols, group.UserSettings.Filters.TwoFactorAuthProtocols...) + u.Filters.AccessTime = append(u.Filters.AccessTime, group.UserSettings.Filters.AccessTime...) +} + +func (u *User) mergeVirtualFolders(group *Group, groupType int, replacer *strings.Replacer) { + if len(group.VirtualFolders) > 0 { + folderPaths := make(map[string]bool) + for _, folder := range u.VirtualFolders { + folderPaths[folder.VirtualPath] = true + } + for _, folder := range group.VirtualFolders { + if folder.VirtualPath == "/" && groupType != sdk.GroupTypePrimary { + continue + } + folder.VirtualPath = u.replacePlaceholder(folder.VirtualPath, replacer) + if _, ok := folderPaths[folder.VirtualPath]; !ok { + folder.MappedPath = u.replacePlaceholder(folder.MappedPath, replacer) + folder.FsConfig = u.replaceFsConfigPlaceholders(folder.FsConfig, replacer) + u.VirtualFolders = append(u.VirtualFolders, folder) + } + } + } +} + +func (u *User) mergePermissions(group *Group, groupType int, replacer *strings.Replacer) { + if u.Permissions == nil { + u.Permissions = make(map[string][]string) + } + for k, v := range group.UserSettings.Permissions { + if k == "/" { + if groupType == sdk.GroupTypePrimary { + u.Permissions[k] = v + } else { + continue + } + } + k = u.replacePlaceholder(k, replacer) + if _, ok := u.Permissions[k]; !ok { + u.Permissions[k] = v + } + } +} + +func (u *User) mergeFilePatterns(group *Group, groupType int, replacer *strings.Replacer) { + if len(group.UserSettings.Filters.FilePatterns) > 0 { + patternPaths := make(map[string]bool) + for _, pattern := range u.Filters.FilePatterns { + patternPaths[pattern.Path] = true + } + for _, pattern := range group.UserSettings.Filters.FilePatterns { + if pattern.Path == "/" && groupType != sdk.GroupTypePrimary { + continue + } + pattern.Path = u.replacePlaceholder(pattern.Path, replacer) + if _, ok := patternPaths[pattern.Path]; !ok { + u.Filters.FilePatterns = append(u.Filters.FilePatterns, pattern) + } + } + } +} + +func (u *User) removeDuplicatesAfterGroupMerge() { + u.Filters.AllowedIP = util.RemoveDuplicates(u.Filters.AllowedIP, false) + u.Filters.DeniedIP = util.RemoveDuplicates(u.Filters.DeniedIP, false) + u.Filters.DeniedLoginMethods = util.RemoveDuplicates(u.Filters.DeniedLoginMethods, false) + u.Filters.DeniedProtocols = util.RemoveDuplicates(u.Filters.DeniedProtocols, false) + u.Filters.WebClient = util.RemoveDuplicates(u.Filters.WebClient, false) + u.Filters.TwoFactorAuthProtocols = util.RemoveDuplicates(u.Filters.TwoFactorAuthProtocols, false) + u.SetEmptySecretsIfNil() + u.groupSettingsApplied = true +} + +func (u *User) hasRole(role string) bool { + if role == "" { + return true + } + return role == u.Role +} + +func (u *User) applyNamingRules() { + u.Username = config.convertName(u.Username) + u.Role = config.convertName(u.Role) + for idx := range u.Groups { + u.Groups[idx].Name = config.convertName(u.Groups[idx].Name) + } + for idx := range u.VirtualFolders { + u.VirtualFolders[idx].Name = config.convertName(u.VirtualFolders[idx].Name) + } +} + +func (u *User) getACopy() User { + u.SetEmptySecretsIfNil() + pubKeys := make([]string, len(u.PublicKeys)) + copy(pubKeys, u.PublicKeys) + virtualFolders := make([]vfs.VirtualFolder, 0, len(u.VirtualFolders)) + for idx := range u.VirtualFolders { + vfolder := u.VirtualFolders[idx].GetACopy() + virtualFolders = append(virtualFolders, vfolder) + } + groups := make([]sdk.GroupMapping, 0, len(u.Groups)) + for _, g := range u.Groups { + groups = append(groups, sdk.GroupMapping{ + Name: g.Name, + Type: g.Type, + }) + } + permissions := make(map[string][]string) + for k, v := range u.Permissions { + perms := make([]string, len(v)) + copy(perms, v) + permissions[k] = perms + } + filters := UserFilters{ + BaseUserFilters: copyBaseUserFilters(u.Filters.BaseUserFilters), + } + filters.RequirePasswordChange = u.Filters.RequirePasswordChange + filters.TOTPConfig.Enabled = u.Filters.TOTPConfig.Enabled + filters.TOTPConfig.ConfigName = u.Filters.TOTPConfig.ConfigName + filters.TOTPConfig.Secret = u.Filters.TOTPConfig.Secret.Clone() + filters.TOTPConfig.Protocols = make([]string, len(u.Filters.TOTPConfig.Protocols)) + copy(filters.TOTPConfig.Protocols, u.Filters.TOTPConfig.Protocols) + filters.AdditionalEmails = make([]string, len(u.Filters.AdditionalEmails)) + copy(filters.AdditionalEmails, u.Filters.AdditionalEmails) + filters.RecoveryCodes = make([]RecoveryCode, 0, len(u.Filters.RecoveryCodes)) + for _, code := range u.Filters.RecoveryCodes { + if code.Secret == nil { + code.Secret = kms.NewEmptySecret() + } + filters.RecoveryCodes = append(filters.RecoveryCodes, RecoveryCode{ + Secret: code.Secret.Clone(), + Used: code.Used, + }) + } + + return User{ + BaseUser: sdk.BaseUser{ + ID: u.ID, + Username: u.Username, + Email: u.Email, + Password: u.Password, + PublicKeys: pubKeys, + HasPassword: u.HasPassword, + HomeDir: u.HomeDir, + UID: u.UID, + GID: u.GID, + MaxSessions: u.MaxSessions, + QuotaSize: u.QuotaSize, + QuotaFiles: u.QuotaFiles, + Permissions: permissions, + UsedQuotaSize: u.UsedQuotaSize, + UsedQuotaFiles: u.UsedQuotaFiles, + LastQuotaUpdate: u.LastQuotaUpdate, + UploadBandwidth: u.UploadBandwidth, + DownloadBandwidth: u.DownloadBandwidth, + UploadDataTransfer: u.UploadDataTransfer, + DownloadDataTransfer: u.DownloadDataTransfer, + TotalDataTransfer: u.TotalDataTransfer, + UsedUploadDataTransfer: u.UsedUploadDataTransfer, + UsedDownloadDataTransfer: u.UsedDownloadDataTransfer, + Status: u.Status, + ExpirationDate: u.ExpirationDate, + LastLogin: u.LastLogin, + FirstDownload: u.FirstDownload, + FirstUpload: u.FirstUpload, + LastPasswordChange: u.LastPasswordChange, + AdditionalInfo: u.AdditionalInfo, + Description: u.Description, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + Role: u.Role, + }, + Filters: filters, + VirtualFolders: virtualFolders, + Groups: groups, + FsConfig: u.FsConfig.GetACopy(), + groupSettingsApplied: u.groupSettingsApplied, + } +} + +// GetEncryptionAdditionalData returns the additional data to use for AEAD +func (u *User) GetEncryptionAdditionalData() string { + return u.Username +} diff --git a/internal/ftpd/cryptfs_test.go b/internal/ftpd/cryptfs_test.go new file mode 100644 index 00000000..23567fff --- /dev/null +++ b/internal/ftpd/cryptfs_test.go @@ -0,0 +1,341 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ftpd_test + +import ( + "crypto/sha256" + "fmt" + "hash" + "io" + "net/http" + "os" + "path" + "path/filepath" + "testing" + "time" + + "github.com/minio/sio" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" +) + +func TestBasicFTPHandlingCryptFs(t *testing.T) { + u := getTestUserWithCryptFs() + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + assert.Len(t, common.Connections.GetStats(""), 1) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + encryptedFileSize, err := getEncryptedFileSize(testFileSize) + assert.NoError(t, err) + expectedQuotaSize := encryptedFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client, 0) + assert.Error(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + // overwrite an existing file + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + info, err := os.Stat(localDownloadPath) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + list, err := client.List(".") + if assert.NoError(t, err) { + if assert.Len(t, list, 1) { + assert.Equal(t, testFileSize, int64(list[0].Size)) + } + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + err = client.Rename(testFileName, testFileName+"1") + assert.NoError(t, err) + err = client.Delete(testFileName) + assert.Error(t, err) + err = client.Delete(testFileName + "1") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-encryptedFileSize, user.UsedQuotaSize) + curDir, err := client.CurrentDir() + if assert.NoError(t, err) { + assert.Equal(t, "/", curDir) + } + testDir := "testDir" + err = client.MakeDir(testDir) + assert.NoError(t, err) + err = client.ChangeDir(testDir) + assert.NoError(t, err) + curDir, err = client.CurrentDir() + if assert.NoError(t, err) { + assert.Equal(t, path.Join("/", testDir), curDir) + } + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + size, err := client.FileSize(path.Join("/", testDir, testFileName)) + assert.NoError(t, err) + assert.Equal(t, testFileSize, size) + err = client.ChangeDirToParent() + assert.NoError(t, err) + curDir, err = client.CurrentDir() + if assert.NoError(t, err) { + assert.Equal(t, "/", curDir) + } + err = client.Delete(path.Join("/", testDir, testFileName)) + assert.NoError(t, err) + err = client.Delete(testDir) + assert.Error(t, err) + err = client.RemoveDir(testDir) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestBufferedCryptFs(t *testing.T) { + u := getTestUserWithCryptFs() + u.FsConfig.CryptConfig.OSFsConfig = sdk.OSFsConfig{ + ReadBufferSize: 1, + WriteBufferSize: 1, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + // overwrite an existing file + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + info, err := os.Stat(localDownloadPath) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestZeroBytesTransfersCryptFs(t *testing.T) { + u := getTestUserWithCryptFs() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFileName := "testfilename" + err = checkBasicFTP(client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, "emptydownload") + err = os.WriteFile(localDownloadPath, []byte(""), os.ModePerm) + assert.NoError(t, err) + err = ftpUploadFile(localDownloadPath, testFileName, 0, client, 0) + assert.NoError(t, err) + size, err := client.FileSize(testFileName) + assert.NoError(t, err) + assert.Equal(t, int64(0), size) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + assert.NoFileExists(t, localDownloadPath) + err = ftpDownloadFile(testFileName, localDownloadPath, 0, client, 0) + assert.NoError(t, err) + info, err := os.Stat(localDownloadPath) + if assert.NoError(t, err) { + assert.Equal(t, int64(0), info.Size()) + } + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestResumeCryptFs(t *testing.T) { + u := getTestUserWithCryptFs() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + data := []byte("test data") + err = os.WriteFile(testFilePath, data, os.ModePerm) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) + assert.NoError(t, err) + // resuming uploads is not supported + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)+5), client, 5) + assert.Error(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(4), client, 5) + assert.NoError(t, err) + readed, err := os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, data[5:], readed) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(8), client, 1) + assert.NoError(t, err) + readed, err = os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, data[1:], readed) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(0), client, 9) + assert.NoError(t, err) + err = client.Delete(testFileName) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) + assert.NoError(t, err) + // now append to a file + srcFile, err := os.Open(testFilePath) + if assert.NoError(t, err) { + err = client.Append(testFileName, srcFile) + assert.Error(t, err) + err = srcFile.Close() + assert.NoError(t, err) + size, err := client.FileSize(testFileName) + assert.NoError(t, err) + assert.Equal(t, int64(len(data)), size) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)), client, 0) + assert.NoError(t, err) + readed, err = os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, data, readed) + } + // now test a download resume using a bigger file + testFileSize := int64(655352) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + initialHash, err := computeHashForFile(sha256.New(), testFilePath) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + downloadHash, err := computeHashForFile(sha256.New(), localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, initialHash, downloadHash) + err = os.Truncate(localDownloadPath, 32767) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath+"_partial", testFileSize-32767, client, 32767) //nolint:goconst + assert.NoError(t, err) + file, err := os.OpenFile(localDownloadPath, os.O_APPEND|os.O_WRONLY, os.ModePerm) + assert.NoError(t, err) + file1, err := os.Open(localDownloadPath + "_partial") //nolint:goconst + assert.NoError(t, err) + _, err = io.Copy(file, file1) + assert.NoError(t, err) + err = file.Close() + assert.NoError(t, err) + err = file1.Close() + assert.NoError(t, err) + downloadHash, err = computeHashForFile(sha256.New(), localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, initialHash, downloadHash) + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath + "_partial") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func getTestUserWithCryptFs() dataprovider.User { + user := getTestUser() + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase") + return user +} + +func getEncryptedFileSize(size int64) (int64, error) { + encSize, err := sio.EncryptedSize(uint64(size)) + return int64(encSize) + 33, err +} + +func computeHashForFile(hasher hash.Hash, path string) (string, error) { + hash := "" + f, err := os.Open(path) + if err != nil { + return hash, err + } + defer f.Close() + _, err = io.Copy(hasher, f) + if err == nil { + hash = fmt.Sprintf("%x", hasher.Sum(nil)) + } + return hash, err +} diff --git a/internal/ftpd/ftpd.go b/internal/ftpd/ftpd.go new file mode 100644 index 00000000..9fdfacf0 --- /dev/null +++ b/internal/ftpd/ftpd.go @@ -0,0 +1,459 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package ftpd implements the FTP protocol +package ftpd + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "os" + "path/filepath" + "strings" + "time" + + ftpserver "github.com/fclairamb/ftpserverlib" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + logSender = "ftpd" +) + +var ( + certMgr *common.CertManager + serviceStatus ServiceStatus +) + +// PassiveIPOverride defines an exception for the configured passive IP +type PassiveIPOverride struct { + Networks []string `json:"networks" mapstructure:"networks"` + // if empty the local address will be returned + IP string `json:"ip" mapstructure:"ip"` + parsedNetworks []func(net.IP) bool +} + +// GetNetworksAsString returns the configured networks as string +func (p *PassiveIPOverride) GetNetworksAsString() string { + return strings.Join(p.Networks, ", ") +} + +// Binding defines the configuration for a network listener +type Binding struct { + // The address to listen on. A blank value means listen on all available network interfaces. + Address string `json:"address" mapstructure:"address"` + // The port used for serving requests + Port int `json:"port" mapstructure:"port"` + // Apply the proxy configuration, if any, for this binding + ApplyProxyConfig bool `json:"apply_proxy_config" mapstructure:"apply_proxy_config"` + // Set to 1 to require TLS for both data and control connection. + // Set to 2 to enable implicit TLS + TLSMode int `json:"tls_mode" mapstructure:"tls_mode"` + // Certificate and matching private key for this specific binding, if empty the global + // ones will be used, if any + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 + MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` + // External IP address for passive connections. + ForcePassiveIP string `json:"force_passive_ip" mapstructure:"force_passive_ip"` + // PassiveIPOverrides allows to define different IP addresses for passive connections + // based on the client IP address + PassiveIPOverrides []PassiveIPOverride `json:"passive_ip_overrides" mapstructure:"passive_ip_overrides"` + // Hostname for passive connections. This hostname will be resolved each time a passive + // connection is requested and this can, depending on the DNS configuration, take a noticeable + // amount of time. Enable this setting only if you have a dynamic IP address + PassiveHost string `json:"passive_host" mapstructure:"passive_host"` + // Set to 1 to require client certificate authentication. + // Set to 2 to require a client certificate and verfify it if given. In this mode + // the client is allowed not to send a certificate. + // You need to define at least a certificate authority for this to work + ClientAuthType int `json:"client_auth_type" mapstructure:"client_auth_type"` + // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. + // If CipherSuites is nil/empty, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. + // Note that TLS 1.3 ciphersuites are not configurable. + // The supported ciphersuites names are defined here: + // + // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 + // + // any invalid name will be silently ignored. + // The order matters, the ciphers listed first will be the preferred ones. + TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` + // PassiveConnectionsSecurity defines the security checks for passive data connections. + // Supported values: + // - 0 require matching peer IP addresses of control and data connection. This is the default + // - 1 disable any checks + PassiveConnectionsSecurity int `json:"passive_connections_security" mapstructure:"passive_connections_security"` + // ActiveConnectionsSecurity defines the security checks for active data connections. + // The supported values are the same as described for PassiveConnectionsSecurity. + // Please note that disabling the security checks you will make the FTP service vulnerable to bounce attacks + // on active data connections, so change the default value only if you are on a trusted/internal network + ActiveConnectionsSecurity int `json:"active_connections_security" mapstructure:"active_connections_security"` + // Debug enables the FTP debug mode. In debug mode, every FTP command will be logged + Debug bool `json:"debug" mapstructure:"debug"` + ciphers []uint16 +} + +func (b *Binding) setCiphers() { + b.ciphers = util.GetTLSCiphersFromNames(b.TLSCipherSuites) +} + +func (b *Binding) isMutualTLSEnabled() bool { + return b.ClientAuthType == 1 || b.ClientAuthType == 2 +} + +// GetAddress returns the binding address +func (b *Binding) GetAddress() string { + return fmt.Sprintf("%s:%d", b.Address, b.Port) +} + +// IsValid returns true if the binding port is > 0 +func (b *Binding) IsValid() bool { + return b.Port > 0 +} + +func (b *Binding) isTLSModeValid() bool { + return b.TLSMode >= 0 && b.TLSMode <= 2 +} + +func (b *Binding) checkSecuritySettings() error { + if b.PassiveConnectionsSecurity < 0 || b.PassiveConnectionsSecurity > 1 { + return fmt.Errorf("invalid passive_connections_security: %v", b.PassiveConnectionsSecurity) + } + if b.ActiveConnectionsSecurity < 0 || b.ActiveConnectionsSecurity > 1 { + return fmt.Errorf("invalid active_connections_security: %v", b.ActiveConnectionsSecurity) + } + return nil +} + +func (b *Binding) checkPassiveIP() error { + if b.ForcePassiveIP != "" { + ip, err := parsePassiveIP(b.ForcePassiveIP) + if err != nil { + return err + } + b.ForcePassiveIP = ip + } + for idx, passiveOverride := range b.PassiveIPOverrides { + var ip string + + if passiveOverride.IP != "" { + var err error + ip, err = parsePassiveIP(passiveOverride.IP) + if err != nil { + return err + } + } + if len(passiveOverride.Networks) == 0 { + return errors.New("passive IP networks override cannot be empty") + } + checkFuncs, err := util.ParseAllowedIPAndRanges(passiveOverride.Networks) + if err != nil { + return fmt.Errorf("invalid passive IP networks override %+v: %w", passiveOverride.Networks, err) + } + b.PassiveIPOverrides[idx].IP = ip + b.PassiveIPOverrides[idx].parsedNetworks = checkFuncs + } + return nil +} + +func (b *Binding) getPassiveIP(cc ftpserver.ClientContext) (string, error) { + if b.ForcePassiveIP != "" { + return b.ForcePassiveIP, nil + } + if b.PassiveHost != "" { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + addrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", b.PassiveHost) + if err != nil { + logger.Error(logSender, "", "unable to resolve hostname %q: %v", b.PassiveHost, err) + return "", fmt.Errorf("unable to resolve hostname %q: %w", b.PassiveHost, err) + } + if len(addrs) > 0 { + return addrs[0].String(), nil + } + } + return strings.Split(cc.LocalAddr().String(), ":")[0], nil +} + +func (b *Binding) passiveIPResolver(cc ftpserver.ClientContext) (string, error) { + if len(b.PassiveIPOverrides) > 0 { + clientIP := net.ParseIP(util.GetIPFromRemoteAddress(cc.RemoteAddr().String())) + if clientIP != nil { + for _, override := range b.PassiveIPOverrides { + for _, fn := range override.parsedNetworks { + if fn(clientIP) { + if override.IP == "" { + return strings.Split(cc.LocalAddr().String(), ":")[0], nil + } + return override.IP, nil + } + } + } + } + } + return b.getPassiveIP(cc) +} + +// HasProxy returns true if the proxy protocol is active for this binding +func (b *Binding) HasProxy() bool { + return b.ApplyProxyConfig && common.Config.ProxyProtocol > 0 +} + +// GetTLSDescription returns the TLS mode as string +func (b *Binding) GetTLSDescription() string { + if certMgr == nil { + return util.I18nFTPTLSDisabled + } + switch b.TLSMode { + case 1: + return util.I18nFTPTLSExplicit + case 2: + return util.I18nFTPTLSImplicit + } + + if certMgr.HasCertificate(common.DefaultTLSKeyPaidID) || certMgr.HasCertificate(b.GetAddress()) { + return util.I18nFTPTLSMixed + } + return util.I18nFTPTLSDisabled +} + +// PortRange defines a port range +type PortRange struct { + // Range start + Start int `json:"start" mapstructure:"start"` + // Range end + End int `json:"end" mapstructure:"end"` +} + +// ServiceStatus defines the service status +type ServiceStatus struct { + IsActive bool `json:"is_active"` + Bindings []Binding `json:"bindings"` + PassivePortRange PortRange `json:"passive_port_range"` +} + +// Configuration defines the configuration for the ftp server +type Configuration struct { + // Addresses and ports to bind to + Bindings []Binding `json:"bindings" mapstructure:"bindings"` + // The contents of the specified file, if any, are diplayed when someone connects to the server. + BannerFile string `json:"banner_file" mapstructure:"banner_file"` + // If files containing a certificate and matching private key for the server are provided the server will accept + // both plain FTP an explicit FTP over TLS. + // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a + // "paramchange" request to the running service on Windows. + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // CACertificates defines the set of root certificate authorities to be used to verify client certificates. + CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` + // CARevocationLists defines a set a revocation lists, one for each root CA, to be used to check + // if a client certificate has been revoked + CARevocationLists []string `json:"ca_revocation_lists" mapstructure:"ca_revocation_lists"` + // Do not impose the port 20 for active data transfer. Enabling this option allows to run SFTPGo with less privilege + ActiveTransfersPortNon20 bool `json:"active_transfers_port_non_20" mapstructure:"active_transfers_port_non_20"` + // Set to true to disable active FTP + DisableActiveMode bool `json:"disable_active_mode" mapstructure:"disable_active_mode"` + // Set to true to enable the FTP SITE command. + // We support chmod and symlink if SITE support is enabled + EnableSite bool `json:"enable_site" mapstructure:"enable_site"` + // Set to 1 to enable FTP commands that allow to calculate the hash value of files. + // These FTP commands will be enabled: HASH, XCRC, MD5/XMD5, XSHA/XSHA1, XSHA256, XSHA512. + // Please keep in mind that to calculate the hash we need to read the whole file, for + // remote backends this means downloading the file, for the encrypted backend this means + // decrypting the file + HASHSupport int `json:"hash_support" mapstructure:"hash_support"` + // Set to 1 to enable support for the non standard "COMB" FTP command. + // Combine is only supported for local filesystem, for cloud backends it has + // no advantage as it will download the partial files and will upload the + // combined one. Cloud backends natively support multipart uploads. + CombineSupport int `json:"combine_support" mapstructure:"combine_support"` + // Port Range for data connections. Random if not specified + PassivePortRange PortRange `json:"passive_port_range" mapstructure:"passive_port_range"` + acmeDomain string +} + +// ShouldBind returns true if there is at least a valid binding +func (c *Configuration) ShouldBind() bool { + for _, binding := range c.Bindings { + if binding.IsValid() { + return true + } + } + + return false +} + +func (c *Configuration) getKeyPairs(configDir string) []common.TLSKeyPair { + var keyPairs []common.TLSKeyPair + + for _, binding := range c.Bindings { + certificateFile := getConfigPath(binding.CertificateFile, configDir) + certificateKeyFile := getConfigPath(binding.CertificateKeyFile, configDir) + if certificateFile != "" && certificateKeyFile != "" { + keyPairs = append(keyPairs, common.TLSKeyPair{ + Cert: certificateFile, + Key: certificateKeyFile, + ID: binding.GetAddress(), + }) + } + } + var certificateFile, certificateKeyFile string + if c.acmeDomain != "" { + certificateFile, certificateKeyFile = util.GetACMECertificateKeyPair(c.acmeDomain) + } else { + certificateFile = getConfigPath(c.CertificateFile, configDir) + certificateKeyFile = getConfigPath(c.CertificateKeyFile, configDir) + } + if certificateFile != "" && certificateKeyFile != "" { + keyPairs = append(keyPairs, common.TLSKeyPair{ + Cert: certificateFile, + Key: certificateKeyFile, + ID: common.DefaultTLSKeyPaidID, + }) + } + return keyPairs +} + +func (c *Configuration) loadFromProvider() error { + configs, err := dataprovider.GetConfigs() + if err != nil { + return fmt.Errorf("unable to load config from provider: %w", err) + } + configs.SetNilsToEmpty() + if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolFTP) { + return nil + } + crt, key := util.GetACMECertificateKeyPair(configs.ACME.Domain) + if crt != "" && key != "" { + if _, err := os.Stat(crt); err != nil { + logger.Error(logSender, "", "unable to load acme cert file %q: %v", crt, err) + return nil + } + if _, err := os.Stat(key); err != nil { + logger.Error(logSender, "", "unable to load acme key file %q: %v", key, err) + return nil + } + c.acmeDomain = configs.ACME.Domain + logger.Info(logSender, "", "acme domain set to %q", c.acmeDomain) + return nil + } + return nil +} + +// Initialize configures and starts the FTP server +func (c *Configuration) Initialize(configDir string) error { + if err := c.loadFromProvider(); err != nil { + return err + } + logger.Info(logSender, "", "initializing FTP server with config %+v", *c) + if !c.ShouldBind() { + return common.ErrNoBinding + } + + keyPairs := c.getKeyPairs(configDir) + if len(keyPairs) > 0 { + mgr, err := common.NewCertManager(keyPairs, configDir, logSender) + if err != nil { + return err + } + mgr.SetCACertificates(c.CACertificates) + if err := mgr.LoadRootCAs(); err != nil { + return err + } + mgr.SetCARevocationLists(c.CARevocationLists) + if err := mgr.LoadCRLs(); err != nil { + return err + } + certMgr = mgr + } + serviceStatus = ServiceStatus{ + Bindings: nil, + PassivePortRange: c.PassivePortRange, + } + + exitChannel := make(chan error, 1) + + for idx, binding := range c.Bindings { + if !binding.IsValid() { + continue + } + + server := NewServer(c, configDir, binding, idx) + + go func(s *Server) { + ftpLogger := logger.NewSlogAdapter("ftpserverlib", []slog.Attr{ + { + Key: "server_id", + Value: slog.StringValue(fmt.Sprintf("FTP_%d", s.ID)), + }, + }) + ftpServer := ftpserver.NewFtpServer(s) + ftpServer.Logger = slog.New(ftpLogger) + logger.Info(logSender, "", "starting FTP serving, binding: %v", s.binding.GetAddress()) + util.CheckTCP4Port(s.binding.Port) + exitChannel <- ftpServer.ListenAndServe() + }(server) + + serviceStatus.Bindings = append(serviceStatus.Bindings, binding) + } + + serviceStatus.IsActive = true + + return <-exitChannel +} + +// ReloadCertificateMgr reloads the certificate manager +func ReloadCertificateMgr() error { + if certMgr != nil { + return certMgr.Reload() + } + return nil +} + +// GetStatus returns the server status +func GetStatus() ServiceStatus { + return serviceStatus +} + +func parsePassiveIP(passiveIP string) (string, error) { + ip := net.ParseIP(passiveIP) + if ip == nil { + return "", fmt.Errorf("the provided passive IP %q is not valid", passiveIP) + } + ip = ip.To4() + if ip == nil { + return "", fmt.Errorf("the provided passive IP %q is not a valid IPv4 address", passiveIP) + } + return ip.String(), nil +} + +func getConfigPath(name, configDir string) string { + if !util.IsFileInputValid(name) { + return "" + } + if name != "" && !filepath.IsAbs(name) { + return filepath.Join(configDir, name) + } + return name +} diff --git a/internal/ftpd/ftpd_test.go b/internal/ftpd/ftpd_test.go new file mode 100644 index 00000000..aebf5f37 --- /dev/null +++ b/internal/ftpd/ftpd_test.go @@ -0,0 +1,4304 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ftpd_test + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "net" + "net/http" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "strconv" + "testing" + "time" + + ftpserver "github.com/fclairamb/ftpserverlib" + "github.com/jlaffaye/ftp" + "github.com/pkg/sftp" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "github.com/rs/zerolog" + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + logSender = "ftpdTesting" + ftpServerAddr = "127.0.0.1:2121" + sftpServerAddr = "127.0.0.1:2122" + ftpSrvAddrTLS = "127.0.0.1:2124" // ftp server with implicit tls + ftpSrvAddrTLSResumption = "127.0.0.1:2126" // ftp server with implicit tls + defaultUsername = "test_user_ftp" + defaultPassword = "test_password" + osWindows = "windows" + ftpsCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + ftpsKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` + caCRT = `-----BEGIN CERTIFICATE----- +MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 +QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT +CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m +fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 +v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh +yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL +CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U +4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb +KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo +NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb +S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i +M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY +/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy +OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD +AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu +XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F +sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki +pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE +jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX +y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy +WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV +oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV +L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 +DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E +eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli +SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w +kK8u1YiiVenmoQ== +-----END CERTIFICATE-----` + caCRL = `-----BEGIN X509 CRL----- +MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN +MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg +ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r +rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV ++jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 +XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE +jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF +DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD +MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 +iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE +ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 +hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 +cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 +sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn +w7bXJGfJcIMrrKs= +-----END X509 CRL-----` + client1Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 +HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE +OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 +Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf +XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO +c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo +Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU +OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC +sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I +l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb +BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz +MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 +QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz +J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ +offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX +G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr +kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au +MU3Bo0A= +-----END CERTIFICATE-----` + client1Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki +J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi +85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW +eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 +ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy +bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO +hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw +eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl +X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ +CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL +j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU +NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf +sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB +Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB +LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g +mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 +T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn +RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa +Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW +asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ +DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 +0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls +ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY +nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe +yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== +-----END RSA PRIVATE KEY-----` + // client 2 crt is revoked + client2Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b +Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua +eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ +cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT +A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 +6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f +Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz +PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY +xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D +EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB +4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO +78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 +7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e +qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N +f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv +/ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S +ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR +E0+r2+4= +-----END CERTIFICATE-----` + client2Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV +jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP +kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK +h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P +QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu +4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq +op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM +qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt +Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL +VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO +qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy +VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl +NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 +4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl +/owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd +aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu +khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz +3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ +eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i +vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB +GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb +Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D +8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl +RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA +ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl +-----END RSA PRIVATE KEY-----` + testFileName = "test_file_ftp.dat" + testDLFileName = "test_download_ftp.dat" + tlsClient1Username = "client1" + tlsClient2Username = "client2" + httpFsPort = 23456 + defaultHTTPFsUsername = "httpfs_ftp_user" + emptyPwdPlaceholder = "empty" +) + +var ( + configDir = filepath.Join(".", "..", "..") + allPerms = []string{dataprovider.PermAny} + homeBasePath string + hookCmdPath string + extAuthPath string + preLoginPath string + postConnectPath string + preDownloadPath string + preUploadPath string + logFilePath string + caCrtPath string + caCRLPath string +) + +func TestMain(m *testing.M) { //nolint:gocyclo + logFilePath = filepath.Join(configDir, "sftpgo_ftpd_test.log") + bannerFileName := "banner_file" + bannerFile := filepath.Join(configDir, bannerFileName) + logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) + err := os.WriteFile(bannerFile, []byte("SFTPGo test ready\nsimple banner line\n"), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error creating banner file: %v", err) + os.Exit(1) + } + // we run the test cases with UploadMode atomic and resume support. The non atomic code path + // simply does not execute some code so if it works in atomic mode will + // work in non atomic mode too + os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") + os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") + os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") + os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") + err = config.LoadConfig(configDir, "") + if err != nil { + logger.ErrorToConsole("error loading configuration: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + logger.InfoToConsole("Starting FTPD tests, provider: %v", providerConf.Driver) + + commonConf := config.GetCommonConfig() + homeBasePath = os.TempDir() + if runtime.GOOS != osWindows { + commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete"} + commonConf.Actions.Hook = hookCmdPath + hookCmdPath, err = exec.LookPath("true") + if err != nil { + logger.Warn(logSender, "", "unable to get hook command: %v", err) + logger.WarnToConsole("unable to get hook command: %v", err) + } + } + + certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") + keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") + caCrtPath = filepath.Join(os.TempDir(), "test_ftpd_ca.crt") + caCRLPath = filepath.Join(os.TempDir(), "test_ftpd_crl.crt") + err = writeCerts(certPath, keyPath, caCrtPath, caCRLPath) + if err != nil { + os.Exit(1) + } + + err = dataprovider.Initialize(providerConf, configDir, true) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + err = common.Initialize(commonConf, 0) + if err != nil { + logger.WarnToConsole("error initializing common: %v", err) + os.Exit(1) + } + + httpConfig := config.GetHTTPConfig() + httpConfig.Initialize(configDir) //nolint:errcheck + + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing kms: %v", err) + os.Exit(1) + } + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing MFA: %v", err) + os.Exit(1) + } + + httpdConf := config.GetHTTPDConfig() + httpdConf.Bindings[0].Port = 8079 + httpdtest.SetBaseURL("http://127.0.0.1:8079") + + ftpdConf := config.GetFTPDConfig() + ftpdConf.Bindings = []ftpd.Binding{ + { + Port: 2121, + ClientAuthType: 2, + CertificateFile: certPath, + CertificateKeyFile: keyPath, + }, + } + ftpdConf.PassivePortRange.Start = 0 + ftpdConf.PassivePortRange.End = 0 + ftpdConf.BannerFile = bannerFileName + ftpdConf.CACertificates = []string{caCrtPath} + ftpdConf.CARevocationLists = []string{caCRLPath} + ftpdConf.EnableSite = true + + // required to test sftpfs + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 2122, + }, + } + hostKeyPath := filepath.Join(os.TempDir(), "id_ed25519") + sftpdConf.HostKeys = []string{hostKeyPath} + + extAuthPath = filepath.Join(homeBasePath, "extauth.sh") + preLoginPath = filepath.Join(homeBasePath, "prelogin.sh") + postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") + preDownloadPath = filepath.Join(homeBasePath, "predownload.sh") + preUploadPath = filepath.Join(homeBasePath, "preupload.sh") + + status := ftpd.GetStatus() + if status.IsActive { + logger.ErrorToConsole("ftpd is already active") + os.Exit(1) + } + + go func() { + logger.Debug(logSender, "", "initializing FTP server with config %+v", ftpdConf) + if err := ftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start FTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + logger.Debug(logSender, "", "initializing SFTP server with config %+v", sftpdConf) + if err := sftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := httpdConf.Initialize(configDir, 0); err != nil { + logger.ErrorToConsole("could not start HTTP server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(ftpdConf.Bindings[0].GetAddress()) + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + ftpd.ReloadCertificateMgr() //nolint:errcheck + + ftpdConf = config.GetFTPDConfig() + ftpdConf.Bindings = []ftpd.Binding{ + { + Port: 2124, + TLSMode: 2, + }, + } + ftpdConf.CertificateFile = certPath + ftpdConf.CertificateKeyFile = keyPath + ftpdConf.CACertificates = []string{caCrtPath} + ftpdConf.CARevocationLists = []string{caCRLPath} + ftpdConf.EnableSite = false + ftpdConf.DisableActiveMode = true + ftpdConf.CombineSupport = 1 + ftpdConf.HASHSupport = 1 + + go func() { + logger.Debug(logSender, "", "initializing FTP server with config %+v", ftpdConf) + if err := ftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start FTP server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(ftpdConf.Bindings[0].GetAddress()) + + ftpdConf = config.GetFTPDConfig() + ftpdConf.Bindings = []ftpd.Binding{ + { + Port: 2126, + CertificateFile: certPath, + CertificateKeyFile: keyPath, + TLSMode: 1, + ClientAuthType: 2, + }, + } + ftpdConf.CACertificates = []string{caCrtPath} + ftpdConf.CARevocationLists = []string{caCRLPath} + + go func() { + logger.Debug(logSender, "", "initializing FTP server with config %+v", ftpdConf) + if err := ftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start FTP server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(ftpdConf.Bindings[0].GetAddress()) + + waitNoConnections() + startHTTPFs() + + exitCode := m.Run() + os.Remove(logFilePath) + os.Remove(bannerFile) + os.Remove(extAuthPath) + os.Remove(preLoginPath) + os.Remove(postConnectPath) + os.Remove(preDownloadPath) + os.Remove(preUploadPath) + os.Remove(certPath) + os.Remove(keyPath) + os.Remove(caCrtPath) + os.Remove(caCRLPath) + os.Remove(hostKeyPath) + os.Remove(hostKeyPath + ".pub") + os.Exit(exitCode) +} + +func TestInitializationFailure(t *testing.T) { + ftpdConf := config.GetFTPDConfig() + ftpdConf.Bindings = []ftpd.Binding{} + ftpdConf.CertificateFile = filepath.Join(os.TempDir(), "test_ftpd.crt") + ftpdConf.CertificateKeyFile = filepath.Join(os.TempDir(), "test_ftpd.key") + err := ftpdConf.Initialize(configDir) + require.EqualError(t, err, common.ErrNoBinding.Error()) + ftpdConf.Bindings = []ftpd.Binding{ + { + Port: 0, + }, + { + Port: 2121, + }, + } + ftpdConf.BannerFile = "a-missing-file" + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + + ftpdConf.BannerFile = "" + ftpdConf.Bindings[1].TLSMode = 10 + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + + ftpdConf.CertificateFile = "" + ftpdConf.CertificateKeyFile = "" + ftpdConf.Bindings[1].TLSMode = 1 + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + + certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") + keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") + ftpdConf.CertificateFile = certPath + ftpdConf.CertificateKeyFile = keyPath + ftpdConf.CACertificates = []string{"invalid ca cert"} + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + + ftpdConf.CACertificates = nil + ftpdConf.CARevocationLists = []string{""} + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + + ftpdConf.CACertificates = []string{caCrtPath} + ftpdConf.CARevocationLists = []string{caCRLPath} + ftpdConf.Bindings[1].ForcePassiveIP = "127001" + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + require.Contains(t, err.Error(), "the provided passive IP \"127001\" is not valid") + ftpdConf.Bindings[1].ForcePassiveIP = "" + err = ftpdConf.Initialize(configDir) + require.Error(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = ftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to load config from provider") + } + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestBasicFTPHandling(t *testing.T) { + u := getTestUser() + u.QuotaSize = 6553600 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaSize = 6553600 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + if user.Username == defaultUsername { + assert.Len(t, common.Connections.GetStats(""), 1) + } else { + assert.Len(t, common.Connections.GetStats(""), 2) + } + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client, 0) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, int64(0), user.FirstDownload) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Equal(t, int64(0), user.FirstDownload) + // overwrite an existing file + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Greater(t, user.FirstDownload, int64(0)) + err = client.Rename(testFileName, testFileName+"1") + assert.NoError(t, err) + err = client.Delete(testFileName) + assert.Error(t, err) + err = client.Delete(testFileName + "1") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) + curDir, err := client.CurrentDir() + if assert.NoError(t, err) { + assert.Equal(t, "/", curDir) + } + testDir := "testDir" + err = client.MakeDir(testDir) + assert.NoError(t, err) + err = client.ChangeDir(testDir) + assert.NoError(t, err) + curDir, err = client.CurrentDir() + if assert.NoError(t, err) { + assert.Equal(t, path.Join("/", testDir), curDir) + } + res, err := client.List(path.Join("/", testDir)) + assert.NoError(t, err) + assert.Len(t, res, 0) + res, err = client.List(path.Join("/")) + assert.NoError(t, err) + if assert.Len(t, res, 1) { + assert.Equal(t, testDir, res[0].Name) + } + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + _, err = client.FileSize(path.Join("/", testDir)) + assert.Error(t, err) + size, err := client.FileSize(path.Join("/", testDir, testFileName)) + assert.NoError(t, err) + assert.Equal(t, testFileSize, size) + err = client.ChangeDirToParent() + assert.NoError(t, err) + curDir, err = client.CurrentDir() + if assert.NoError(t, err) { + assert.Equal(t, "/", curDir) + } + err = client.Delete(path.Join("/", testDir, testFileName)) + assert.NoError(t, err) + err = client.Delete(testDir) + assert.Error(t, err) + err = client.RemoveDir(testDir) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestHTTPFs(t *testing.T) { + u := getTestUserWithHTTPFs() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + // test a download resume + data := []byte("test data") + err = os.WriteFile(testFilePath, data, os.ModePerm) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)-5), client, 5) + assert.NoError(t, err) + readed, err := os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, []byte("data"), readed, "readed data mismatch: %q", string(readed)) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestListDirWithWildcards(t *testing.T) { + localUser, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + + defer func() { + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + }() + + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, true, nil, ftp.DialWithDisabledMLSD(true)) + if assert.NoError(t, err) { + dir1 := "test.dir" + dir2 := "test.dir1" + err = client.MakeDir(dir1) + assert.NoError(t, err) + err = client.MakeDir(dir2) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + fileName := "file[a-z]e.dat" + err = ftpUploadFile(testFilePath, fileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(fileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + entries, err := client.List(fileName) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, fileName, entries[0].Name) + nListEntries, err := client.NameList(fileName) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, nListEntries, fileName) + entries, err = client.List(".") + require.NoError(t, err) + require.Len(t, entries, 3) + nListEntries, err = client.NameList(".") + require.NoError(t, err) + require.Len(t, nListEntries, 3) + entries, err = client.List("/test.*") + require.NoError(t, err) + require.Len(t, entries, 2) + found := 0 + for _, e := range entries { + switch e.Name { + case dir1, dir2: + found++ + } + } + assert.Equal(t, 2, found) + nListEntries, err = client.NameList("/test.*") + require.NoError(t, err) + require.Len(t, entries, 2) + assert.Contains(t, nListEntries, dir1) + assert.Contains(t, nListEntries, dir2) + entries, err = client.List("/*.dir?") + require.NoError(t, err) + assert.Len(t, entries, 1) + assert.Equal(t, dir2, entries[0].Name) + nListEntries, err = client.NameList("/*.dir?") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, nListEntries, dir2) + entries, err = client.List("/test.???") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, dir1, entries[0].Name) + nListEntries, err = client.NameList("/test.???") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, nListEntries, dir1) + _, err = client.NameList("/missingdir/test.*") + assert.Error(t, err) + _, err = client.List("/missingdir/test.*") + assert.Error(t, err) + _, err = client.NameList("test[-]") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), path.ErrBadPattern.Error()) + } + _, err = client.List("test[-]") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), path.ErrBadPattern.Error()) + } + subDir := path.Join(dir1, "sub.d") + err = client.MakeDir(subDir) + assert.NoError(t, err) + err = client.ChangeDir(path.Dir(subDir)) + assert.NoError(t, err) + entries, err = client.List("sub.?") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, path.Base(subDir), entries[0].Name) + nListEntries, err = client.NameList("sub.?") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, nListEntries, path.Base(subDir)) + entries, err = client.List("../*.dir?") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, path.Join("../", dir2), entries[0].Name) + nListEntries, err = client.NameList("../*.dir?") + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, nListEntries, path.Join("../", dir2)) + + err = client.ChangeDir("/") + assert.NoError(t, err) + entries, err = client.List(path.Join(dir1, "sub.*")) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Equal(t, path.Join(dir1, "sub.d"), entries[0].Name) + nListEntries, err = client.NameList(path.Join(dir1, "sub.*")) + require.NoError(t, err) + require.Len(t, entries, 1) + assert.Contains(t, nListEntries, path.Join(dir1, "sub.d")) + err = client.RemoveDir(subDir) + assert.NoError(t, err) + err = client.RemoveDir(dir1) + assert.NoError(t, err) + err = client.RemoveDir(dir2) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + } +} + +func TestStartDirectory(t *testing.T) { + startDir := "/start/dir" + u := getTestUser() + u.Filters.StartDirectory = startDir + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.Filters.StartDirectory = startDir + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + currentDir, err := client.CurrentDir() + assert.NoError(t, err) + assert.Equal(t, startDir, currentDir) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + entries, err := client.List(".") + assert.NoError(t, err) + if assert.Len(t, entries, 1) { + assert.Equal(t, testFileName, entries[0].Name) + } + entries, err = client.List("/") + assert.NoError(t, err) + if assert.Len(t, entries, 1) { + assert.Equal(t, "start", entries[0].Name) + } + err = client.ChangeDirToParent() + assert.NoError(t, err) + currentDir, err = client.CurrentDir() + assert.NoError(t, err) + assert.Equal(t, path.Dir(startDir), currentDir) + err = client.ChangeDirToParent() + assert.NoError(t, err) + currentDir, err = client.CurrentDir() + assert.NoError(t, err) + assert.Equal(t, "/", currentDir) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginEmptyPassword(t *testing.T) { + u := getTestUser() + u.Password = "" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Password = emptyPwdPlaceholder + + _, err = getFTPClient(user, true, nil) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAnonymousUser(t *testing.T) { + u := getTestUser() + u.Password = "" + u.Filters.IsAnonymous = true + _, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.Error(t, err) + user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.IsAnonymous) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) + assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) + assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) + + user.Password = emptyPwdPlaceholder + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = os.Rename(testFilePath, filepath.Join(user.GetHomeDir(), testFileName)) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = client.MakeDir("adir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAnonymousGroupInheritance(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.IsAnonymous = true + g.UserSettings.Permissions = make(map[string][]string) + g.UserSettings.Permissions["/"] = allPerms + g.UserSettings.Permissions["/testsub"] = allPerms + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + user.Password = emptyPwdPlaceholder + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = client.MakeDir("adir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = client.MakeDir("/testsub/adir") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = os.Rename(testFilePath, filepath.Join(user.GetHomeDir(), testFileName)) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + user.Password = defaultPassword + client, err = getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestMultiFactorAuth(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolFTP}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + user.Password = defaultPassword + _, err = getFTPClient(user, true, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) + } + passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) + assert.NoError(t, err) + user.Password = defaultPassword + passcode + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + // reusing the same passcode should not work + _, err = getFTPClient(user, true, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMustChangePasswordRequirement(t *testing.T) { + u := getTestUser() + u.Filters.RequirePasswordChange = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + _, err = getFTPClient(user, true, nil) + assert.Error(t, err) + + err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") + assert.NoError(t, err) + + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSecondFactorRequirement(t *testing.T) { + u := getTestUser() + u.Filters.TwoFactorAuthProtocols = []string{common.ProtocolFTP} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + _, err = getFTPClient(user, true, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "second factor authentication is not set") + } + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolFTP}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + passcode, err := generateTOTPPasscode(key.Secret(), otp.AlgorithmSHA1) + assert.NoError(t, err) + user.Password = defaultPassword + passcode + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginInvalidCredentials(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Username = "wrong username" + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) + } + user.Username = u.Username + user.Password = "wrong pwd" + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), dataprovider.ErrInvalidCredentials.Error()) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoginNonExistentUser(t *testing.T) { + user := getTestUser() + _, err := getFTPClient(user, false, nil) + assert.Error(t, err) +} + +func TestFTPSecurity(t *testing.T) { + u := getTestUser() + u.Filters.FTPSecurity = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TLS is required") + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestGroupFTPSecurity(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.FTPSecurity = 1 + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TLS is required") + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoginExternalAuth(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + g := getTestGroup() + g.UserSettings.Filters.DeniedProtocols = []string{common.ProtocolFTP} + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + + client, err := getFTPClient(u, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) + assert.NoError(t, err) + _, err = getFTPClient(u, true, nil) + if !assert.Error(t, err) { + err := client.Quit() + assert.NoError(t, err) + } else { + assert.Contains(t, err.Error(), "protocol FTP is not allowed") + } + + u.Groups = nil + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) + assert.NoError(t, err) + u.Username = defaultUsername + "1" + client, err = getFTPClient(u, true, nil) + if !assert.Error(t, err) { + err := client.Quit() + assert.NoError(t, err) + } else { + assert.Contains(t, err.Error(), "invalid credentials") + } + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, defaultUsername, user.Username) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestPreLoginHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) + assert.NoError(t, err) + client, err := getFTPClient(u, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + + // test login with an existing user + client, err = getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) + assert.NoError(t, err) + client, err = getFTPClient(u, false, nil) + if !assert.Error(t, err) { + err := client.Quit() + assert.NoError(t, err) + } + user.Status = 0 + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) + assert.NoError(t, err) + client, err = getFTPClient(u, false, nil) + if !assert.Error(t, err, "pre-login script returned a disabled user, login must fail") { + err := client.Quit() + assert.NoError(t, err) + } + user.Status = 0 + user.Filters.FTPSecurity = 1 + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) + assert.NoError(t, err) + client, err = getFTPClient(u, true, nil) + if !assert.Error(t, err) { + err := client.Quit() + assert.NoError(t, err) + } + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TLS is required") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestPreLoginHookReturningAnonymousUser(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + u.Filters.IsAnonymous = true + u.Filters.DeniedProtocols = []string{common.ProtocolSSH} + u.Password = "" + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + // the pre-login hook create the anonymous user + client, err := getFTPClient(u, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = client.MakeDir("tdiranonymous") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = os.Rename(testFilePath, filepath.Join(u.GetHomeDir(), testFileName)) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.IsAnonymous) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) + assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) + assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) + // now the same with an existing user + client, err = getFTPClient(u, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission") + } + err = os.Rename(testFilePath, filepath.Join(u.GetHomeDir(), testFileName)) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestPreDownloadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} + common.Config.Actions.Hook = preDownloadPath + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + // now return an error from the pre-download hook + err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + client, err = getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "permission denied") + } + err := client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPreUploadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} + common.Config.Actions.Hook = preUploadPath + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preUploadPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + // now return an error from the pre-upload hook + err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + client, err = getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ftpserver.ErrFileNameNotAllowed.Error()) + } + err = ftpUploadFile(testFilePath, testFileName+"1", testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ftpserver.ErrFileNameNotAllowed.Error()) + } + err := client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPostConnectHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + common.Config.PostConnectHook = postConnectPath + + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + client, err = getFTPClient(user, true, nil) + if !assert.Error(t, err) { + err := client.Quit() + assert.NoError(t, err) + } + + common.Config.PostConnectHook = "http://127.0.0.1:8079/healthz" + + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + common.Config.PostConnectHook = "http://127.0.0.1:8079/notfound" + + client, err = getFTPClient(user, true, nil) + if !assert.Error(t, err) { + err := client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.PostConnectHook = "" +} + +//nolint:dupl +func TestMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user := getTestUser() + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxTotalConnections = oldValue +} + +//nolint:dupl +func TestMaxPerHostConnections(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 1 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user := getTestUser() + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user := getTestUser() + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + conn, sftpClient, err := getSftpClient(user) + assert.NoError(t, err) + defer conn.Close() + defer sftpClient.Close() + + f1, err := sftpClient.Create("file1") + assert.NoError(t, err) + f2, err := sftpClient.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.Error(t, err) + err := client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestRateLimiter(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Threshold = 5 + cfg.DefenderConfig.ScoreLimitExceeded = 3 + cfg.RateLimitersConfig = []common.RateLimiterConfig{ + { + Average: 1, + Period: 1000, + Burst: 1, + Type: 2, + Protocols: []string{common.ProtocolFTP}, + GenerateDefenderEvents: true, + EntriesSoftLimit: 100, + EntriesHardLimit: 150, + }, + } + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + + _, err = getFTPClient(user, true, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "rate limit exceed") + } + + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "rate limit exceed") + } + + _, err = getFTPClient(user, true, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "banned client IP") + } + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestDefender(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Threshold = 4 + cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreNoAuth = 1 + cfg.DefenderConfig.ScoreValid = 1 + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + // just dial without login + ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} + client, err = ftp.Dial(ftpServerAddr, ftpOptions...) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 1, host.Score) + } + user.Password = "wrong_pwd" + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 2, host.Score) + } + + for i := 0; i < 2; i++ { + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + } + + user.Password = defaultPassword + _, err = getFTPClient(user, false, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "banned client IP") + } + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestMaxSessions(t *testing.T) { + u := getTestUser() + u.MaxSessions = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestZeroBytesTransfers(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, useTLS := range []bool{true, false} { + client, err := getFTPClient(user, useTLS, nil) + if assert.NoError(t, err) { + testFileName := "testfilename" + err = checkBasicFTP(client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, "empty_download") + err = os.WriteFile(localDownloadPath, []byte(""), os.ModePerm) + assert.NoError(t, err) + err = ftpUploadFile(localDownloadPath, testFileName, 0, client, 0) + assert.NoError(t, err) + size, err := client.FileSize(testFileName) + assert.NoError(t, err) + assert.Equal(t, int64(0), size) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + assert.NoFileExists(t, localDownloadPath) + err = ftpDownloadFile(testFileName, localDownloadPath, 0, client, 0) + assert.NoError(t, err) + assert.FileExists(t, localDownloadPath) + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDownloadErrors(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1 + subDir1 := "sub1" + subDir2 := "sub2" + u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} + u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDelete, dataprovider.PermDownload} + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/sub2", + AllowedPatterns: []string{}, + DeniedPatterns: []string{"*.jpg", "*.zip"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath1 := filepath.Join(user.HomeDir, subDir1, "file.zip") + testFilePath2 := filepath.Join(user.HomeDir, subDir2, "file.zip") + testFilePath3 := filepath.Join(user.HomeDir, subDir2, "file.jpg") + err = os.MkdirAll(filepath.Dir(testFilePath1), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Dir(testFilePath2), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath1, []byte("file1"), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath2, []byte("file2"), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath3, []byte("file3"), os.ModePerm) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(path.Join("/", subDir1, "file.zip"), localDownloadPath, 5, client, 0) + assert.Error(t, err) + err = ftpDownloadFile(path.Join("/", subDir2, "file.zip"), localDownloadPath, 5, client, 0) + assert.Error(t, err) + err = ftpDownloadFile(path.Join("/", subDir2, "file.jpg"), localDownloadPath, 5, client, 0) + assert.Error(t, err) + err = ftpDownloadFile("/missing.zip", localDownloadPath, 5, client, 0) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadErrors(t *testing.T) { + u := getTestUser() + u.QuotaSize = 65535 + subDir1 := "sub1" + subDir2 := "sub2" + u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} + u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDelete} + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/sub2", + AllowedPatterns: []string{}, + DeniedPatterns: []string{"*.zip"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := user.QuotaSize + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = client.MakeDir(subDir1) + assert.NoError(t, err) + err = client.MakeDir(subDir2) + assert.NoError(t, err) + err = client.ChangeDir(subDir1) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + err = client.ChangeDirToParent() + assert.NoError(t, err) + err = client.ChangeDir(subDir2) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName+".zip", testFileSize, client, 0) + assert.Error(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + err = client.ChangeDir("/") + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, subDir1, testFileSize, client, 0) + assert.Error(t, err) + // overquota + err = ftpUploadFile(testFilePath, testFileName+"1", testFileSize, client, 0) + assert.Error(t, err) + err = client.Delete(path.Join("/", subDir2, testFileName)) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPBuffered(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 100 + u.FsConfig.SFTPConfig.BufferSize = 2 + u.HomeDir = filepath.Join(os.TempDir(), u.Username) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(sftpUser, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + // overwrite an existing file + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + user, _, err := httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + + data := []byte("test data") + err = os.WriteFile(testFilePath, data, os.ModePerm) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)+5), client, 5) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "operation unsupported") + } + err = ftpDownloadFile(testFileName, localDownloadPath, int64(4), client, 5) + assert.NoError(t, err) + readed, err := os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, []byte("data"), readed) + // try to append to a file, it should fail + // now append to a file + srcFile, err := os.Open(testFilePath) + if assert.NoError(t, err) { + err = client.Append(testFileName, srcFile) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "operation unsupported") + } + err = srcFile.Close() + assert.NoError(t, err) + size, err := client.FileSize(testFileName) + assert.NoError(t, err) + assert.Equal(t, int64(len(data)), size) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)), client, 0) + assert.NoError(t, err) + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(sftpUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestResume(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + u = getTestUser() + u.FsConfig.OSConfig = sdk.OSFsConfig{ + ReadBufferSize: 1, + WriteBufferSize: 1, + } + u.Username += "_buf" + u.HomeDir += "_buf" + bufferedUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser, bufferedUser} { + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + data := []byte("test data") + err = os.WriteFile(testFilePath, data, os.ModePerm) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)+5), client, 5) + assert.NoError(t, err) + readed, err := os.ReadFile(filepath.Join(user.GetHomeDir(), testFileName)) + assert.NoError(t, err) + assert.Equal(t, "test test data", string(readed)) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(len(data)), client, 5) + assert.NoError(t, err) + readed, err = os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, data, readed) + err = client.Delete(testFileName) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, int64(len(data)), client, 0) + assert.NoError(t, err) + // now append to a file + srcFile, err := os.Open(testFilePath) + if assert.NoError(t, err) { + err = client.Append(testFileName, srcFile) + assert.NoError(t, err) + err = srcFile.Close() + assert.NoError(t, err) + size, err := client.FileSize(testFileName) + assert.NoError(t, err) + assert.Equal(t, int64(2*len(data)), size) + err = ftpDownloadFile(testFileName, localDownloadPath, int64(2*len(data)), client, 0) + assert.NoError(t, err) + readed, err = os.ReadFile(localDownloadPath) + assert.NoError(t, err) + expected := append(data, data...) + assert.Equal(t, expected, readed) + } + // append to a new file + srcFile, err = os.Open(testFilePath) + if assert.NoError(t, err) { + newFileName := testFileName + "_new" + err = client.Append(newFileName, srcFile) + assert.NoError(t, err) + err = srcFile.Close() + assert.NoError(t, err) + size, err := client.FileSize(newFileName) + assert.NoError(t, err) + assert.Equal(t, int64(len(data)), size) + err = ftpDownloadFile(newFileName, localDownloadPath, int64(len(data)), client, 0) + assert.NoError(t, err) + readed, err = os.ReadFile(localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, data, readed) + } + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(bufferedUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(bufferedUser.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestDeniedLoginMethod(t *testing.T) { + u := getTestUser() + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyAndPassword} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + assert.NoError(t, checkBasicFTP(client)) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestDeniedProtocols(t *testing.T) { + u := getTestUser() + u.Filters.DeniedProtocols = []string{common.ProtocolFTP} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + _, err = getFTPClient(user, false, nil) + assert.Error(t, err) + user.Filters.DeniedProtocols = []string{common.ProtocolSSH, common.ProtocolWebDAV} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + assert.NoError(t, checkBasicFTP(client)) + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaLimits(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 1 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + testFileSize := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + testFileSize2 := int64(32768) + testFileName2 := "test_file2.dat" + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath2, testFileSize2) + assert.NoError(t, err) + // test quota files + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client, 0) //nolint:goconst + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName+".quota1", testFileSize, client, 0) + assert.Error(t, err) + err = client.Rename(testFileName+".quota", testFileName) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + // test quota size + user.QuotaSize = testFileSize - 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client, 0) + assert.Error(t, err) + err = client.Rename(testFileName, testFileName+".quota") + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + // now test quota limits while uploading the current file, we have 1 bytes remaining + user.QuotaSize = testFileSize + 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath1, testFileName1, testFileSize1, client, 0) + assert.Error(t, err) + _, err = client.FileSize(testFileName1) + assert.Error(t, err) + err = client.Rename(testFileName+".quota", testFileName) + assert.NoError(t, err) + // overwriting an existing file will work if the resulting size is lesser or equal than the current one + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath2, testFileName, testFileSize2, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath1, testFileName, testFileSize1, client, 0) + assert.Error(t, err) + err = ftpUploadFile(testFilePath1, testFileName, testFileSize1, client, 10) + assert.Error(t, err) + err = ftpUploadFile(testFilePath2, testFileName, testFileSize2, client, 0) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + user.QuotaFiles = 0 + user.QuotaSize = 0 + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.QuotaSize = 0 + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadMaxSize(t *testing.T) { + testFileSize := int64(65535) + u := getTestUser() + u.Filters.MaxUploadFileSize = testFileSize + 1 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.Filters.MaxUploadFileSize = testFileSize + 1 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath1, testFileName1, testFileSize1, client, 0) + assert.Error(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + // now test overwrite an existing file with a size bigger than the allowed one + err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath1, testFileName1, testFileSize1, client, 0) + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.MaxUploadFileSize = 65536000 + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginWithIPilters(t *testing.T) { + u := getTestUser() + u.Filters.DeniedIP = []string{"192.167.0.0/24", "172.18.0.0/16"} + u.Filters.AllowedIP = []string{"172.19.0.0/16"} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if !assert.Error(t, err) { + err = client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginWithDatabaseCredentials(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account", "private_key": " ", "client_email": "example@iam.gserviceaccount.com" }`) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.NotEmpty(t, user.FsConfig.GCSConfig.Credentials.GetPayload()) + assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) + + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginInvalidFs(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client, err := getFTPClient(user, false, nil) + if !assert.Error(t, err) { + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientClose(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + stats := common.Connections.GetStats("") + if assert.Len(t, stats, 1) { + common.Connections.Close(stats[0].ConnectionID, "") + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 50*time.Millisecond) + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRename(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + testDir := "adir" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = client.MakeDir(testDir) + assert.NoError(t, err) + err = client.Rename(testFileName, path.Join("missing", testFileName)) + assert.Error(t, err) + err = client.Rename(testFileName, path.Join(testDir, testFileName)) + assert.NoError(t, err) + size, err := client.FileSize(path.Join(testDir, testFileName)) + assert.NoError(t, err) + assert.Equal(t, testFileSize, size) + if runtime.GOOS != osWindows { + otherDir := "dir" + err = client.MakeDir(otherDir) + assert.NoError(t, err) + err = client.MakeDir(path.Join(otherDir, testDir)) + assert.NoError(t, err) + code, response, err := client.SendCommand("SITE CHMOD 0001 %v", otherDir) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "SITE CHMOD command successful", response) + err = client.Rename(testDir, path.Join(otherDir, testDir)) + assert.Error(t, err) + + code, response, err = client.SendCommand("SITE CHMOD 755 %v", otherDir) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "SITE CHMOD command successful", response) + } + err = client.Quit() + assert.NoError(t, err) + } + user.Permissions[path.Join("/", testDir)] = []string{dataprovider.PermListItems} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = client.Rename(path.Join(testDir, testFileName), testFileName) + assert.Error(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = allPerms + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSymlink(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + for _, user := range []dataprovider.User{localUser, sftpUser} { + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + code, _, err := client.SendCommand("SITE SYMLINK %v %v", testFileName, testFileName+".link") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + + if runtime.GOOS != osWindows { + testDir := "adir" + otherDir := "dir" + err = client.MakeDir(otherDir) + assert.NoError(t, err) + err = client.MakeDir(path.Join(otherDir, testDir)) + assert.NoError(t, err) + code, response, err := client.SendCommand("SITE CHMOD 0001 %v", otherDir) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "SITE CHMOD command successful", response) + code, _, err = client.SendCommand("SITE SYMLINK %v %v", testDir, path.Join(otherDir, testDir)) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFileUnavailable, code) + + code, response, err = client.SendCommand("SITE CHMOD 755 %v", otherDir) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "SITE CHMOD command successful", response) + } + err = client.Quit() + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestStat(t *testing.T) { + u := getTestUser() + u.Permissions["/subdir"] = []string{dataprovider.PermUpload} + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + subDir := "subdir" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = client.MakeDir(subDir) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join("/", subDir, testFileName), testFileSize, client, 0) + assert.Error(t, err) + size, err := client.FileSize(testFileName) + assert.NoError(t, err) + assert.Equal(t, testFileSize, size) + _, err = client.FileSize(path.Join("/", subDir, testFileName)) + assert.Error(t, err) + _, err = client.FileSize("missing file") + assert.Error(t, err) + err = client.Quit() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadOverwriteVfolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1000 + vdir := "/vdir" + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdir, + QuotaSize: -1, + QuotaFiles: -1, + }) + err = os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, 0, folder.UsedQuotaFiles) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + err = ftpUploadFile(testFilePath, path.Join(vdir, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, 0, folder.UsedQuotaFiles) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(524288) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), ftpserver.ErrStorageExceeded.Error()) + } + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + err = client.Quit() + assert.NoError(t, err) + } + + testFileSize = int64(600000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + user.DownloadDataTransfer = 2 + user.UploadDataTransfer = 2 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.Error(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.Error(t, err) + + err = client.Quit() + assert.NoError(t, err) + } + + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAllocateAvailable(t *testing.T) { + u := getTestUser() + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vdir", + QuotaSize: 110, + }) + err = os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("allo 2000000") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "Done !", response) + + code, response, err = client.SendCommand("AVBL /vdir") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "110", response) + + code, _, err = client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + + err = client.Quit() + assert.NoError(t, err) + } + user.QuotaSize = 100 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := user.QuotaSize - 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + code, response, err := client.SendCommand("allo 1000") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "Done !", response) + + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + + code, response, err = client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "1", response) + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + user.TotalDataTransfer = 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "1", response) + + err = client.Quit() + assert.NoError(t, err) + } + + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 5 + user.QuotaSize = 6 * 1024 * 1024 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "5242880", response) + + err = client.Quit() + assert.NoError(t, err) + } + + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 5 + user.QuotaSize = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "5242880", response) + + err = client.Quit() + assert.NoError(t, err) + } + + user.Filters.MaxUploadFileSize = 100 + user.QuotaSize = 0 + user.TotalDataTransfer = 0 + user.UploadDataTransfer = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("allo 10000") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "Done !", response) + + code, response, err = client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "100", response) + + err = client.Quit() + assert.NoError(t, err) + } + + user.QuotaSize = 50 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "0", response) + } + + user.QuotaSize = 1000 + user.Filters.MaxUploadFileSize = 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("AVBL") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, "1", response) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestAvailableSFTPFs(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(sftpUser, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("AVBL /") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + avblSize, err := strconv.ParseInt(response, 10, 64) + assert.NoError(t, err) + assert.Greater(t, avblSize, int64(0)) + + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestChtimes(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + + mtime := time.Now().Format("20060102150405") + code, response, err := client.SendCommand("MFMT %v %v", mtime, testFileName) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Equal(t, fmt.Sprintf("Modify=%v; %v", mtime, testFileName), response) + err = client.Quit() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMODEType(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("MODE s") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusNotImplementedParameter, code) + assert.Equal(t, "Unsupported mode", response) + code, response, err = client.SendCommand("MODE S") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "Using stream mode", response) + + code, _, err = client.SendCommand("MODE Z") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusNotImplementedParameter, code) + + code, _, err = client.SendCommand("MODE SS") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusNotImplementedParameter, code) + + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSTAT(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, false, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131072) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testDir := "testdir" + err = client.MakeDir(testDir) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(testDir, testFileName+"_1"), testFileSize, client, 0) + assert.NoError(t, err) + code, response, err := client.SendCommand("STAT %s", testDir) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusDirectory, code) + assert.Contains(t, response, fmt.Sprintf("STAT %s", testDir)) + assert.Contains(t, response, testFileName) + assert.Contains(t, response, testFileName+"_1") + assert.Contains(t, response, "End") + + err = client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestChown(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("chown is not supported on Windows") + } + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131072) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + code, response, err := client.SendCommand("SITE CHOWN 1000:1000 %v", testFileName) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFileUnavailable, code) + assert.Equal(t, "Couldn't chown: operation unsupported", response) + err = client.Quit() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestChmod(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("chmod is partially supported on Windows") + } + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131072) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + + code, response, err := client.SendCommand("SITE CHMOD 600 %v", testFileName) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusCommandOK, code) + assert.Equal(t, "SITE CHMOD command successful", response) + + fi, err := os.Stat(filepath.Join(user.HomeDir, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, os.FileMode(0600), fi.Mode().Perm()) + } + err = client.Quit() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCombineDisabled(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClient(user, true, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + + code, response, err := client.SendCommand("COMB file file.1 file.2") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusNotImplemented, code) + assert.Equal(t, "COMB support is disabled", response) + + err = client.Quit() + assert.NoError(t, err) + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestActiveModeDisabled(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClientImplicitTLS(user) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + code, response, err := client.SendCommand("PORT 10,2,0,2,4,31") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusNotAvailable, code) + assert.Equal(t, "PORT command is disabled", response) + + code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusNotAvailable, code) + assert.Equal(t, "EPRT command is disabled", response) + + err = client.Quit() + assert.NoError(t, err) + } + + client, err = getFTPClient(user, false, nil) + if assert.NoError(t, err) { + code, response, err := client.SendCommand("PORT 10,2,0,2,4,31") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusBadArguments, code) + assert.Equal(t, "Your request does not meet the configured security requirements", response) + + code, response, err = client.SendCommand("EPRT |1|132.235.1.2|6275|") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusBadArguments, code) + assert.Equal(t, "Your request does not meet the configured security requirements", response) + + err = client.Quit() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSITEDisabled(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClientImplicitTLS(user) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + + code, response, err := client.SendCommand("SITE CHMOD 600 afile.txt") + assert.NoError(t, err) + assert.Equal(t, ftp.StatusBadCommand, code) + assert.Equal(t, "SITE support is disabled", response) + + err = client.Quit() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHASH(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + u = getTestUserWithCryptFs() + u.Username += "_crypt" + cryptUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser, cryptUser} { + client, err := getFTPClientImplicitTLS(user) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131072) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + + h := sha256.New() + f, err := os.Open(testFilePath) + assert.NoError(t, err) + _, err = io.Copy(h, f) + assert.NoError(t, err) + hash := hex.EncodeToString(h.Sum(nil)) + err = f.Close() + assert.NoError(t, err) + + code, response, err := client.SendCommand("XSHA256 %v", testFileName) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusRequestedFileActionOK, code) + assert.Contains(t, response, hash) + + code, response, err = client.SendCommand("HASH %v", testFileName) + assert.NoError(t, err) + assert.Equal(t, ftp.StatusFile, code) + assert.Contains(t, response, hash) + + err = client.Quit() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(cryptUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(cryptUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCombine(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + client, err := getFTPClientImplicitTLS(user) + if assert.NoError(t, err) { + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131072) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = checkBasicFTP(client) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName+".1", testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, testFileName+".2", testFileSize, client, 0) + assert.NoError(t, err) + + code, response, err := client.SendCommand("COMB %v %v %v", testFileName, testFileName+".1", testFileName+".2") + assert.NoError(t, err) + if user.Username == defaultUsername { + assert.Equal(t, ftp.StatusRequestedFileActionOK, code) + assert.Equal(t, "COMB succeeded!", response) + } else { + assert.Equal(t, ftp.StatusFileUnavailable, code) + assert.Contains(t, response, "COMB is not supported for this filesystem") + } + + err = client.Quit() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientCertificateAuthRevokedCert(t *testing.T) { + u := getTestUser() + u.Username = tlsClient2Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + ClientSessionCache: tls.NewLRUClientSessionCache(0), + } + tlsCert, err := tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + _, err = getFTPClientWithSessionReuse(user, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "bad certificate") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientCertificateAuth(t *testing.T) { + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + // TLS username is not enabled, mutual TLS should fail + _, err = getFTPClient(user, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "login method password is not allowed") + } + + user.Filters.TLSUsername = sdk.TLSUsernameCN + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client, err := getFTPClient(user, true, tlsConfig) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + + // now use a valid certificate with a CN different from username + u = getTestUser() + u.Username = tlsClient2Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + user2, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + _, err = getFTPClient(user2, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "does not match username") + } + // add the certs to the user + user2.Filters.TLSUsername = sdk.TLSUsernameNone + user2.Filters.TLSCerts = []string{client2Crt, client1Crt} + user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + client, err = getFTPClient(user2, true, tlsConfig) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + user2.Filters.TLSCerts = []string{client2Crt} + user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + _, err = getFTPClient(user2, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "TLS certificate is not valid") + } + + // now disable certificate authentication + user.Filters.DeniedLoginMethods = append(user.Filters.DeniedLoginMethods, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = getFTPClient(user, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "login method TLSCertificate+password is not allowed") + } + + // disable FTP protocol + user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + user.Filters.DeniedProtocols = append(user.Filters.DeniedProtocols, common.ProtocolFTP) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = getFTPClient(user, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "protocol FTP is not allowed") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) + + _, err = getFTPClient(user, true, tlsConfig) + assert.Error(t, err) +} + +func TestClientCertificateAndPwdAuth(t *testing.T) { + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client, err := getFTPClient(user, true, tlsConfig) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err = client.Quit() + assert.NoError(t, err) + } + + _, err = getFTPClient(user, true, nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "login method password is not allowed") + } + user.Password = defaultPassword + "1" + _, err = getFTPClient(user, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid credentials") + } + + tlsCert, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + tlsConfig.Certificates = []tls.Certificate{tlsCert} + _, err = getFTPClient(user, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "bad certificate") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestExternalAuthWithClientCert(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, dataprovider.LoginMethodPassword) + u.Filters.TLSUsername = sdk.TLSUsernameCN + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 8 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + // external auth not called, auth scope is 8 + _, err = getFTPClient(u, true, nil) + assert.Error(t, err) + _, _, err = httpdtest.GetUserByUsername(u.Username, http.StatusNotFound) + assert.NoError(t, err) + + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client, err := getFTPClient(u, true, tlsConfig) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, u.Username, user.Username) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + u.Username = tlsClient2Username + _, err = getFTPClient(u, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid credentials") + } + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestPreLoginHookWithClientCert(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, dataprovider.LoginMethodPassword) + u.Filters.TLSUsername = sdk.TLSUsernameCN + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername(tlsClient1Username, http.StatusNotFound) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client, err := getFTPClient(u, true, tlsConfig) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + user, _, err := httpdtest.GetUserByUsername(tlsClient1Username, http.StatusOK) + assert.NoError(t, err) + + // test login with an existing user + client, err = getFTPClient(user, true, tlsConfig) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + err := client.Quit() + assert.NoError(t, err) + } + + u.Username = tlsClient2Username + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + _, err = getFTPClient(u, true, tlsConfig) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "does not match username") + } + + user2, _, err := httpdtest.GetUserByUsername(tlsClient2Username, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestNestedVirtualFolders(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + }, + VirtualPath: vdirCryptPath, + }) + mappedPath := filepath.Join(os.TempDir(), "local") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + }) + mappedPathNested := filepath.Join(os.TempDir(), "nested") + folderNameNested := filepath.Base(mappedPathNested) + vdirNestedPath := "/vdir/crypt/nested" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameNested, + }, + VirtualPath: vdirNestedPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderNameNested, + MappedPath: mappedPathNested, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client, err := getFTPClient(sftpUser, false, nil) + if assert.NoError(t, err) { + err = checkBasicFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + + err = ftpUploadFile(testFilePath, testFileName, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(testFileName, localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join("/vdir", testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdirCryptPath, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + err = ftpUploadFile(testFilePath, path.Join(vdirNestedPath, testFileName), testFileSize, client, 0) + assert.NoError(t, err) + err = ftpDownloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client, 0) + assert.NoError(t, err) + + err = client.Quit() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathNested) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 1*time.Second, 50*time.Millisecond) + assert.Eventually(t, func() bool { return common.Connections.GetClientConnections() == 0 }, 1000*time.Millisecond, + 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func checkBasicFTP(client *ftp.ServerConn) error { + _, err := client.CurrentDir() + if err != nil { + return err + } + err = client.NoOp() + if err != nil { + return err + } + _, err = client.List(".") + if err != nil { + return err + } + return nil +} + +func ftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *ftp.ServerConn, offset uint64) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + if offset > 0 { + err = client.StorFrom(remoteDestPath, srcFile, offset) + } else { + err = client.Stor(remoteDestPath, srcFile) + } + if err != nil { + return err + } + if expectedSize > 0 { + size, err := client.FileSize(remoteDestPath) + if err != nil { + return err + } + if size != expectedSize { + return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", size, expectedSize) + } + } + return nil +} + +func ftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *ftp.ServerConn, offset uint64) error { + downloadDest, err := os.Create(localDestPath) + if err != nil { + return err + } + defer downloadDest.Close() + var r *ftp.Response + if offset > 0 { + r, err = client.RetrFrom(remoteSourcePath, offset) + } else { + r, err = client.Retr(remoteSourcePath) + } + if err != nil { + return err + } + defer r.Close() + + written, err := io.Copy(downloadDest, r) + if err != nil { + return err + } + if written != expectedSize { + return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", written, expectedSize) + } + return nil +} + +func getFTPClientImplicitTLS(user dataprovider.User) (*ftp.ServerConn, error) { + ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + ftpOptions = append(ftpOptions, ftp.DialWithTLS(tlsConfig)) + ftpOptions = append(ftpOptions, ftp.DialWithDisabledEPSV(true)) + client, err := ftp.Dial(ftpSrvAddrTLS, ftpOptions...) + if err != nil { + return nil, err + } + pwd := defaultPassword + if user.Password != "" { + pwd = user.Password + } + err = client.Login(user.Username, pwd) + if err != nil { + return nil, err + } + return client, err +} + +func getFTPClientWithSessionReuse(user dataprovider.User, tlsConfig *tls.Config, dialOptions ...ftp.DialOption, +) (*ftp.ServerConn, error) { + ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} + ftpOptions = append(ftpOptions, dialOptions...) + if tlsConfig == nil { + tlsConfig = &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + ClientSessionCache: tls.NewLRUClientSessionCache(0), + } + } + ftpOptions = append(ftpOptions, ftp.DialWithExplicitTLS(tlsConfig)) + client, err := ftp.Dial(ftpSrvAddrTLSResumption, ftpOptions...) + if err != nil { + return nil, err + } + pwd := defaultPassword + if user.Password != "" { + if user.Password == emptyPwdPlaceholder { + pwd = "" + } else { + pwd = user.Password + } + } + err = client.Login(user.Username, pwd) + if err != nil { + return nil, err + } + return client, err +} + +func getFTPClient(user dataprovider.User, useTLS bool, tlsConfig *tls.Config, dialOptions ...ftp.DialOption, +) (*ftp.ServerConn, error) { + ftpOptions := []ftp.DialOption{ftp.DialWithTimeout(5 * time.Second)} + ftpOptions = append(ftpOptions, dialOptions...) + if useTLS { + if tlsConfig == nil { + tlsConfig = &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + } + ftpOptions = append(ftpOptions, ftp.DialWithExplicitTLS(tlsConfig)) + } + client, err := ftp.Dial(ftpServerAddr, ftpOptions...) + if err != nil { + return nil, err + } + pwd := defaultPassword + if user.Password != "" { + if user.Password == emptyPwdPlaceholder { + pwd = "" + } else { + pwd = user.Password + } + } + err = client.Login(user.Username, pwd) + if err != nil { + return nil, err + } + return client, err +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + logger.WarnToConsole("tcp server %v not listening: %v", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + logger.InfoToConsole("tcp server %v now listening", address) + conn.Close() + break + } +} + +func waitNoConnections() { + time.Sleep(50 * time.Millisecond) + for len(common.Connections.GetStats("")) > 0 { + time.Sleep(50 * time.Millisecond) + } +} + +func getTestGroup() dataprovider.Group { + return dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "test_group", + Description: "test group description", + }, + } +} + +func getTestUser() dataprovider.User { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Status: 1, + ExpirationDate: 0, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = allPerms + return user +} + +func getTestSFTPUser() dataprovider.User { + u := getTestUser() + u.Username = u.Username + "_sftp" + u.FsConfig.Provider = sdk.SFTPFilesystemProvider + u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + return u +} + +func getTestUserWithHTTPFs() dataprovider.User { + u := getTestUser() + u.FsConfig.Provider = sdk.HTTPFilesystemProvider + u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), + Username: defaultHTTPFsUsername, + }, + } + return u +} + +func getExtAuthScriptContent(user dataprovider.User) []byte { + extAuthContent := []byte("#!/bin/sh\n\n") + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%v\"; then\n", user.Username))...) + u, _ := json.Marshal(user) + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) + extAuthContent = append(extAuthContent, []byte("else\n")...) + extAuthContent = append(extAuthContent, []byte("echo '{\"username\":\"\"}'\n")...) + extAuthContent = append(extAuthContent, []byte("fi\n")...) + return extAuthContent +} + +func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { + content := []byte("#!/bin/sh\n\n") + if nonJSONResponse { + content = append(content, []byte("echo 'text response'\n")...) + return content + } + if len(user.Username) > 0 { + u, _ := json.Marshal(user) + content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) + } + return content +} + +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func getExitCodeScriptContent(exitCode int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) + return content +} + +func createTestFile(path string, size int64) error { + baseDir := filepath.Dir(path) + if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(baseDir, os.ModePerm) + if err != nil { + return err + } + } + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + return os.WriteFile(path, content, os.ModePerm) +} + +func writeCerts(certPath, keyPath, caCrtPath, caCRLPath string) error { + err := os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing FTPS certificate: %v", err) + return err + } + err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing FTPS private key: %v", err) + return err + } + err = os.WriteFile(caCrtPath, []byte(caCRT), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing FTPS CA crt: %v", err) + return err + } + err = os.WriteFile(caCRLPath, []byte(caCRL), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing FTPS CRL: %v", err) + return err + } + return nil +} + +func generateTOTPPasscode(secret string, algo otp.Algorithm) (string, error) { + return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: algo, + }) +} + +func startHTTPFs() { + go func() { + if err := httpdtest.StartTestHTTPFs(httpFsPort, nil); err != nil { + logger.ErrorToConsole("could not start HTTPfs test server: %v", err) + os.Exit(1) + } + }() + waitTCPListening(fmt.Sprintf(":%d", httpFsPort)) +} diff --git a/internal/ftpd/handler.go b/internal/ftpd/handler.go new file mode 100644 index 00000000..613eaa07 --- /dev/null +++ b/internal/ftpd/handler.go @@ -0,0 +1,625 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ftpd + +import ( + "errors" + "fmt" + "io" + "os" + "path" + "strings" + "time" + + ftpserver "github.com/fclairamb/ftpserverlib" + "github.com/spf13/afero" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + errNotImplemented = errors.New("not implemented") + errCOMBNotSupported = errors.New("COMB is not supported for this filesystem") +) + +// Connection details for an FTP connection. +// It implements common.ActiveConnection and ftpserver.ClientDriver interfaces +type Connection struct { + *common.BaseConnection + clientContext ftpserver.ClientContext + doWildcardListDir bool +} + +func (c *Connection) getFTPMode() string { + if c.clientContext == nil { + return "" + } + switch c.clientContext.GetLastDataChannel() { + case ftpserver.DataChannelActive: + return "active" + case ftpserver.DataChannelPassive: + return "passive" + } + return "" +} + +// GetClientVersion returns the connected client's version. +// It returns "Unknown" if the client does not advertise its +// version +func (c *Connection) GetClientVersion() string { + version := c.clientContext.GetClientVersion() + if len(version) > 0 { + return version + } + return "Unknown" +} + +// GetLocalAddress returns local connection address +func (c *Connection) GetLocalAddress() string { + return c.clientContext.LocalAddr().String() +} + +// GetRemoteAddress returns the connected client's address +func (c *Connection) GetRemoteAddress() string { + return c.clientContext.RemoteAddr().String() +} + +// Disconnect disconnects the client +func (c *Connection) Disconnect() error { + return c.clientContext.Close() +} + +// GetCommand returns the last received FTP command +func (c *Connection) GetCommand() string { + return c.clientContext.GetLastCommand() +} + +// Create is not implemented we use ClientDriverExtentionFileTransfer +func (c *Connection) Create(_ string) (afero.File, error) { + return nil, errNotImplemented +} + +// Mkdir creates a directory using the connection filesystem +func (c *Connection) Mkdir(name string, _ os.FileMode) error { + c.UpdateLastActivity() + name = util.CleanPath(name) + + return c.CreateDir(name, true) +} + +// MkdirAll is not implemented, we don't need it +func (c *Connection) MkdirAll(_ string, _ os.FileMode) error { + return errNotImplemented +} + +// Open is not implemented we use ClientDriverExtentionFileTransfer and ClientDriverExtensionFileList +func (c *Connection) Open(_ string) (afero.File, error) { + return nil, errNotImplemented +} + +// OpenFile is not implemented we use ClientDriverExtentionFileTransfer +func (c *Connection) OpenFile(_ string, _ int, _ os.FileMode) (afero.File, error) { + return nil, errNotImplemented +} + +// Remove removes a file. +// We implements ClientDriverExtensionRemoveDir for directories +func (c *Connection) Remove(name string) error { + c.UpdateLastActivity() + name = util.CleanPath(name) + + fs, p, err := c.GetFsAndResolvedPath(name) + if err != nil { + return err + } + + var fi os.FileInfo + if fi, err = fs.Lstat(p); err != nil { + c.Log(logger.LevelError, "failed to remove file %q: stat error: %+v", p, err) + return c.GetFsError(fs, err) + } + + if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { + c.Log(logger.LevelError, "cannot remove %q is not a file/symlink", p) + return c.GetGenericError(nil) + } + return c.RemoveFile(fs, p, name, fi) +} + +// RemoveAll is not implemented, we don't need it +func (c *Connection) RemoveAll(_ string) error { + return errNotImplemented +} + +// Rename renames a file or a directory +func (c *Connection) Rename(oldname, newname string) error { + c.UpdateLastActivity() + oldname = util.CleanPath(oldname) + newname = util.CleanPath(newname) + + return c.BaseConnection.Rename(oldname, newname) +} + +// Stat returns a FileInfo describing the named file/directory, or an error, +// if any happens +func (c *Connection) Stat(name string) (os.FileInfo, error) { + c.UpdateLastActivity() + name = util.CleanPath(name) + c.doWildcardListDir = false + + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { + return nil, c.GetPermissionDeniedError() + } + + fi, err := c.DoStat(name, 0, true) + if err != nil { + if c.isListDirWithWildcards(path.Base(name)) { + c.doWildcardListDir = true + return vfs.NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + return nil, err + } + return fi, nil +} + +// Name returns the name of this connection +func (c *Connection) Name() string { + return c.GetID() +} + +// Chown changes the uid and gid of the named file +func (c *Connection) Chown(_ string, _, _ int) error { + c.UpdateLastActivity() + + return common.ErrOpUnsupported + /*p, err := c.Fs.ResolvePath(name) + if err != nil { + return c.GetFsError(err) + } + attrs := common.StatAttributes{ + Flags: common.StatAttrUIDGID, + UID: uid, + GID: gid, + } + + return c.SetStat(p, name, &attrs)*/ +} + +// Chmod changes the mode of the named file/directory +func (c *Connection) Chmod(name string, mode os.FileMode) error { + c.UpdateLastActivity() + name = util.CleanPath(name) + + attrs := common.StatAttributes{ + Flags: common.StatAttrPerms, + Mode: mode, + } + return c.SetStat(name, &attrs) +} + +// Chtimes changes the access and modification times of the named file +func (c *Connection) Chtimes(name string, atime time.Time, mtime time.Time) error { + c.UpdateLastActivity() + name = util.CleanPath(name) + + attrs := common.StatAttributes{ + Flags: common.StatAttrTimes, + Atime: atime, + Mtime: mtime, + } + return c.SetStat(name, &attrs) +} + +// GetAvailableSpace implements ClientDriverExtensionAvailableSpace interface +func (c *Connection) GetAvailableSpace(dirName string) (int64, error) { + c.UpdateLastActivity() + dirName = util.CleanPath(dirName) + + diskQuota, transferQuota := c.HasSpace(false, false, path.Join(dirName, "fakefile.txt")) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + return 0, nil + } + + if diskQuota.AllowedSize == 0 && transferQuota.AllowedULSize == 0 && transferQuota.AllowedTotalSize == 0 { + // no quota restrictions + if c.User.Filters.MaxUploadFileSize > 0 { + return c.User.Filters.MaxUploadFileSize, nil + } + + fs, p, err := c.GetFsAndResolvedPath(dirName) + if err != nil { + return 0, err + } + + statVFS, err := fs.GetAvailableDiskSize(p) + if err != nil { + return 0, c.GetFsError(fs, err) + } + return int64(statVFS.FreeSpace()), nil + } + + allowedDiskSize := diskQuota.AllowedSize + allowedUploadSize := transferQuota.AllowedULSize + if transferQuota.AllowedTotalSize > 0 { + allowedUploadSize = transferQuota.AllowedTotalSize + } + allowedSize := allowedDiskSize + if allowedSize == 0 { + allowedSize = allowedUploadSize + } else { + if allowedUploadSize > 0 && allowedUploadSize < allowedSize { + allowedSize = allowedUploadSize + } + } + // the available space is the minimum between MaxUploadFileSize, if setted, + // and quota allowed size + if c.User.Filters.MaxUploadFileSize > 0 { + if c.User.Filters.MaxUploadFileSize < allowedSize { + return c.User.Filters.MaxUploadFileSize, nil + } + } + + return allowedSize, nil +} + +// AllocateSpace implements ClientDriverExtensionAllocate interface +func (c *Connection) AllocateSpace(_ int) error { + c.UpdateLastActivity() + // we treat ALLO as NOOP see RFC 959 + return nil +} + +// RemoveDir implements ClientDriverExtensionRemoveDir +func (c *Connection) RemoveDir(name string) error { + c.UpdateLastActivity() + name = util.CleanPath(name) + + return c.BaseConnection.RemoveDir(name) +} + +// Symlink implements ClientDriverExtensionSymlink +func (c *Connection) Symlink(oldname, newname string) error { + c.UpdateLastActivity() + oldname = util.CleanPath(oldname) + newname = util.CleanPath(newname) + + return c.CreateSymlink(oldname, newname) +} + +// ReadDir implements ClientDriverExtensionFilelist +func (c *Connection) ReadDir(name string) ([]os.FileInfo, error) { + c.UpdateLastActivity() + name = util.CleanPath(name) + + if c.doWildcardListDir { + c.doWildcardListDir = false + baseName := path.Base(name) + // we only support wildcards for the last path level, for example: + // - *.xml is supported + // - dir*/*.xml is not supported + name = path.Dir(name) + c.clientContext.SetListPath(name) + lister, err := c.ListDir(name) + if err != nil { + return nil, err + } + patternLister := &patternDirLister{ + DirLister: lister, + pattern: baseName, + lastCommand: c.clientContext.GetLastCommand(), + dirName: name, + connectionPath: util.CleanPath(c.clientContext.Path()), + } + return consumeDirLister(patternLister) + } + + lister, err := c.ListDir(name) + if err != nil { + return nil, err + } + return consumeDirLister(lister) +} + +// GetHandle implements ClientDriverExtentionFileTransfer +func (c *Connection) GetHandle(name string, flags int, offset int64) (ftpserver.FileTransfer, error) { + c.UpdateLastActivity() + name = util.CleanPath(name) + + fs, p, err := c.GetFsAndResolvedPath(name) + if err != nil { + return nil, err + } + + if c.GetCommand() == "COMB" && !vfs.IsLocalOsFs(fs) { + return nil, errCOMBNotSupported + } + + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying transfer due to count limits") + return nil, c.GetPermissionDeniedError() + } + + if flags&os.O_WRONLY != 0 { + return c.uploadFile(fs, p, name, flags) + } + return c.downloadFile(fs, p, name, offset) +} + +func (c *Connection) downloadFile(fs vfs.Fs, fsPath, ftpPath string, offset int64) (ftpserver.FileTransfer, error) { + if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(ftpPath)) { + return nil, c.GetPermissionDeniedError() + } + transferQuota := c.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.Log(logger.LevelInfo, "denying file read due to quota limits") + return nil, c.GetReadQuotaExceededError() + } + + if ok, policy := c.User.IsFileAllowed(ftpPath); !ok { + c.Log(logger.LevelWarn, "reading file %q is not allowed", ftpPath) + return nil, c.GetErrorForDeniedFile(policy) + } + + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreDownload, fsPath, ftpPath, 0, 0); err != nil { + c.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", ftpPath, err) + return nil, c.GetPermissionDeniedError() + } + + file, r, cancelFn, err := fs.Open(fsPath, offset) + if err != nil { + c.Log(logger.LevelError, "could not open file %q for reading: %+v", fsPath, err) + return nil, c.GetFsError(fs, err) + } + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, fsPath, ftpPath, + common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) + baseTransfer.SetFtpMode(c.getFTPMode()) + t := newTransfer(baseTransfer, nil, r, offset) + + return t, nil +} + +func (c *Connection) uploadFile(fs vfs.Fs, fsPath, ftpPath string, flags int) (ftpserver.FileTransfer, error) { + if ok, _ := c.User.IsFileAllowed(ftpPath); !ok { + c.Log(logger.LevelWarn, "writing file %q is not allowed", ftpPath) + return nil, ftpserver.ErrFileNameNotAllowed + } + + filePath := fsPath + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(fsPath) + } + + stat, statErr := fs.Lstat(fsPath) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { + if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(ftpPath)) { + return nil, fmt.Errorf("%w, no upload permission", ftpserver.ErrFileNameNotAllowed) + } + return c.handleFTPUploadToNewFile(fs, flags, fsPath, filePath, ftpPath) + } + + if statErr != nil { + c.Log(logger.LevelError, "error performing file stat %q: %+v", fsPath, statErr) + return nil, c.GetFsError(fs, statErr) + } + + // This happen if we upload a file that has the same name of an existing directory + if stat.IsDir() { + c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", fsPath) + return nil, c.GetOpUnsupportedError() + } + + if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(ftpPath)) { + return nil, fmt.Errorf("%w, no overwrite permission", ftpserver.ErrFileNameNotAllowed) + } + + return c.handleFTPUploadToExistingFile(fs, flags, fsPath, filePath, stat.Size(), ftpPath) +} + +func (c *Connection) handleFTPUploadToNewFile(fs vfs.Fs, flags int, resolvedPath, filePath, requestPath string) (ftpserver.FileTransfer, error) { + diskQuota, transferQuota := c.HasSpace(true, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, ftpserver.ErrStorageExceeded + } + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, 0, 0); err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, ftpserver.ErrFileNameNotAllowed + } + file, w, cancelFn, err := fs.Create(filePath, flags, c.GetCreateChecks(requestPath, true, false)) + if err != nil { + c.Log(logger.LevelError, "error creating file %q, flags %v: %+v", resolvedPath, flags, err) + return nil, c.GetFsError(fs, err) + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + // we can get an error only for resume + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) + baseTransfer.SetFtpMode(c.getFTPMode()) + t := newTransfer(baseTransfer, w, nil, 0) + + return t, nil +} + +func (c *Connection) handleFTPUploadToExistingFile(fs vfs.Fs, flags int, resolvedPath, filePath string, fileSize int64, + requestPath string) (ftpserver.FileTransfer, error) { + var err error + diskQuota, transferQuota := c.HasSpace(false, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, ftpserver.ErrStorageExceeded + } + minWriteOffset := int64(0) + // ftpserverlib sets: + // - os.O_WRONLY | os.O_APPEND for APPE and COMB + // - os.O_WRONLY | os.O_CREATE for REST. + // - os.O_WRONLY | os.O_CREATE | os.O_TRUNC if the command is not APPE and REST = 0 + // so if we don't have O_TRUNC is a resume. + isResume := flags&os.O_TRUNC == 0 + // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace + // will return false in this case and we deny the upload before + maxWriteSize, err := c.GetMaxWriteSize(diskQuota, isResume, fileSize, vfs.IsUploadResumeSupported(fs, fileSize)) + if err != nil { + c.Log(logger.LevelDebug, "unable to get max write size: %v", err) + return nil, err + } + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, flags); err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, ftpserver.ErrFileNameNotAllowed + } + + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + _, _, err = fs.Rename(resolvedPath, filePath, 0) + if err != nil { + c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", + resolvedPath, filePath, err) + return nil, c.GetFsError(fs, err) + } + } + + file, w, cancelFn, err := fs.Create(filePath, flags, c.GetCreateChecks(requestPath, false, isResume)) + if err != nil { + c.Log(logger.LevelError, "error opening existing file, flags: %v, source: %q, err: %+v", flags, filePath, err) + return nil, c.GetFsError(fs, err) + } + + initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota + if isResume { + c.Log(logger.LevelDebug, "resuming upload requested, file path: %q initial size: %v", filePath, fileSize) + minWriteOffset = fileSize + initialSize = fileSize + if vfs.IsSFTPFs(fs) && fs.IsUploadResumeSupported() { + // we need this since we don't allow resume with wrong offset, we should fix this in pkg/sftp + file.Seek(initialSize, io.SeekStart) //nolint:errcheck // for sftp seek simply set the offset + } + } else { + if vfs.HasTruncateSupport(fs) { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) + } else { + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck + } + } else { + initialSize = fileSize + truncatedSize = fileSize + } + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) + baseTransfer.SetFtpMode(c.getFTPMode()) + t := newTransfer(baseTransfer, w, nil, minWriteOffset) + + return t, nil +} + +func (c *Connection) isListDirWithWildcards(name string) bool { + if strings.ContainsAny(name, "*?[]^") { + lastCommand := c.clientContext.GetLastCommand() + return lastCommand == "LIST" || lastCommand == "NLST" + } + return false +} + +func getPathRelativeTo(base, target string) string { + var sb strings.Builder + for { + if base == target { + return sb.String() + } + if !strings.HasSuffix(base, "/") { + base += "/" + } + if strings.HasPrefix(target, base) { + sb.WriteString(strings.TrimPrefix(target, base)) + return sb.String() + } + if base == "/" || base == "./" { + return target + } + sb.WriteString("../") + base = path.Dir(path.Clean(base)) + } +} + +type patternDirLister struct { + vfs.DirLister + pattern string + lastCommand string + dirName string + connectionPath string +} + +func (l *patternDirLister) Next(limit int) ([]os.FileInfo, error) { + for { + files, err := l.DirLister.Next(limit) + if len(files) == 0 { + return files, err + } + validIdx := 0 + var relativeBase string + if l.lastCommand != "NLST" { + relativeBase = getPathRelativeTo(l.connectionPath, l.dirName) + } + for _, fi := range files { + match, errMatch := path.Match(l.pattern, fi.Name()) + if errMatch != nil { + return nil, errMatch + } + if match { + files[validIdx] = vfs.NewFileInfo(path.Join(relativeBase, fi.Name()), fi.IsDir(), fi.Size(), + fi.ModTime(), true) + validIdx++ + } + } + files = files[:validIdx] + if err != nil || len(files) > 0 { + return files, err + } + } +} + +func consumeDirLister(lister vfs.DirLister) ([]os.FileInfo, error) { + defer lister.Close() + + var results []os.FileInfo + + for { + files, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + results = append(results, files...) + if err != nil && !finished { + return results, err + } + if finished { + lister.Close() + break + } + } + + return results, nil +} diff --git a/internal/ftpd/internal_test.go b/internal/ftpd/internal_test.go new file mode 100644 index 00000000..4e167524 --- /dev/null +++ b/internal/ftpd/internal_test.go @@ -0,0 +1,1217 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ftpd + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/fs" + "net" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/eikenb/pipeat" + ftpserver "github.com/fclairamb/ftpserverlib" + "github.com/pires/go-proxyproto" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + ftpsCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + ftpsKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` + caCRT = `-----BEGIN CERTIFICATE----- +MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 +QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT +CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m +fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 +v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh +yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL +CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U +4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb +KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo +NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb +S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i +M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY +/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy +OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD +AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu +XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F +sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki +pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE +jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX +y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy +WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV +oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV +L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 +DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E +eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli +SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w +kK8u1YiiVenmoQ== +-----END CERTIFICATE-----` + caKey = `-----BEGIN RSA PRIVATE KEY----- +MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x +b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD +OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz +n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj +7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ +De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb +cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq +R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 +8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB +P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z +xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA +AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY +00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa +KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t +Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR +y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz ++465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc +9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx +lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR +nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ +u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz +pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP +L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 +0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM +0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN ++jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u +BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe +KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO +Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 +/hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn +v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX +pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H +RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO +hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH +Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR +V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP +ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz +1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h +iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y +BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q +uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN +38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ +SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq +jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck +vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ +2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB +PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ +VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor +PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== +-----END RSA PRIVATE KEY-----` + caCRL = `-----BEGIN X509 CRL----- +MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN +MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg +ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r +rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV ++jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 +XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE +jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF +DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD +MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 +iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE +ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 +hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 +cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 +sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn +w7bXJGfJcIMrrKs= +-----END X509 CRL-----` + client1Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 +HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE +OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 +Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf +XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO +c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo +Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU +OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC +sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I +l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb +BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz +MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 +QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz +J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ +offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX +G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr +kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au +MU3Bo0A= +-----END CERTIFICATE-----` + client1Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki +J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi +85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW +eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 +ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy +bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO +hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw +eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl +X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ +CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL +j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU +NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf +sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB +Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB +LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g +mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 +T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn +RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa +Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW +asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ +DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 +0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls +ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY +nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe +yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== +-----END RSA PRIVATE KEY-----` + // client 2 crt is revoked + client2Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b +Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua +eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ +cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT +A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 +6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f +Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz +PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY +xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D +EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB +4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO +78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 +7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e +qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N +f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv +/ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S +ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR +E0+r2+4= +-----END CERTIFICATE-----` + client2Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV +jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP +kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK +h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P +QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu +4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq +op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM +qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt +Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL +VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO +qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy +VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl +NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 +4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl +/owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd +aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu +khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz +3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ +eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i +vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB +GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb +Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D +8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl +RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA +ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl +-----END RSA PRIVATE KEY-----` +) + +var ( + configDir = filepath.Join(".", "..", "..") +) + +type mockFTPClientContext struct { + lastDataChannel ftpserver.DataChannel + remoteIP string + localIP string + extra any +} + +func (cc *mockFTPClientContext) Path() string { + return "" +} + +func (cc *mockFTPClientContext) SetPath(_ string) {} + +func (cc *mockFTPClientContext) SetListPath(_ string) {} + +func (cc *mockFTPClientContext) SetDebug(_ bool) {} + +func (cc *mockFTPClientContext) Debug() bool { + return false +} + +func (cc *mockFTPClientContext) ID() uint32 { + return 1 +} + +func (cc *mockFTPClientContext) RemoteAddr() net.Addr { + ip := "127.0.0.1" + if cc.remoteIP != "" { + ip = cc.remoteIP + } + return &net.IPAddr{IP: net.ParseIP(ip)} +} + +func (cc *mockFTPClientContext) LocalAddr() net.Addr { + ip := "127.0.0.1" + if cc.localIP != "" { + ip = cc.localIP + } + return &net.IPAddr{IP: net.ParseIP(ip)} +} + +func (cc *mockFTPClientContext) GetClientVersion() string { + return "mock version" +} + +func (cc *mockFTPClientContext) Close() error { + return nil +} + +func (cc *mockFTPClientContext) HasTLSForControl() bool { + return false +} + +func (cc *mockFTPClientContext) HasTLSForTransfers() bool { + return false +} + +func (cc *mockFTPClientContext) SetTLSRequirement(_ ftpserver.TLSRequirement) error { + return nil +} + +func (cc *mockFTPClientContext) GetLastCommand() string { + return "" +} + +func (cc *mockFTPClientContext) GetLastDataChannel() ftpserver.DataChannel { + return cc.lastDataChannel +} + +func (cc *mockFTPClientContext) SetExtra(extra any) { + cc.extra = extra +} + +func (cc *mockFTPClientContext) Extra() any { + return cc.extra +} + +// MockOsFs mockable OsFs +type MockOsFs struct { + vfs.Fs + err error + statErr error + isAtomicUploadSupported bool +} + +// Name returns the name for the Fs implementation +func (fs MockOsFs) Name() string { + return "mockOsFs" +} + +// IsUploadResumeSupported returns true if resuming uploads is supported +func (MockOsFs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (MockOsFs) IsConditionalUploadResumeSupported(_ int64) bool { + return false +} + +// IsAtomicUploadSupported returns true if atomic upload is supported +func (fs MockOsFs) IsAtomicUploadSupported() bool { + return fs.isAtomicUploadSupported +} + +// Stat returns a FileInfo describing the named file +func (fs MockOsFs) Stat(name string) (os.FileInfo, error) { + if fs.statErr != nil { + return nil, fs.statErr + } + return os.Stat(name) +} + +// Lstat returns a FileInfo describing the named file +func (fs MockOsFs) Lstat(name string) (os.FileInfo, error) { + if fs.statErr != nil { + return nil, fs.statErr + } + return os.Lstat(name) +} + +// Remove removes the named file or (empty) directory. +func (fs MockOsFs) Remove(name string, _ bool) error { + if fs.err != nil { + return fs.err + } + return os.Remove(name) +} + +// Rename renames (moves) source to target +func (fs MockOsFs) Rename(source, target string, _ int) (int, int64, error) { + if fs.err != nil { + return -1, -1, fs.err + } + err := os.Rename(source, target) + return -1, -1, err +} + +func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { + return &MockOsFs{ + Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), + err: err, + statErr: statErr, + isAtomicUploadSupported: atomicUpload, + } +} + +func TestInitialization(t *testing.T) { + oldMgr := certMgr + certMgr = nil + + binding := Binding{ + Port: 2121, + } + c := &Configuration{ + Bindings: []Binding{binding}, + CertificateFile: "acert", + CertificateKeyFile: "akey", + } + assert.False(t, binding.HasProxy()) + assert.Equal(t, util.I18nFTPTLSDisabled, binding.GetTLSDescription()) + err := c.Initialize(configDir) + assert.Error(t, err) + c.CertificateFile = "" + c.CertificateKeyFile = "" + c.BannerFile = "afile" + server := NewServer(c, configDir, binding, 0) + assert.Equal(t, version.GetServerVersion("_", false), server.initialMsg) + _, err = server.GetTLSConfig() + assert.Error(t, err) + + binding.TLSMode = 1 + server = NewServer(c, configDir, binding, 0) + _, err = server.GetSettings() + assert.Error(t, err) + + binding.PassiveConnectionsSecurity = 100 + binding.ActiveConnectionsSecurity = 100 + server = NewServer(c, configDir, binding, 0) + _, err = server.GetSettings() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid passive_connections_security") + } + binding.PassiveConnectionsSecurity = 1 + server = NewServer(c, configDir, binding, 0) + _, err = server.GetSettings() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid active_connections_security") + } + binding = Binding{ + Port: 2121, + ForcePassiveIP: "192.168.1", + } + server = NewServer(c, configDir, binding, 0) + _, err = server.GetSettings() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not valid") + } + + binding.ForcePassiveIP = "::ffff:192.168.89.9" + err = binding.checkPassiveIP() + assert.NoError(t, err) + assert.Equal(t, "192.168.89.9", binding.ForcePassiveIP) + + binding.ForcePassiveIP = "::1" + err = binding.checkPassiveIP() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not a valid IPv4 address") + } + + err = ReloadCertificateMgr() + assert.NoError(t, err) + + binding = Binding{ + Port: 2121, + ClientAuthType: 1, + } + assert.Equal(t, util.I18nFTPTLSDisabled, binding.GetTLSDescription()) + certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") + keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") + binding.CertificateFile = certPath + binding.CertificateKeyFile = keyPath + keyPairs := []common.TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: binding.GetAddress(), + }, + } + certMgr, err = common.NewCertManager(keyPairs, configDir, "") + require.NoError(t, err) + + assert.Equal(t, util.I18nFTPTLSMixed, binding.GetTLSDescription()) + server = NewServer(c, configDir, binding, 0) + cfg, err := server.GetTLSConfig() + require.NoError(t, err) + assert.Equal(t, tls.RequireAndVerifyClientCert, cfg.ClientAuth) + + certMgr = oldMgr +} + +func TestServerGetSettings(t *testing.T) { + oldConfig := common.Config + oldMgr := certMgr + + binding := Binding{ + Port: 2121, + ApplyProxyConfig: true, + } + c := &Configuration{ + Bindings: []Binding{binding}, + PassivePortRange: PortRange{ + Start: 10000, + End: 10000, + }, + } + assert.False(t, binding.HasProxy()) + server := NewServer(c, configDir, binding, 0) + settings, err := server.GetSettings() + assert.NoError(t, err) + if ranger, ok := settings.PassiveTransferPortRange.(*ftpserver.PortRange); ok { + assert.Equal(t, 10000, ranger.Start) + assert.Equal(t, 10000, ranger.End) + } + c.PassivePortRange.End = 11000 + settings, err = server.GetSettings() + assert.NoError(t, err) + if ranger, ok := settings.PassiveTransferPortRange.(*ftpserver.PortRange); ok { + assert.Equal(t, 10000, ranger.Start) + assert.Equal(t, 11000, ranger.End) + } + + common.Config.ProxyProtocol = 1 + _, err = server.GetSettings() + assert.Error(t, err) + server.binding.Port = 8021 + + assert.Equal(t, util.I18nFTPTLSDisabled, binding.GetTLSDescription()) + _, err = server.GetTLSConfig() + assert.Error(t, err) // TLS configured but cert manager has no certificate + + binding.TLSMode = 1 + assert.Equal(t, util.I18nFTPTLSExplicit, binding.GetTLSDescription()) + + binding.TLSMode = 2 + assert.Equal(t, util.I18nFTPTLSImplicit, binding.GetTLSDescription()) + + certPath := filepath.Join(os.TempDir(), "test_ftpd.crt") + keyPath := filepath.Join(os.TempDir(), "test_ftpd.key") + err = os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm) + assert.NoError(t, err) + + keyPairs := []common.TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: common.DefaultTLSKeyPaidID, + }, + } + certMgr, err = common.NewCertManager(keyPairs, configDir, "") + require.NoError(t, err) + common.Config.ProxyAllowed = nil + c.CertificateFile = certPath + c.CertificateKeyFile = keyPath + server = NewServer(c, configDir, binding, 0) + server.binding.Port = 9021 + settings, err = server.GetSettings() + assert.NoError(t, err) + assert.NotNil(t, settings.Listener) + + listener, err := net.Listen("tcp", ":0") + assert.NoError(t, err) + listener, err = server.WrapPassiveListener(listener) + assert.NoError(t, err) + + _, ok := listener.(*proxyproto.Listener) + assert.True(t, ok) + + common.Config = oldConfig + certMgr = oldMgr +} + +func TestUserInvalidParams(t *testing.T) { + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: "invalid", + }, + } + binding := Binding{ + Port: 2121, + } + c := &Configuration{ + Bindings: []Binding{binding}, + PassivePortRange: PortRange{ + Start: 10000, + End: 11000, + }, + } + server := NewServer(c, configDir, binding, 3) + _, err := server.validateUser(u, &mockFTPClientContext{}, dataprovider.LoginMethodPassword) + assert.Error(t, err) + + u.Username = "a" + u.HomeDir = filepath.Clean(os.TempDir()) + subDir := "subdir" + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir1", subDir) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + }) + _, err = server.validateUser(u, &mockFTPClientContext{}, dataprovider.LoginMethodPassword) + assert.Error(t, err) + u.VirtualFolders = nil + _, err = server.validateUser(u, &mockFTPClientContext{}, dataprovider.LoginMethodPassword) + assert.Error(t, err) +} + +func TestFTPMode(t *testing.T) { + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolFTP, "", "", dataprovider.User{}), + } + assert.Empty(t, connection.getFTPMode()) + connection.clientContext = &mockFTPClientContext{lastDataChannel: ftpserver.DataChannelActive} + assert.Equal(t, "active", connection.getFTPMode()) + connection.clientContext = &mockFTPClientContext{lastDataChannel: ftpserver.DataChannelPassive} + assert.Equal(t, "passive", connection.getFTPMode()) + connection.clientContext = &mockFTPClientContext{lastDataChannel: 0} + assert.Empty(t, connection.getFTPMode()) +} + +func TestClientVersion(t *testing.T) { + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("2_%v", mockCC.ID()) + user := dataprovider.User{} + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + clientContext: mockCC, + } + err := common.Connections.Add(connection) + assert.NoError(t, err) + stats := common.Connections.GetStats("") + if assert.Len(t, stats, 1) { + assert.Equal(t, "mock version", stats[0].ClientVersion) + common.Connections.Remove(connection.GetID()) + } + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestDriverMethodsNotImplemented(t *testing.T) { + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("2_%v", mockCC.ID()) + user := dataprovider.User{} + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + clientContext: mockCC, + } + _, err := connection.Create("") + assert.EqualError(t, err, errNotImplemented.Error()) + err = connection.MkdirAll("", os.ModePerm) + assert.EqualError(t, err, errNotImplemented.Error()) + _, err = connection.Open("") + assert.EqualError(t, err, errNotImplemented.Error()) + _, err = connection.OpenFile("", 0, os.ModePerm) + assert.EqualError(t, err, errNotImplemented.Error()) + err = connection.RemoveAll("") + assert.EqualError(t, err, errNotImplemented.Error()) + assert.Equal(t, connection.GetID(), connection.Name()) +} + +func TestExtraData(t *testing.T) { + mockCC := mockFTPClientContext{} + _, ok := mockCC.Extra().(*tlsState) + require.False(t, ok) + mockCC.SetExtra(&tlsState{ + LoginWithMutualTLS: false, + Version: tls.VersionName(tls.VersionTLS13), + Cipher: tls.CipherSuiteName(tls.TLS_AES_128_GCM_SHA256), + KEX: tls.X25519MLKEM768.String(), + }) + state, ok := mockCC.Extra().(*tlsState) + require.True(t, ok) + require.False(t, state.LoginWithMutualTLS) + require.Equal(t, tls.VersionName(tls.VersionTLS13), state.Version) + require.Equal(t, tls.CipherSuiteName(tls.TLS_AES_128_GCM_SHA256), state.Cipher) + require.Equal(t, tls.X25519MLKEM768.String(), state.KEX) + mockCC.SetExtra(&tlsState{ + LoginWithMutualTLS: true, + }) + state, ok = mockCC.Extra().(*tlsState) + require.True(t, ok) + require.True(t, state.LoginWithMutualTLS) +} + +func TestResolvePathErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: "invalid", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("%v", mockCC.ID()) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + clientContext: mockCC, + } + err := connection.Mkdir("", os.ModePerm) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + err = connection.Remove("") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + err = connection.RemoveDir("") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + err = connection.Rename("", "") + assert.ErrorIs(t, err, common.ErrOpUnsupported) + err = connection.Symlink("", "") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + _, err = connection.Stat("") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + err = connection.Chmod("", os.ModePerm) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + err = connection.Chtimes("", time.Now(), time.Now()) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + _, err = connection.ReadDir("") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + _, err = connection.GetHandle("", 0, 0) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + _, err = connection.GetAvailableSpace("") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } +} + +func TestUploadFileStatError(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("this test is not available on Windows") + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user", + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("%v", mockCC.ID()) + fs := vfs.NewOsFs(connID, user.HomeDir, "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + clientContext: mockCC, + } + testFile := filepath.Join(user.HomeDir, "test", "testfile") + err := os.MkdirAll(filepath.Dir(testFile), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFile, []byte("data"), os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(filepath.Dir(testFile), 0001) + assert.NoError(t, err) + _, err = connection.uploadFile(fs, testFile, "test", 0) + assert.Error(t, err) + err = os.Chmod(filepath.Dir(testFile), os.ModePerm) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Dir(testFile)) + assert.NoError(t, err) +} + +func TestAVBLErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user", + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("%v", mockCC.ID()) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + clientContext: mockCC, + } + _, err := connection.GetAvailableSpace("/") + assert.NoError(t, err) + _, err = connection.GetAvailableSpace("/missing-path") + assert.Error(t, err) + assert.True(t, errors.Is(err, fs.ErrNotExist)) +} + +func TestUploadOverwriteErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user", + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("%v", mockCC.ID()) + fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + clientContext: mockCC, + } + flags := 0 + flags |= os.O_APPEND + _, err := connection.handleFTPUploadToExistingFile(fs, flags, "", "", 0, "") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrOpUnsupported.Error()) + } + + f, err := os.CreateTemp("", "temp") + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + flags = 0 + flags |= os.O_CREATE + flags |= os.O_TRUNC + tr, err := connection.handleFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name()) + if assert.NoError(t, err) { + transfer := tr.(*transfer) + transfers := connection.GetTransfers() + if assert.Equal(t, 1, len(transfers)) { + assert.Equal(t, transfers[0].ID, transfer.GetID()) + assert.Equal(t, int64(123), transfer.InitialSize) + err = transfer.Close() + assert.NoError(t, err) + assert.Equal(t, 0, len(connection.GetTransfers())) + } + } + err = os.Remove(f.Name()) + assert.NoError(t, err) + + _, err = connection.handleFTPUploadToExistingFile(fs, os.O_TRUNC, filepath.Join(os.TempDir(), "sub", "file"), + filepath.Join(os.TempDir(), "sub", "file1"), 0, "/sub/file1") + assert.Error(t, err) + fs = vfs.NewOsFs(connID, user.GetHomeDir(), "", nil) + _, err = connection.handleFTPUploadToExistingFile(fs, 0, "missing1", "missing2", 0, "missing") + assert.Error(t, err) +} + +func TestTransferErrors(t *testing.T) { + testfile := "testfile" + file, err := os.Create(testfile) + assert.NoError(t, err) + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user", + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + mockCC := &mockFTPClientContext{} + connID := fmt.Sprintf("%v", mockCC.ID()) + fs := newMockOsFs(nil, nil, false, connID, user.GetHomeDir()) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, "", "", user), + } + baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), file.Name(), testfile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + tr := newTransfer(baseTransfer, nil, nil, 0) + err = tr.Close() + assert.NoError(t, err) + _, err = tr.Seek(10, 0) + assert.Error(t, err) + buf := make([]byte, 64) + _, err = tr.Read(buf) + assert.Error(t, err) + err = tr.Close() + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrTransferClosed.Error()) + } + assert.Len(t, connection.GetTransfers(), 0) + + r, _, err := pipeat.Pipe() + assert.NoError(t, err) + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + tr = newTransfer(baseTransfer, nil, vfs.NewPipeReader(r), 10) + pos, err := tr.Seek(10, 0) + assert.NoError(t, err) + assert.Equal(t, pos, tr.expectedOffset) + err = tr.closeIO() + assert.NoError(t, err) + + r, w, err := pipeat.Pipe() + assert.NoError(t, err) + pipeWriter := vfs.NewPipeWriter(w) + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, testfile, + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + tr.Connection.RemoveTransfer(tr) + tr = newTransfer(baseTransfer, pipeWriter, nil, 0) + + err = r.Close() + assert.NoError(t, err) + errFake := fmt.Errorf("fake upload error") + go func() { + time.Sleep(100 * time.Millisecond) + pipeWriter.Done(errFake) + }() + err = tr.closeIO() + assert.EqualError(t, err, errFake.Error()) + _, err = tr.Seek(1, 0) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrOpUnsupported.Error()) + } + tr.Connection.RemoveTransfer(tr) + err = os.Remove(testfile) + assert.NoError(t, err) +} + +func TestVerifyTLSConnection(t *testing.T) { + oldCertMgr := certMgr + + caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") + certPath := filepath.Join(os.TempDir(), "test.crt") + keyPath := filepath.Join(os.TempDir(), "test.key") + err := os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(certPath, []byte(ftpsCert), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(keyPath, []byte(ftpsKey), os.ModePerm) + assert.NoError(t, err) + keyPairs := []common.TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: common.DefaultTLSKeyPaidID, + }, + } + certMgr, err = common.NewCertManager(keyPairs, "", "ftp_test") + assert.NoError(t, err) + + certMgr.SetCARevocationLists([]string{caCrlPath}) + err = certMgr.LoadCRLs() + assert.NoError(t, err) + + crt, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + x509crt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + server := Server{} + state := tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{x509crt}, + } + + err = server.verifyTLSConnection(state) + assert.Error(t, err) // no verified certification chain + err = server.VerifyTLSConnectionState(nil, state) + assert.NoError(t, err) + server.binding.ClientAuthType = 1 + err = server.VerifyTLSConnectionState(nil, state) + assert.Error(t, err) + + crt, err = tls.X509KeyPair([]byte(caCRT), []byte(caKey)) + assert.NoError(t, err) + + x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crt, x509CAcrt}) + err = server.verifyTLSConnection(state) + assert.NoError(t, err) + + crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + x509crtRevoked, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crtRevoked, x509CAcrt}) + state.PeerCertificates = []*x509.Certificate{x509crtRevoked} + err = server.verifyTLSConnection(state) + assert.EqualError(t, err, common.ErrCrtRevoked.Error()) + + err = os.Remove(caCrlPath) + assert.NoError(t, err) + err = os.Remove(certPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + + certMgr = oldCertMgr +} + +func TestCiphers(t *testing.T) { + b := Binding{ + TLSCipherSuites: []string{}, + } + b.setCiphers() + require.Equal(t, util.GetTLSCiphersFromNames(nil), b.ciphers) + b.TLSCipherSuites = []string{"TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384"} + b.setCiphers() + require.Len(t, b.ciphers, 2) + require.Equal(t, []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384}, b.ciphers) +} + +func TestPassiveIPResolver(t *testing.T) { + b := Binding{ + PassiveIPOverrides: []PassiveIPOverride{ + {}, + }, + } + err := b.checkPassiveIP() + assert.Error(t, err) + assert.Contains(t, err.Error(), "passive IP networks override cannot be empty") + b = Binding{ + PassiveIPOverrides: []PassiveIPOverride{ + { + IP: "invalid ip", + }, + }, + } + err = b.checkPassiveIP() + assert.Error(t, err) + assert.Contains(t, err.Error(), "is not valid") + + b = Binding{ + PassiveIPOverrides: []PassiveIPOverride{ + { + IP: "192.168.1.1", + Networks: []string{"192.168.1.0/24", "invalid cidr"}, + }, + }, + } + err = b.checkPassiveIP() + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid passive IP networks override") + b = Binding{ + ForcePassiveIP: "192.168.2.1", + PassiveIPOverrides: []PassiveIPOverride{ + { + IP: "::ffff:192.168.1.1", + Networks: []string{"192.168.1.0/24"}, + }, + }, + } + err = b.checkPassiveIP() + assert.NoError(t, err) + assert.NotEmpty(t, b.PassiveIPOverrides[0].GetNetworksAsString()) + assert.Equal(t, "192.168.1.1", b.PassiveIPOverrides[0].IP) + require.Len(t, b.PassiveIPOverrides[0].parsedNetworks, 1) + ip := net.ParseIP("192.168.1.2") + assert.True(t, b.PassiveIPOverrides[0].parsedNetworks[0](ip)) + ip = net.ParseIP("192.168.0.2") + assert.False(t, b.PassiveIPOverrides[0].parsedNetworks[0](ip)) + + mockCC := &mockFTPClientContext{ + remoteIP: "192.168.1.10", + localIP: "192.168.1.3", + } + passiveIP, err := b.passiveIPResolver(mockCC) + assert.NoError(t, err) + assert.Equal(t, "192.168.1.1", passiveIP) + b.PassiveIPOverrides[0].IP = "" + passiveIP, err = b.passiveIPResolver(mockCC) + assert.NoError(t, err) + assert.Equal(t, "192.168.1.3", passiveIP) + mockCC.remoteIP = "172.16.2.3" + passiveIP, err = b.passiveIPResolver(mockCC) + assert.NoError(t, err) + assert.Equal(t, b.ForcePassiveIP, passiveIP) +} + +func TestRelativePath(t *testing.T) { + rel := getPathRelativeTo("/testpath", "/testpath") + assert.Empty(t, rel) + rel = getPathRelativeTo("/", "/") + assert.Empty(t, rel) + rel = getPathRelativeTo("/", "/dir/sub") + assert.Equal(t, "dir/sub", rel) + rel = getPathRelativeTo("./", "/dir/sub") + assert.Equal(t, "/dir/sub", rel) + rel = getPathRelativeTo("/sub", "/dir/sub") + assert.Equal(t, "../dir/sub", rel) + rel = getPathRelativeTo("/dir", "/dir/sub") + assert.Equal(t, "sub", rel) + rel = getPathRelativeTo("/dir/sub", "/dir") + assert.Equal(t, "../", rel) + rel = getPathRelativeTo("dir", "/dir1") + assert.Equal(t, "/dir1", rel) + rel = getPathRelativeTo("", "/dir2") + assert.Equal(t, "dir2", rel) + rel = getPathRelativeTo(".", "/dir2") + assert.Equal(t, "/dir2", rel) + rel = getPathRelativeTo("/dir3", "dir3") + assert.Equal(t, "dir3", rel) +} + +func TestConfigsFromProvider(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + c := Configuration{} + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + configs := dataprovider.Configs{ + ACME: &dataprovider.ACMEConfigs{ + Domain: "domain.com", + Email: "info@domain.com", + HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: 80}, + Protocols: 2, + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + util.CertsBasePath = "" + // crt and key empty + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + util.CertsBasePath = filepath.Clean(os.TempDir()) + // crt not found + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs := c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + crtPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".crt") + err = os.WriteFile(crtPath, nil, 0666) + assert.NoError(t, err) + // key not found + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + keyPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".key") + err = os.WriteFile(keyPath, nil, 0666) + assert.NoError(t, err) + // acme cert used + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Equal(t, configs.ACME.Domain, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 1) + // protocols does not match + configs.ACME.Protocols = 5 + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + c.acmeDomain = "" + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + + err = os.Remove(crtPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + util.CertsBasePath = "" + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestPassiveHost(t *testing.T) { + b := Binding{ + PassiveHost: "invalid hostname", + } + _, err := b.getPassiveIP(nil) + assert.Error(t, err) + b.PassiveHost = "localhost" + ip, err := b.getPassiveIP(nil) + assert.NoError(t, err, ip) + assert.Equal(t, "127.0.0.1", ip) +} diff --git a/internal/ftpd/server.go b/internal/ftpd/server.go new file mode 100644 index 00000000..2e5790ae --- /dev/null +++ b/internal/ftpd/server.go @@ -0,0 +1,460 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ftpd + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "slices" + + ftpserver "github.com/fclairamb/ftpserverlib" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +// tlsState tracks TLS connection state for a client +type tlsState struct { + // LoginWithMutualTLS indicates whether the user logged in using TLS certificate authentication + LoginWithMutualTLS bool + Version string + Cipher string + KEX string +} + +// Server implements the ftpserverlib MainDriver interface +type Server struct { + ID int + config *Configuration + initialMsg string + statusBanner string + binding Binding + tlsConfig *tls.Config +} + +// NewServer returns a new FTP server driver +func NewServer(config *Configuration, configDir string, binding Binding, id int) *Server { + binding.setCiphers() + vers := version.GetServerVersion("_", false) + server := &Server{ + config: config, + initialMsg: vers, + statusBanner: fmt.Sprintf("%s FTP Server", vers), + binding: binding, + ID: id, + } + if config.BannerFile != "" { + bannerFilePath := config.BannerFile + if !filepath.IsAbs(bannerFilePath) { + bannerFilePath = filepath.Join(configDir, bannerFilePath) + } + bannerContent, err := os.ReadFile(bannerFilePath) + if err == nil { + server.initialMsg = util.BytesToString(bannerContent) + } else { + logger.WarnToConsole("unable to read FTPD banner file: %v", err) + logger.Warn(logSender, "", "unable to read banner file: %v", err) + } + } + server.buildTLSConfig() + return server +} + +// GetSettings returns FTP server settings +func (s *Server) GetSettings() (*ftpserver.Settings, error) { + if err := s.binding.checkPassiveIP(); err != nil { + return nil, err + } + if err := s.binding.checkSecuritySettings(); err != nil { + return nil, err + } + var portRange *ftpserver.PortRange + if s.config.PassivePortRange.Start > 0 && s.config.PassivePortRange.End >= s.config.PassivePortRange.Start { + portRange = &ftpserver.PortRange{ + Start: s.config.PassivePortRange.Start, + End: s.config.PassivePortRange.End, + } + } + var ftpListener net.Listener + if s.binding.HasProxy() { + listener, err := net.Listen("tcp", s.binding.GetAddress()) + if err != nil { + logger.Warn(logSender, "", "error starting listener on address %v: %v", s.binding.GetAddress(), err) + return nil, err + } + ftpListener, err = common.Config.GetProxyListener(listener) + if err != nil { + logger.Warn(logSender, "", "error enabling proxy listener: %v", err) + return nil, err + } + if s.binding.TLSMode == 2 && s.tlsConfig != nil { + ftpListener = tls.NewListener(ftpListener, s.tlsConfig) + } + } + + if !s.binding.isTLSModeValid() { + return nil, fmt.Errorf("unsupported TLS mode: %d", s.binding.TLSMode) + } + + if s.binding.TLSMode > 0 && certMgr == nil { + return nil, errors.New("to enable TLS you need to provide a certificate") + } + + settings := &ftpserver.Settings{ + Listener: ftpListener, + ListenAddr: s.binding.GetAddress(), + PublicIPResolver: s.binding.passiveIPResolver, + ActiveTransferPortNon20: s.config.ActiveTransfersPortNon20, + IdleTimeout: -1, + ConnectionTimeout: 20, + Banner: s.statusBanner, + TLSRequired: ftpserver.TLSRequirement(s.binding.TLSMode), + DisableSite: !s.config.EnableSite, + DisableActiveMode: s.config.DisableActiveMode, + EnableHASH: s.config.HASHSupport > 0, + EnableCOMB: s.config.CombineSupport > 0, + DefaultTransferType: ftpserver.TransferTypeBinary, + ActiveConnectionsCheck: ftpserver.DataConnectionRequirement(s.binding.ActiveConnectionsSecurity), + PasvConnectionsCheck: ftpserver.DataConnectionRequirement(s.binding.PassiveConnectionsSecurity), + } + if portRange != nil { + settings.PassiveTransferPortRange = portRange + } + return settings, nil +} + +// ClientConnected is called to send the very first welcome message +func (s *Server) ClientConnected(cc ftpserver.ClientContext) (string, error) { + cc.SetDebug(s.binding.Debug) + ipAddr := util.GetIPFromRemoteAddress(cc.RemoteAddr().String()) + common.Connections.AddClientConnection(ipAddr) + if common.IsBanned(ipAddr, common.ProtocolFTP) { + logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection refused, ip %q is banned", ipAddr) + return "Access denied: banned client IP", common.ErrConnectionDenied + } + if err := common.Connections.IsNewConnectionAllowed(ipAddr, common.ProtocolFTP); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolFTP, "", "connection not allowed from ip %q: %v", ipAddr, err) + return "Access denied", err + } + _, err := common.LimitRate(common.ProtocolFTP, ipAddr) + if err != nil { + return fmt.Sprintf("Access denied: %v", err.Error()), err + } + if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolFTP); err != nil { + return "Access denied", err + } + connID := fmt.Sprintf("%v_%v", s.ID, cc.ID()) + user := dataprovider.User{} + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolFTP, cc.LocalAddr().String(), + cc.RemoteAddr().String(), user), + clientContext: cc, + } + err = common.Connections.Add(connection) + return s.initialMsg, err +} + +// ClientDisconnected is called when the user disconnects, even if he never authenticated +func (s *Server) ClientDisconnected(cc ftpserver.ClientContext) { + connID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()) + common.Connections.Remove(connID) + common.Connections.RemoveClientConnection(util.GetIPFromRemoteAddress(cc.RemoteAddr().String())) +} + +// AuthUser authenticates the user and selects an handling driver +func (s *Server) AuthUser(cc ftpserver.ClientContext, username, password string) (ftpserver.ClientDriver, error) { + loginMethod := dataprovider.LoginMethodPassword + tlsState, ok := cc.Extra().(*tlsState) + if ok && tlsState != nil && tlsState.LoginWithMutualTLS { + loginMethod = dataprovider.LoginMethodTLSCertificateAndPwd + } + ipAddr := util.GetIPFromRemoteAddress(cc.RemoteAddr().String()) + user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, common.ProtocolFTP) + if err != nil { + user.Username = username + updateLoginMetrics(&user, ipAddr, loginMethod, err, nil) + return nil, dataprovider.ErrInvalidCredentials + } + + connection, err := s.validateUser(user, cc, loginMethod) + + defer updateLoginMetrics(&user, ipAddr, loginMethod, err, connection) + + if err != nil { + return nil, err + } + setStartDirectory(user.Filters.StartDirectory, cc) + dataprovider.UpdateLastLogin(&user) + return connection, nil +} + +// PreAuthUser implements the MainDriverExtensionUserVerifier interface +func (s *Server) PreAuthUser(cc ftpserver.ClientContext, username string) error { + if s.binding.TLSMode == 0 && s.tlsConfig != nil { + user, err := dataprovider.GetFTPPreAuthUser(username, util.GetIPFromRemoteAddress(cc.RemoteAddr().String())) + if err == nil { + if user.Filters.FTPSecurity == 1 { + return cc.SetTLSRequirement(ftpserver.MandatoryEncryption) + } + return nil + } + if !errors.Is(err, util.ErrNotFound) { + logger.Error(logSender, fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()), + "unable to get user on pre auth: %v", err) + return common.ErrInternalFailure + } + } + return nil +} + +// WrapPassiveListener implements the MainDriverExtensionPassiveWrapper interface +func (s *Server) WrapPassiveListener(listener net.Listener) (net.Listener, error) { + if s.binding.HasProxy() { + return common.Config.GetProxyListener(listener) + } + return listener, nil +} + +// VerifyConnection checks whether a user should be authenticated using a client certificate without prompting for a password +func (s *Server) VerifyConnection(cc ftpserver.ClientContext, user string, tlsConn *tls.Conn) (ftpserver.ClientDriver, error) { + if tlsConn == nil { + return nil, nil + } + state := tlsConn.ConnectionState() + cc.SetExtra(&tlsState{ + LoginWithMutualTLS: false, + Cipher: tls.CipherSuiteName(state.CipherSuite), + Version: tls.VersionName(state.Version), + KEX: state.CurveID.String(), + }) + if !s.binding.isMutualTLSEnabled() { + return nil, nil + } + + if len(state.PeerCertificates) > 0 { + ipAddr := util.GetIPFromRemoteAddress(cc.RemoteAddr().String()) + dbUser, err := dataprovider.CheckUserBeforeTLSAuth(user, ipAddr, common.ProtocolFTP, state.PeerCertificates[0]) + if err != nil { + dbUser.Username = user + updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err, nil) + return nil, dataprovider.ErrInvalidCredentials + } + if dbUser.IsTLSVerificationEnabled() { + dbUser, err = dataprovider.CheckUserAndTLSCert(user, ipAddr, common.ProtocolFTP, state.PeerCertificates[0]) + if err != nil { + return nil, err + } + + cc.SetExtra(&tlsState{ + LoginWithMutualTLS: true, + Cipher: tls.CipherSuiteName(state.CipherSuite), + Version: tls.VersionName(state.Version), + KEX: state.CurveID.String(), + }) + + if dbUser.IsLoginMethodAllowed(dataprovider.LoginMethodTLSCertificate, common.ProtocolFTP) { + connection, err := s.validateUser(dbUser, cc, dataprovider.LoginMethodTLSCertificate) + + defer updateLoginMetrics(&dbUser, ipAddr, dataprovider.LoginMethodTLSCertificate, err, connection) + + if err != nil { + return nil, err + } + setStartDirectory(dbUser.Filters.StartDirectory, cc) + dataprovider.UpdateLastLogin(&dbUser) + return connection, nil + } + } + } + + return nil, nil +} + +func (s *Server) buildTLSConfig() { + if certMgr != nil { + certID := common.DefaultTLSKeyPaidID + if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" { + certID = s.binding.GetAddress() + } + if !certMgr.HasCertificate(certID) { + return + } + s.tlsConfig = &tls.Config{ + GetCertificate: certMgr.GetCertificateFunc(certID), + MinVersion: util.GetTLSVersion(s.binding.MinTLSVersion), + CipherSuites: s.binding.ciphers, + } + logger.Debug(logSender, "", "configured TLS cipher suites for binding %q: %v, certID: %v", + s.binding.GetAddress(), s.binding.ciphers, certID) + if s.binding.isMutualTLSEnabled() { + s.tlsConfig.ClientCAs = certMgr.GetRootCAs() + s.tlsConfig.VerifyConnection = s.verifyTLSConnection + switch s.binding.ClientAuthType { + case 1: + s.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + case 2: + s.tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven + } + } + } +} + +// GetTLSConfig returns the TLS configuration for this server +func (s *Server) GetTLSConfig() (*tls.Config, error) { + if s.tlsConfig != nil { + return s.tlsConfig, nil + } + return nil, errors.New("no TLS certificate configured") +} + +// VerifyTLSConnectionState implements the MainDriverExtensionTLSConnectionStateVerifier extension +func (s *Server) VerifyTLSConnectionState(_ ftpserver.ClientContext, cs tls.ConnectionState) error { + if !s.binding.isMutualTLSEnabled() { + return nil + } + return s.verifyTLSConnection(cs) +} + +func (s *Server) verifyTLSConnection(state tls.ConnectionState) error { + if certMgr != nil { + var clientCrt *x509.Certificate + var clientCrtName string + if len(state.PeerCertificates) > 0 { + clientCrt = state.PeerCertificates[0] + clientCrtName = clientCrt.Subject.String() + } + if len(state.VerifiedChains) == 0 { + if s.binding.ClientAuthType == 2 { + return nil + } + logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain") + return errors.New("TLS connection cannot be verified: unable to get verification chain") + } + for _, verifiedChain := range state.VerifiedChains { + var caCrt *x509.Certificate + if len(verifiedChain) > 0 { + caCrt = verifiedChain[len(verifiedChain)-1] + } + if certMgr.IsRevoked(clientCrt, caCrt) { + logger.Debug(logSender, "", "tls handshake error, client certificate %q has beed revoked", clientCrtName) + return common.ErrCrtRevoked + } + } + } + + return nil +} + +func (s *Server) validateUser(user dataprovider.User, cc ftpserver.ClientContext, loginMethod string) (*Connection, error) { + connectionID := fmt.Sprintf("%v_%v_%v", common.ProtocolFTP, s.ID, cc.ID()) + if !filepath.IsAbs(user.HomeDir) { + logger.Warn(logSender, connectionID, "user %q has an invalid home dir: %q. Home dir must be an absolute path, login not allowed", + user.Username, user.HomeDir) + return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir) + } + if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolFTP) { + logger.Info(logSender, connectionID, "cannot login user %q, protocol FTP is not allowed", user.Username) + return nil, fmt.Errorf("protocol FTP is not allowed for user %q", user.Username) + } + if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolFTP) { + logger.Info(logSender, connectionID, "cannot login user %q, %v login method is not allowed", + user.Username, loginMethod) + return nil, fmt.Errorf("login method %v is not allowed for user %q", loginMethod, user.Username) + } + if user.MustSetSecondFactorForProtocol(common.ProtocolFTP) { + logger.Info(logSender, connectionID, "cannot login user %q, second factor authentication is not set", + user.Username) + return nil, fmt.Errorf("second factor authentication is not set for user %q", user.Username) + } + if user.MaxSessions > 0 { + activeSessions := common.Connections.GetActiveSessions(user.Username) + if activeSessions >= user.MaxSessions { + logger.Info(logSender, connectionID, "authentication refused for user: %q, too many open sessions: %v/%v", + user.Username, activeSessions, user.MaxSessions) + return nil, fmt.Errorf("too many open sessions: %v", activeSessions) + } + } + remoteAddr := cc.RemoteAddr().String() + if !user.IsLoginFromAddrAllowed(remoteAddr) { + logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", + user.Username, remoteAddr) + return nil, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, remoteAddr) + } + err := user.CheckFsRoot(connectionID) + if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) + return nil, common.ErrInternalFailure + } + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fmt.Sprintf("%v_%v", s.ID, cc.ID()), common.ProtocolFTP, + cc.LocalAddr().String(), remoteAddr, user), + clientContext: cc, + } + err = common.Connections.Swap(connection) + if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to swap connection: %v, close fs error: %v", err, errClose) + return nil, err + } + return connection, nil +} + +func setStartDirectory(startDirectory string, cc ftpserver.ClientContext) { + if startDirectory == "" { + return + } + cc.SetPath(startDirectory) +} + +func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error, c *Connection) { + metric.AddLoginAttempt(loginMethod) + if err == nil { + info := "" + if tlsState, ok := c.clientContext.Extra().(*tlsState); ok && tlsState != nil { + info = fmt.Sprintf("%s - %s - %s", tlsState.Version, tlsState.Cipher, tlsState.KEX) + } + logger.LoginLog(user.Username, ip, loginMethod, common.ProtocolFTP, c.ID, c.GetClientVersion(), + c.clientContext.HasTLSForControl(), info) + plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolFTP, user.Username, ip, "", nil) + common.DelayLogin(nil) + } else if err != common.ErrInternalFailure { + logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolFTP, err.Error()) + event := common.HostEventLoginFailed + logEv := notifier.LogEventTypeLoginFailed + if errors.Is(err, util.ErrNotFound) { + event = common.HostEventUserNotFound + logEv = notifier.LogEventTypeLoginNoUser + } + common.AddDefenderEvent(ip, common.ProtocolFTP, event) + plugin.Handler.NotifyLogEvent(logEv, common.ProtocolFTP, user.Username, ip, "", err) + if loginMethod != dataprovider.LoginMethodTLSCertificate { + common.DelayLogin(err) + } + } + metric.AddLoginResult(loginMethod, err) + dataprovider.ExecutePostLoginHook(user, loginMethod, ip, common.ProtocolFTP, err) +} diff --git a/internal/ftpd/transfer.go b/internal/ftpd/transfer.go new file mode 100644 index 00000000..071f8b1e --- /dev/null +++ b/internal/ftpd/transfer.go @@ -0,0 +1,153 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package ftpd + +import ( + "errors" + "io" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// transfer contains the transfer details for an upload or a download. +// It implements the ftpserver.FileTransfer interface to handle files downloads and uploads +type transfer struct { + *common.BaseTransfer + writer io.WriteCloser + reader io.ReadCloser + isFinished bool + expectedOffset int64 +} + +func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader, + expectedOffset int64) *transfer { + var writer io.WriteCloser + var reader io.ReadCloser + if baseTransfer.File != nil { + writer = baseTransfer.File + reader = baseTransfer.File + } else if pipeWriter != nil { + writer = pipeWriter + } else if pipeReader != nil { + reader = pipeReader + } + return &transfer{ + BaseTransfer: baseTransfer, + writer: writer, + reader: reader, + isFinished: false, + expectedOffset: expectedOffset, + } +} + +// Read reads the contents to downloads. +func (t *transfer) Read(p []byte) (n int, err error) { + t.Connection.UpdateLastActivity() + + n, err = t.reader.Read(p) + t.BytesSent.Add(int64(n)) + + if err == nil { + err = t.CheckRead() + } + if err != nil && err != io.EOF { + t.TransferError(err) + err = t.ConvertError(err) + return + } + t.HandleThrottle() + return +} + +// Write writes the uploaded contents. +func (t *transfer) Write(p []byte) (n int, err error) { + t.Connection.UpdateLastActivity() + + n, err = t.writer.Write(p) + t.BytesReceived.Add(int64(n)) + + if err == nil { + err = t.CheckWrite() + } + if err != nil { + t.TransferError(err) + err = t.ConvertError(err) + return + } + t.HandleThrottle() + return +} + +// Seek sets the offset to resume an upload or a download +func (t *transfer) Seek(offset int64, whence int) (int64, error) { + t.Connection.UpdateLastActivity() + if t.File != nil { + ret, err := t.File.Seek(offset, whence) + if err != nil { + t.TransferError(err) + } + return ret, err + } + if (t.reader != nil || t.writer != nil) && t.expectedOffset == offset && whence == io.SeekStart { + return offset, nil + } + t.TransferError(errors.New("seek is unsupported for this transfer")) + return 0, common.ErrOpUnsupported +} + +// Close it is called when the transfer is completed. +func (t *transfer) Close() error { + if err := t.setFinished(); err != nil { + return err + } + err := t.closeIO() + errBaseClose := t.BaseTransfer.Close() + if errBaseClose != nil { + err = errBaseClose + } + return t.Connection.GetFsError(t.Fs, err) +} + +func (t *transfer) closeIO() error { + var err error + if t.File != nil { + err = t.File.Close() + } else if t.writer != nil { + err = t.writer.Close() + t.Lock() + // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic + if err != nil && t.ErrTransfer == nil { + t.ErrTransfer = err + } + t.Unlock() + } else if t.reader != nil { + err = t.reader.Close() + if metadater, ok := t.reader.(vfs.Metadater); ok { + t.SetMetadata(metadater.Metadata()) + } + } + return err +} + +func (t *transfer) setFinished() error { + t.Lock() + defer t.Unlock() + if t.isFinished { + return common.ErrTransferClosed + } + t.isFinished = true + return nil +} diff --git a/internal/httpclient/httpclient.go b/internal/httpclient/httpclient.go new file mode 100644 index 00000000..61fb3d18 --- /dev/null +++ b/internal/httpclient/httpclient.go @@ -0,0 +1,276 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package httpclient provides HTTP client configuration for SFTPGo hooks +package httpclient + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/hashicorp/go-retryablehttp" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// TLSKeyPair defines the paths for a TLS key pair +type TLSKeyPair struct { + Cert string `json:"cert" mapstructure:"cert"` + Key string `json:"key" mapstructure:"key"` +} + +// Header defines an HTTP header. +// If the URL is not empty, the header is added only if the +// requested URL starts with the one specified +type Header struct { + Key string `json:"key" mapstructure:"key"` + Value string `json:"value" mapstructure:"value"` + URL string `json:"url" mapstructure:"url"` +} + +// Config defines the configuration for HTTP clients. +// HTTP clients are used for executing hooks such as the ones used for +// custom actions, external authentication and pre-login user modifications +type Config struct { + // Timeout specifies a time limit, in seconds, for a request + Timeout float64 `json:"timeout" mapstructure:"timeout"` + // RetryWaitMin defines the minimum waiting time between attempts in seconds + RetryWaitMin int `json:"retry_wait_min" mapstructure:"retry_wait_min"` + // RetryWaitMax defines the minimum waiting time between attempts in seconds + RetryWaitMax int `json:"retry_wait_max" mapstructure:"retry_wait_max"` + // RetryMax defines the maximum number of attempts + RetryMax int `json:"retry_max" mapstructure:"retry_max"` + // CACertificates defines extra CA certificates to trust. + // The paths can be absolute or relative to the config dir. + // Adding trusted CA certificates is a convenient way to use self-signed + // certificates without defeating the purpose of using TLS + CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` + // Certificates defines the certificates to use for mutual TLS + Certificates []TLSKeyPair `json:"certificates" mapstructure:"certificates"` + // if enabled the HTTP client accepts any TLS certificate presented by + // the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + SkipTLSVerify bool `json:"skip_tls_verify" mapstructure:"skip_tls_verify"` + // Headers defines a list of http headers to add to each request + Headers []Header `json:"headers" mapstructure:"headers"` + customTransport *http.Transport +} + +const logSender = "httpclient" + +var httpConfig Config + +// Initialize configures HTTP clients +func (c *Config) Initialize(configDir string) error { + if c.Timeout <= 0 { + return fmt.Errorf("invalid timeout: %v", c.Timeout) + } + rootCAs, err := c.loadCACerts(configDir) + if err != nil { + return err + } + customTransport := http.DefaultTransport.(*http.Transport).Clone() + if customTransport.TLSClientConfig != nil { + customTransport.TLSClientConfig.RootCAs = rootCAs + } else { + customTransport.TLSClientConfig = &tls.Config{ + RootCAs: rootCAs, + } + } + customTransport.TLSClientConfig.InsecureSkipVerify = c.SkipTLSVerify + c.customTransport = customTransport + + err = c.loadCertificates(configDir) + if err != nil { + return err + } + var headers []Header + for _, h := range c.Headers { + if h.Key != "" && h.Value != "" { + headers = append(headers, h) + } + } + c.Headers = headers + httpConfig = *c + return nil +} + +// loadCACerts returns system cert pools and try to add the configured +// CA certificates to it +func (c *Config) loadCACerts(configDir string) (*x509.CertPool, error) { + if len(c.CACertificates) == 0 { + return nil, nil + } + rootCAs, err := x509.SystemCertPool() + if err != nil { + rootCAs = x509.NewCertPool() + } + + for _, ca := range c.CACertificates { + if !util.IsFileInputValid(ca) { + return nil, fmt.Errorf("unable to load invalid CA certificate: %q", ca) + } + if !filepath.IsAbs(ca) { + ca = filepath.Join(configDir, ca) + } + certs, err := os.ReadFile(ca) + if err != nil { + return nil, fmt.Errorf("unable to load CA certificate: %v", err) + } + if rootCAs.AppendCertsFromPEM(certs) { + logger.Debug(logSender, "", "CA certificate %q added to the trusted certificates", ca) + } else { + return nil, fmt.Errorf("unable to add CA certificate %q to the trusted cetificates", ca) + } + } + return rootCAs, nil +} + +func (c *Config) loadCertificates(configDir string) error { + if len(c.Certificates) == 0 { + return nil + } + + for _, keyPair := range c.Certificates { + cert := keyPair.Cert + key := keyPair.Key + if !util.IsFileInputValid(cert) { + return fmt.Errorf("unable to load invalid certificate: %q", cert) + } + if !util.IsFileInputValid(key) { + return fmt.Errorf("unable to load invalid key: %q", key) + } + if !filepath.IsAbs(cert) { + cert = filepath.Join(configDir, cert) + } + if !filepath.IsAbs(key) { + key = filepath.Join(configDir, key) + } + tlsCert, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + return fmt.Errorf("unable to load key pair %q, %q: %v", cert, key, err) + } + x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0]) + if err == nil { + logger.Debug(logSender, "", "adding leaf certificate for key pair %q, %q", cert, key) + tlsCert.Leaf = x509Cert + } + logger.Debug(logSender, "", "client certificate %q and key %q successfully loaded", cert, key) + c.customTransport.TLSClientConfig.Certificates = append(c.customTransport.TLSClientConfig.Certificates, tlsCert) + } + return nil +} + +// GetHTTPClient returns a new HTTP client with the configured parameters +func GetHTTPClient() *http.Client { + return &http.Client{ + Timeout: time.Duration(httpConfig.Timeout * float64(time.Second)), + Transport: httpConfig.customTransport, + } +} + +// GetRetraybleHTTPClient returns an HTTP client that retry a request on error. +// It uses the configured retry parameters +func GetRetraybleHTTPClient() *retryablehttp.Client { + client := retryablehttp.NewClient() + client.HTTPClient.Timeout = time.Duration(httpConfig.Timeout * float64(time.Second)) + client.HTTPClient.Transport.(*http.Transport).TLSClientConfig = httpConfig.customTransport.TLSClientConfig + client.Logger = &logger.LeveledLogger{Sender: "RetryableHTTPClient"} + client.RetryWaitMin = time.Duration(httpConfig.RetryWaitMin) * time.Second + client.RetryWaitMax = time.Duration(httpConfig.RetryWaitMax) * time.Second + client.RetryMax = httpConfig.RetryMax + + return client +} + +// Get issues a GET to the specified URL +func Get(url string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + addHeaders(req, url) + client := GetHTTPClient() + defer client.CloseIdleConnections() + + return client.Do(req) +} + +// Post issues a POST to the specified URL +func Post(url string, contentType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest(http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + addHeaders(req, url) + client := GetHTTPClient() + defer client.CloseIdleConnections() + + return client.Do(req) +} + +// RetryableGet issues a GET to the specified URL using the retryable client +func RetryableGet(url string) (*http.Response, error) { + req, err := retryablehttp.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + addHeadersToRetryableReq(req, url) + client := GetRetraybleHTTPClient() + defer client.HTTPClient.CloseIdleConnections() + + return client.Do(req) +} + +// RetryablePost issues a POST to the specified URL using the retryable client +func RetryablePost(url string, contentType string, body io.Reader) (*http.Response, error) { + req, err := retryablehttp.NewRequest(http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + addHeadersToRetryableReq(req, url) + client := GetRetraybleHTTPClient() + defer client.HTTPClient.CloseIdleConnections() + + return client.Do(req) +} + +func addHeaders(req *http.Request, url string) { + for idx := range httpConfig.Headers { + h := &httpConfig.Headers[idx] + if h.URL == "" || strings.HasPrefix(url, h.URL) { + req.Header.Set(h.Key, h.Value) + } + } +} + +func addHeadersToRetryableReq(req *retryablehttp.Request, url string) { + for idx := range httpConfig.Headers { + h := &httpConfig.Headers[idx] + if h.URL == "" || strings.HasPrefix(url, h.URL) { + req.Header.Set(h.Key, h.Value) + } + } +} diff --git a/internal/httpd/api_admin.go b/internal/httpd/api_admin.go new file mode 100644 index 00000000..2fd093a0 --- /dev/null +++ b/internal/httpd/api_admin.go @@ -0,0 +1,337 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getAdmins(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + admins, err := dataprovider.GetAdmins(limit, offset, order) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + render.JSON(w, r, admins) +} + +func getAdminByUsername(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + username := getURLParam(r, "username") + renderAdmin(w, r, username, http.StatusOK) +} + +func renderAdmin(w http.ResponseWriter, r *http.Request, username string, status int) { + admin, err := dataprovider.AdminExists(username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + admin.HideConfidentialData() + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusCreated) + render.JSON(w, r.WithContext(ctx), admin) + } else { + render.JSON(w, r, admin) + } +} + +func addAdmin(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + admin := dataprovider.Admin{} + err = render.DecodeJSON(r.Body, &admin) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = dataprovider.AddAdmin(&admin, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", adminPath, url.PathEscape(admin.Username))) + renderAdmin(w, r, admin.Username, http.StatusCreated) +} + +func disableAdmin2FA(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + admin, err := dataprovider.AdminExists(getURLParam(r, "username")) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !admin.Filters.TOTPConfig.Enabled { + sendAPIResponse(w, r, nil, "two-factor authentication is not enabled", http.StatusBadRequest) + return + } + if admin.Username == claims.Username { + if admin.Filters.RequireTwoFactor { + err := util.NewValidationError("two-factor authentication must be enabled") + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + } + admin.Filters.RecoveryCodes = nil + admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: false, + } + if err := dataprovider.UpdateAdmin(&admin, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "2FA disabled", http.StatusOK) +} + +func updateAdmin(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + username := getURLParam(r, "username") + admin, err := dataprovider.AdminExists(username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedAdmin dataprovider.Admin + err = render.DecodeJSON(r.Body, &updatedAdmin) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + if username == claims.Username { + if claims.APIKeyID != "" { + sendAPIResponse(w, r, errors.New("updating the admin impersonated with an API key is not allowed"), "", + http.StatusBadRequest) + return + } + if !util.SlicesEqual(admin.Permissions, updatedAdmin.Permissions) { + sendAPIResponse(w, r, errors.New("you cannot change your permissions"), "", http.StatusBadRequest) + return + } + if updatedAdmin.Status == 0 { + sendAPIResponse(w, r, errors.New("you cannot disable yourself"), "", http.StatusBadRequest) + return + } + if updatedAdmin.Role != claims.Role { + sendAPIResponse(w, r, errors.New("you cannot add/change your role"), "", http.StatusBadRequest) + return + } + updatedAdmin.Filters.RequirePasswordChange = admin.Filters.RequirePasswordChange + updatedAdmin.Filters.RequireTwoFactor = admin.Filters.RequireTwoFactor + } + updatedAdmin.ID = admin.ID + updatedAdmin.Username = admin.Username + if updatedAdmin.Password == "" { + updatedAdmin.Password = admin.Password + } + updatedAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig + updatedAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes + err = dataprovider.UpdateAdmin(&updatedAdmin, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Admin updated", http.StatusOK) +} + +func deleteAdmin(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + username := getURLParam(r, "username") + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + if username == claims.Username { + sendAPIResponse(w, r, errors.New("you cannot delete yourself"), "", http.StatusBadRequest) + return + } + + err = dataprovider.DeleteAdmin(username, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Admin deleted", http.StatusOK) +} + +func getAdminProfile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + resp := adminProfile{ + baseProfile: baseProfile{ + Email: admin.Email, + Description: admin.Description, + AllowAPIKeyAuth: admin.Filters.AllowAPIKeyAuth, + }, + } + render.JSON(w, r, resp) +} + +func updateAdminProfile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + var req adminProfile + err = render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + admin.Email = req.Email + admin.Description = req.Description + admin.Filters.AllowAPIKeyAuth = req.AllowAPIKeyAuth + if err := dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Profile updated", http.StatusOK) +} + +func forgotAdminPassword(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + if !smtp.IsEnabled() { + sendAPIResponse(w, r, nil, "No SMTP configuration", http.StatusBadRequest) + return + } + + err := handleForgotPassword(r, getURLParam(r, "username"), true) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + sendAPIResponse(w, r, err, "Check your email for the confirmation code", http.StatusOK) +} + +func resetAdminPassword(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + var req pwdReset + err := render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + _, _, err = handleResetPassword(r, req.Code, req.Password, req.Password, true) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Password reset successful", http.StatusOK) +} + +func changeAdminPassword(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + var pwd pwdChange + err := render.DecodeJSON(r.Body, &pwd) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = doChangeAdminPassword(r, pwd.CurrentPassword, pwd.NewPassword, pwd.NewPassword) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + invalidateToken(r) + sendAPIResponse(w, r, err, "Password updated", http.StatusOK) +} + +func doChangeAdminPassword(r *http.Request, currentPassword, newPassword, confirmNewPassword string) error { + if currentPassword == "" || newPassword == "" || confirmNewPassword == "" { + return util.NewI18nError( + util.NewValidationError("please provide the current password and the new one two times"), + util.I18nErrorChangePwdRequiredFields, + ) + } + if newPassword != confirmNewPassword { + return util.NewI18nError(util.NewValidationError("the two password fields do not match"), util.I18nErrorChangePwdNoMatch) + } + if currentPassword == newPassword { + return util.NewI18nError( + util.NewValidationError("the new password must be different from the current one"), + util.I18nErrorChangePwdNoDifferent, + ) + } + claims, err := jwt.FromContext(r.Context()) + if err != nil { + return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken) + } + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + return err + } + match, err := admin.CheckPassword(currentPassword) + if !match || err != nil { + return util.NewI18nError(util.NewValidationError("current password does not match"), util.I18nErrorChangePwdCurrentNoMatch) + } + + admin.Password = newPassword + admin.Filters.RequirePasswordChange = false + + return dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) +} diff --git a/internal/httpd/api_configs.go b/internal/httpd/api_configs.go new file mode 100644 index 00000000..520f81a9 --- /dev/null +++ b/internal/httpd/api_configs.go @@ -0,0 +1,125 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "net/http" + + "github.com/go-chi/render" + "github.com/rs/xid" + "golang.org/x/oauth2" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type smtpTestRequest struct { + smtp.Config + Recipient string `json:"recipient"` +} + +func (r *smtpTestRequest) hasRedactedSecret() bool { + return r.Password == redactedSecret || r.OAuth2.ClientSecret == redactedSecret || r.OAuth2.RefreshToken == redactedSecret +} + +func testSMTPConfig(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + var req smtpTestRequest + err := render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if req.hasRedactedSecret() { + configs, err := dataprovider.GetConfigs() + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + configs.SetNilsToEmpty() + if err := configs.SMTP.TryDecrypt(); err == nil { + if req.Password == redactedSecret { + req.Password = configs.SMTP.Password.GetPayload() + } + if req.OAuth2.ClientSecret == redactedSecret { + req.OAuth2.ClientSecret = configs.SMTP.OAuth2.ClientSecret.GetPayload() + } + if req.OAuth2.RefreshToken == redactedSecret { + req.OAuth2.RefreshToken = configs.SMTP.OAuth2.RefreshToken.GetPayload() + } + } + } + if req.AuthType == 3 { + if err := req.OAuth2.Validate(); err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + } + if err := req.SendEmail([]string{req.Recipient}, nil, "SFTPGo - Testing Email Settings", + "It appears your SFTPGo email is setup correctly!", smtp.EmailContentTypeTextPlain); err != nil { + logger.Info(logSender, "", "unable to send test email: %v", err) + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + sendAPIResponse(w, r, nil, "SMTP connection OK", http.StatusOK) +} + +type oauth2TokenRequest struct { + smtp.OAuth2Config + BaseRedirectURL string `json:"base_redirect_url"` +} + +func (s *httpdServer) handleSMTPOAuth2TokenRequestPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + var req oauth2TokenRequest + err := render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if req.BaseRedirectURL == "" { + sendAPIResponse(w, r, nil, "base redirect url is required", http.StatusBadRequest) + return + } + if req.ClientSecret == redactedSecret { + configs, err := dataprovider.GetConfigs() + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + configs.SetNilsToEmpty() + if err := configs.SMTP.TryDecrypt(); err == nil { + req.ClientSecret = configs.SMTP.OAuth2.ClientSecret.GetPayload() + } + } + cfg := req.GetOAuth2() + cfg.RedirectURL = req.BaseRedirectURL + webOAuth2RedirectPath + clientSecret := kms.NewPlainSecret(cfg.ClientSecret) + clientSecret.SetAdditionalData(xid.New().String()) + pendingAuth := newOAuth2PendingAuth(req.Provider, cfg.RedirectURL, cfg.ClientID, clientSecret) + oauth2Mgr.addPendingAuth(pendingAuth) + stateToken := createOAuth2Token(s.csrfTokenAuth, pendingAuth.State, util.GetIPFromRemoteAddress(r.RemoteAddr)) + if stateToken == "" { + sendAPIResponse(w, r, nil, "unable to create state token", http.StatusInternalServerError) + return + } + u := cfg.AuthCodeURL(stateToken, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(pendingAuth.Verifier)) + sendAPIResponse(w, r, nil, u, http.StatusOK) +} diff --git a/internal/httpd/api_defender.go b/internal/httpd/api_defender.go new file mode 100644 index 00000000..8da5c3df --- /dev/null +++ b/internal/httpd/api_defender.go @@ -0,0 +1,93 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/hex" + "errors" + "fmt" + "net" + "net/http" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getDefenderHosts(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + hosts, err := common.GetDefenderHosts() + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if hosts == nil { + render.JSON(w, r, make([]dataprovider.DefenderEntry, 0)) + return + } + render.JSON(w, r, hosts) +} + +func getDefenderHostByID(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + ip, err := getIPFromID(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + host, err := common.GetDefenderHost(ip) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + render.JSON(w, r, host) +} + +func deleteDefenderHostByID(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + ip, err := getIPFromID(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if !common.DeleteDefenderHost(ip) { + sendAPIResponse(w, r, nil, "Not found", http.StatusNotFound) + return + } + + sendAPIResponse(w, r, nil, "OK", http.StatusOK) +} + +func getIPFromID(r *http.Request) (string, error) { + decoded, err := hex.DecodeString(getURLParam(r, "id")) + if err != nil { + return "", errors.New("invalid host id") + } + ip := util.BytesToString(decoded) + err = validateIPAddress(ip) + if err != nil { + return "", err + } + return ip, nil +} + +func validateIPAddress(ip string) error { + if net.ParseIP(ip) == nil { + return fmt.Errorf("ip address %q is not valid", ip) + } + return nil +} diff --git a/internal/httpd/api_eventrule.go b/internal/httpd/api_eventrule.go new file mode 100644 index 00000000..b8474af4 --- /dev/null +++ b/internal/httpd/api_eventrule.go @@ -0,0 +1,276 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getEventActions(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + actions, err := dataprovider.GetEventActions(limit, offset, order, false) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + render.JSON(w, r, actions) +} + +func renderEventAction(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { + action, err := dataprovider.EventActionExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if hideConfidentialData(claims, r) { + action.PrepareForRendering() + } + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) + render.JSON(w, r.WithContext(ctx), action) + } else { + render.JSON(w, r, action) + } +} + +func getEventActionByName(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + renderEventAction(w, r, name, claims, http.StatusOK) +} + +func addEventAction(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var action dataprovider.BaseEventAction + err = render.DecodeJSON(r.Body, &action) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + err = dataprovider.AddEventAction(&action, claims.Username, ipAddr, claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", eventActionsPath, url.PathEscape(action.Name))) + renderEventAction(w, r, action.Name, claims, http.StatusCreated) +} + +func updateEventAction(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + name := getURLParam(r, "name") + action, err := dataprovider.EventActionExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedAction dataprovider.BaseEventAction + err = render.DecodeJSON(r.Body, &updatedAction) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + updatedAction.ID = action.ID + updatedAction.Name = action.Name + updatedAction.Options.SetEmptySecretsIfNil() + + switch updatedAction.Type { + case dataprovider.ActionTypeHTTP: + if updatedAction.Options.HTTPConfig.Password.IsNotPlainAndNotEmpty() { + updatedAction.Options.HTTPConfig.Password = action.Options.HTTPConfig.Password + } + } + + err = dataprovider.UpdateEventAction(&updatedAction, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Event action updated", http.StatusOK) +} + +func deleteEventAction(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + err = dataprovider.DeleteEventAction(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Event action deleted", http.StatusOK) +} + +func getEventRules(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + rules, err := dataprovider.GetEventRules(limit, offset, order) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + render.JSON(w, r, rules) +} + +func renderEventRule(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { + rule, err := dataprovider.EventRuleExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if hideConfidentialData(claims, r) { + rule.PrepareForRendering() + } + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) + render.JSON(w, r.WithContext(ctx), rule) + } else { + render.JSON(w, r, rule) + } +} + +func getEventRuleByName(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + renderEventRule(w, r, name, claims, http.StatusOK) +} + +func addEventRule(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var rule dataprovider.EventRule + err = render.DecodeJSON(r.Body, &rule) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := dataprovider.AddEventRule(&rule, claims.Username, ipAddr, claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", eventRulesPath, url.PathEscape(rule.Name))) + renderEventRule(w, r, rule.Name, claims, http.StatusCreated) +} + +func updateEventRule(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + rule, err := dataprovider.EventRuleExists(getURLParam(r, "name")) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedRule dataprovider.EventRule + err = render.DecodeJSON(r.Body, &updatedRule) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + updatedRule.ID = rule.ID + updatedRule.Name = rule.Name + + err = dataprovider.UpdateEventRule(&updatedRule, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Event rules updated", http.StatusOK) +} + +func deleteEventRule(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + err = dataprovider.DeleteEventRule(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Event rule deleted", http.StatusOK) +} + +func runOnDemandRule(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + name := getURLParam(r, "name") + if err := common.RunOnDemandRule(name); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Event rule started", http.StatusAccepted) +} diff --git a/internal/httpd/api_events.go b/internal/httpd/api_events.go new file mode 100644 index 00000000..a5ad1a5c --- /dev/null +++ b/internal/httpd/api_events.go @@ -0,0 +1,481 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/csv" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/sftpgo/sdk/plugin/eventsearcher" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getCommonSearchParamsFromRequest(r *http.Request) (eventsearcher.CommonSearchParams, error) { + c := eventsearcher.CommonSearchParams{} + c.Limit = 100 + + if _, ok := r.URL.Query()["limit"]; ok { + limit, err := strconv.Atoi(r.URL.Query().Get("limit")) + if err != nil { + return c, util.NewValidationError(fmt.Sprintf("invalid limit: %v", err)) + } + if limit < 1 || limit > 1000 { + return c, util.NewValidationError(fmt.Sprintf("limit is out of the 1-1000 range: %v", limit)) + } + c.Limit = limit + } + if _, ok := r.URL.Query()["order"]; ok { + order := r.URL.Query().Get("order") + if order != dataprovider.OrderASC && order != dataprovider.OrderDESC { + return c, util.NewValidationError(fmt.Sprintf("invalid order %q", order)) + } + if order == dataprovider.OrderASC { + c.Order = 1 + } + } + if _, ok := r.URL.Query()["start_timestamp"]; ok { + ts, err := strconv.ParseInt(r.URL.Query().Get("start_timestamp"), 10, 64) + if err != nil { + return c, util.NewValidationError(fmt.Sprintf("invalid start_timestamp: %v", err)) + } + c.StartTimestamp = ts + } + if _, ok := r.URL.Query()["end_timestamp"]; ok { + ts, err := strconv.ParseInt(r.URL.Query().Get("end_timestamp"), 10, 64) + if err != nil { + return c, util.NewValidationError(fmt.Sprintf("invalid end_timestamp: %v", err)) + } + c.EndTimestamp = ts + } + c.Username = strings.TrimSpace(r.URL.Query().Get("username")) + c.IP = strings.TrimSpace(r.URL.Query().Get("ip")) + c.InstanceIDs = getCommaSeparatedQueryParam(r, "instance_ids") + c.FromID = r.URL.Query().Get("from_id") + + return c, nil +} + +func getFsSearchParamsFromRequest(r *http.Request) (eventsearcher.FsEventSearch, error) { + var err error + s := eventsearcher.FsEventSearch{} + s.CommonSearchParams, err = getCommonSearchParamsFromRequest(r) + if err != nil { + return s, err + } + s.FsProvider = -1 + if _, ok := r.URL.Query()["fs_provider"]; ok { + provider := r.URL.Query().Get("fs_provider") + val, err := strconv.Atoi(provider) + if err != nil { + return s, util.NewValidationError(fmt.Sprintf("invalid fs_provider: %v", provider)) + } + s.FsProvider = val + } + s.Actions = getCommaSeparatedQueryParam(r, "actions") + s.SSHCmd = strings.TrimSpace(r.URL.Query().Get("ssh_cmd")) + s.Bucket = strings.TrimSpace(r.URL.Query().Get("bucket")) + s.Endpoint = strings.TrimSpace(r.URL.Query().Get("endpoint")) + s.Protocols = getCommaSeparatedQueryParam(r, "protocols") + statuses := getCommaSeparatedQueryParam(r, "statuses") + for _, status := range statuses { + val, err := strconv.ParseInt(status, 10, 32) + if err != nil { + return s, util.NewValidationError(fmt.Sprintf("invalid status: %v", status)) + } + s.Statuses = append(s.Statuses, int32(val)) + } + + return s, nil +} + +func getProviderSearchParamsFromRequest(r *http.Request) (eventsearcher.ProviderEventSearch, error) { + var err error + s := eventsearcher.ProviderEventSearch{} + s.CommonSearchParams, err = getCommonSearchParamsFromRequest(r) + if err != nil { + return s, err + } + s.Actions = getCommaSeparatedQueryParam(r, "actions") + s.ObjectName = strings.TrimSpace(r.URL.Query().Get("object_name")) + s.ObjectTypes = getCommaSeparatedQueryParam(r, "object_types") + return s, nil +} + +func getLogSearchParamsFromRequest(r *http.Request) (eventsearcher.LogEventSearch, error) { + var err error + s := eventsearcher.LogEventSearch{} + s.CommonSearchParams, err = getCommonSearchParamsFromRequest(r) + if err != nil { + return s, err + } + s.Protocols = getCommaSeparatedQueryParam(r, "protocols") + events := getCommaSeparatedQueryParam(r, "events") + for _, ev := range events { + evType, err := strconv.ParseInt(ev, 10, 32) + if err == nil { + s.Events = append(s.Events, int32(evType)) + } + } + + return s, nil +} + +func searchFsEvents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + filters, err := getFsSearchParamsFromRequest(r) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + filters.Role = getRoleFilterForEventSearch(r, claims.Role) + + if getBoolQueryParam(r, "csv_export") { + filters.Limit = 100 + if err := exportFsEvents(w, &filters); err != nil { + panic(http.ErrAbortHandler) + } + return + } + + data, err := plugin.Handler.SearchFsEvents(&filters) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(data) //nolint:errcheck +} + +func searchProviderEvents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + var filters eventsearcher.ProviderEventSearch + if filters, err = getProviderSearchParamsFromRequest(r); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + filters.Role = getRoleFilterForEventSearch(r, claims.Role) + filters.OmitObjectData = getBoolQueryParam(r, "omit_object_data") + + if getBoolQueryParam(r, "csv_export") { + filters.Limit = 100 + filters.OmitObjectData = true + if err := exportProviderEvents(w, &filters); err != nil { + panic(http.ErrAbortHandler) + } + return + } + + data, err := plugin.Handler.SearchProviderEvents(&filters) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(data) //nolint:errcheck +} + +func searchLogEvents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + var filters eventsearcher.LogEventSearch + if filters, err = getLogSearchParamsFromRequest(r); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + filters.Role = getRoleFilterForEventSearch(r, claims.Role) + + if getBoolQueryParam(r, "csv_export") { + filters.Limit = 100 + if err := exportLogEvents(w, &filters); err != nil { + panic(http.ErrAbortHandler) + } + return + } + + data, err := plugin.Handler.SearchLogEvents(&filters) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(data) //nolint:errcheck +} + +func exportFsEvents(w http.ResponseWriter, filters *eventsearcher.FsEventSearch) error { + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=fslogs-%s.csv", time.Now().Format("2006-01-02T15-04-05"))) + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Accept-Ranges", "none") + w.WriteHeader(http.StatusOK) + + csvWriter := csv.NewWriter(w) + ev := fsEvent{} + err := csvWriter.Write(ev.getCSVHeader()) + if err != nil { + return err + } + results := make([]fsEvent, 0, filters.Limit) + for { + data, err := plugin.Handler.SearchFsEvents(filters) + if err != nil { + return err + } + if err := json.Unmarshal(data, &results); err != nil { + return err + } + for _, event := range results { + if err := csvWriter.Write(event.getCSVData()); err != nil { + return err + } + } + if len(results) == 0 || len(results) < filters.Limit { + break + } + filters.StartTimestamp = results[len(results)-1].Timestamp + filters.FromID = results[len(results)-1].ID + results = nil + } + csvWriter.Flush() + return csvWriter.Error() +} + +func exportProviderEvents(w http.ResponseWriter, filters *eventsearcher.ProviderEventSearch) error { + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=providerlogs-%s.csv", time.Now().Format("2006-01-02T15-04-05"))) + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Accept-Ranges", "none") + w.WriteHeader(http.StatusOK) + + ev := providerEvent{} + csvWriter := csv.NewWriter(w) + err := csvWriter.Write(ev.getCSVHeader()) + if err != nil { + return err + } + results := make([]providerEvent, 0, filters.Limit) + for { + data, err := plugin.Handler.SearchProviderEvents(filters) + if err != nil { + return err + } + if err := json.Unmarshal(data, &results); err != nil { + return err + } + for _, event := range results { + if err := csvWriter.Write(event.getCSVData()); err != nil { + return err + } + } + if len(results) < filters.Limit || len(results) == 0 { + break + } + filters.FromID = results[len(results)-1].ID + filters.StartTimestamp = results[len(results)-1].Timestamp + results = nil + } + csvWriter.Flush() + return csvWriter.Error() +} + +func exportLogEvents(w http.ResponseWriter, filters *eventsearcher.LogEventSearch) error { + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", time.Now().Format("2006-01-02T15-04-05"))) + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Accept-Ranges", "none") + w.WriteHeader(http.StatusOK) + + ev := logEvent{} + csvWriter := csv.NewWriter(w) + err := csvWriter.Write(ev.getCSVHeader()) + if err != nil { + return err + } + results := make([]logEvent, 0, filters.Limit) + for { + data, err := plugin.Handler.SearchLogEvents(filters) + if err != nil { + return err + } + if err := json.Unmarshal(data, &results); err != nil { + return err + } + for _, event := range results { + if err := csvWriter.Write(event.getCSVData()); err != nil { + return err + } + } + if len(results) == 0 || len(results) < filters.Limit { + break + } + filters.StartTimestamp = results[len(results)-1].Timestamp + filters.FromID = results[len(results)-1].ID + results = nil + } + csvWriter.Flush() + return csvWriter.Error() +} + +func getRoleFilterForEventSearch(r *http.Request, defaultValue string) string { + if defaultValue != "" { + return defaultValue + } + return r.URL.Query().Get("role") +} + +type fsEvent struct { + ID string `json:"id"` + Timestamp int64 `json:"timestamp"` + Action string `json:"action"` + Username string `json:"username"` + FsPath string `json:"fs_path"` + FsTargetPath string `json:"fs_target_path,omitempty"` + VirtualPath string `json:"virtual_path"` + VirtualTargetPath string `json:"virtual_target_path,omitempty"` + SSHCmd string `json:"ssh_cmd,omitempty"` + FileSize int64 `json:"file_size,omitempty"` + Elapsed int64 `json:"elapsed,omitempty"` + Status int `json:"status"` + Protocol string `json:"protocol"` + IP string `json:"ip,omitempty"` + SessionID string `json:"session_id"` + FsProvider int `json:"fs_provider"` + Bucket string `json:"bucket,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + OpenFlags int `json:"open_flags,omitempty"` + Role string `json:"role,omitempty"` + InstanceID string `json:"instance_id,omitempty"` +} + +func (e *fsEvent) getCSVHeader() []string { + return []string{"Time", "Action", "Path", "Size", "Elapsed", "Status", "User", "Protocol", + "IP", "SSH command"} +} + +func (e *fsEvent) getCSVData() []string { + timestamp := time.Unix(0, e.Timestamp).UTC() + var pathInfo strings.Builder + pathInfo.Write([]byte(e.VirtualPath)) + if e.VirtualTargetPath != "" { + pathInfo.WriteString(" => ") + pathInfo.WriteString(e.VirtualTargetPath) + } + var status string + switch e.Status { + case 1: + status = "OK" + case 2: + status = "KO" + case 3: + status = "Quota exceeded" + } + var fileSize string + if e.FileSize > 0 { + fileSize = util.ByteCountIEC(e.FileSize) + } + var elapsed string + if e.Elapsed > 0 { + elapsed = (time.Duration(e.Elapsed) * time.Millisecond).String() + } + return []string{timestamp.Format(time.RFC3339Nano), e.Action, pathInfo.String(), + fileSize, elapsed, status, e.Username, e.Protocol, e.IP, e.SSHCmd} +} + +type providerEvent struct { + ID string `json:"id"` + Timestamp int64 `json:"timestamp"` + Action string `json:"action"` + Username string `json:"username"` + IP string `json:"ip,omitempty"` + ObjectType string `json:"object_type"` + ObjectName string `json:"object_name"` + ObjectData []byte `json:"object_data"` + Role string `json:"role,omitempty"` + InstanceID string `json:"instance_id,omitempty"` +} + +func (e *providerEvent) getCSVHeader() []string { + return []string{"Time", "Action", "Object Type", "Object Name", "User", "IP"} +} + +func (e *providerEvent) getCSVData() []string { + timestamp := time.Unix(0, e.Timestamp).UTC() + return []string{timestamp.Format(time.RFC3339Nano), e.Action, e.ObjectType, e.ObjectName, + e.Username, e.IP} +} + +type logEvent struct { + ID string `json:"id"` + Timestamp int64 `json:"timestamp"` + Event int `json:"event"` + Protocol string `json:"protocol"` + Username string `json:"username,omitempty"` + IP string `json:"ip,omitempty"` + Message string `json:"message,omitempty"` + Role string `json:"role,omitempty"` +} + +func (e *logEvent) getCSVHeader() []string { + return []string{"Time", "Event", "Protocol", "User", "IP", "Message"} +} + +func (e *logEvent) getCSVData() []string { + timestamp := time.Unix(0, e.Timestamp).UTC() + return []string{timestamp.Format(time.RFC3339Nano), getLogEventString(notifier.LogEventType(e.Event)), + e.Protocol, e.Username, e.IP, e.Message} +} + +func getLogEventString(event notifier.LogEventType) string { + switch event { + case notifier.LogEventTypeLoginFailed: + return "Login failed" + case notifier.LogEventTypeLoginNoUser: + return "Login with non-existent user" + case notifier.LogEventTypeNoLoginTried: + return "No login tried" + case notifier.LogEventTypeNotNegotiated: + return "Algorithm negotiation failed" + case notifier.LogEventTypeLoginOK: + return "Login succeeded" + default: + return "" + } +} diff --git a/internal/httpd/api_folder.go b/internal/httpd/api_folder.go new file mode 100644 index 00000000..c46d4ab3 --- /dev/null +++ b/internal/httpd/api_folder.go @@ -0,0 +1,146 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func getFolders(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + folders, err := dataprovider.GetFolders(limit, offset, order, false) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + render.JSON(w, r, folders) +} + +func addFolder(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + var folder vfs.BaseVirtualFolder + err = render.DecodeJSON(r.Body, &folder) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if err := dataprovider.AddFolder(&folder, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", folderPath, url.PathEscape(folder.Name))) + renderFolder(w, r, folder.Name, claims, http.StatusCreated) +} + +func updateFolder(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + name := getURLParam(r, "name") + folder, err := dataprovider.GetFolderByName(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedFolder vfs.BaseVirtualFolder + err = render.DecodeJSON(r.Body, &updatedFolder) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + updatedFolder.ID = folder.ID + updatedFolder.Name = folder.Name + updatedFolder.FsConfig.SetEmptySecretsIfNil() + updateEncryptedSecrets(&updatedFolder.FsConfig, &folder.FsConfig) + + err = dataprovider.UpdateFolder(&updatedFolder, folder.Users, folder.Groups, claims.Username, + util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Folder updated", http.StatusOK) +} + +func renderFolder(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { + folder, err := dataprovider.GetFolderByName(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if hideConfidentialData(claims, r) { + folder.PrepareForRendering() + } + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) + render.JSON(w, r.WithContext(ctx), folder) + } else { + render.JSON(w, r, folder) + } +} + +func getFolderByName(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + renderFolder(w, r, name, claims, http.StatusOK) +} + +func deleteFolder(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + err = dataprovider.DeleteFolder(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Folder deleted", http.StatusOK) +} diff --git a/internal/httpd/api_group.go b/internal/httpd/api_group.go new file mode 100644 index 00000000..beae9ad2 --- /dev/null +++ b/internal/httpd/api_group.go @@ -0,0 +1,144 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getGroups(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + groups, err := dataprovider.GetGroups(limit, offset, order, false) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + render.JSON(w, r, groups) +} + +func addGroup(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var group dataprovider.Group + err = render.DecodeJSON(r.Body, &group) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = dataprovider.AddGroup(&group, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", groupPath, url.PathEscape(group.Name))) + renderGroup(w, r, group.Name, claims, http.StatusCreated) +} + +func updateGroup(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + name := getURLParam(r, "name") + group, err := dataprovider.GroupExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedGroup dataprovider.Group + err = render.DecodeJSON(r.Body, &updatedGroup) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + updatedGroup.ID = group.ID + updatedGroup.Name = group.Name + updatedGroup.UserSettings.FsConfig.SetEmptySecretsIfNil() + updateEncryptedSecrets(&updatedGroup.UserSettings.FsConfig, &group.UserSettings.FsConfig) + err = dataprovider.UpdateGroup(&updatedGroup, group.Users, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), + claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Group updated", http.StatusOK) +} + +func renderGroup(w http.ResponseWriter, r *http.Request, name string, claims *jwt.Claims, status int) { + group, err := dataprovider.GroupExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if hideConfidentialData(claims, r) { + group.PrepareForRendering() + } + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) + render.JSON(w, r.WithContext(ctx), group) + } else { + render.JSON(w, r, group) + } +} + +func getGroupByName(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + renderGroup(w, r, name, claims, http.StatusOK) +} + +func deleteGroup(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + err = dataprovider.DeleteGroup(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Group deleted", http.StatusOK) +} diff --git a/internal/httpd/api_http_user.go b/internal/httpd/api_http_user.go new file mode 100644 index 00000000..45bbd52f --- /dev/null +++ b/internal/httpd/api_http_user.go @@ -0,0 +1,593 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path" + "strconv" + "strings" + + "github.com/go-chi/render" + "github.com/rs/xid" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, error) { + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return nil, fmt.Errorf("invalid token claims %w", err) + } + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, nil, "Unable to retrieve your user", getRespStatus(err)) + return nil, err + } + connID := xid.New().String() + protocol := getProtocolFromRequest(r) + connectionID := fmt.Sprintf("%v_%v", protocol, connID) + if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { + sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return nil, err + } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return connection, err + } + return connection, nil +} + +func readUserFolder(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + lister, err := connection.ReadDir(name) + if err != nil { + sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err)) + return + } + renderAPIDirContents(w, lister, false) +} + +func createUserDir(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + if getBoolQueryParam(r, "mkdir_parents") { + if err = connection.CheckParentDirs(path.Dir(name)); err != nil { + sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) + return + } + } + err = connection.CreateDir(name, true) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to create directory %q", name), getMappedStatusCode(err)) + return + } + sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %q created", name), http.StatusCreated) +} + +func deleteUserDir(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + err = connection.RemoveAll(name) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete directory %q", name), getMappedStatusCode(err)) + return + } + sendAPIResponse(w, r, nil, fmt.Sprintf("Directory %q deleted", name), http.StatusOK) +} + +func renameUserFsEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) + if !connection.IsSameResource(oldName, newName) { + if err := connection.Copy(oldName, newName); err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Cannot perform copy step to rename %q -> %q", oldName, newName), + getMappedStatusCode(err)) + return + } + if err := connection.RemoveAll(oldName); err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Cannot perform remove step to rename %q -> %q", oldName, newName), + getMappedStatusCode(err)) + return + } + } else { + if err := connection.Rename(oldName, newName); err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to rename %q => %q", oldName, newName), + getMappedStatusCode(err)) + return + } + } + sendAPIResponse(w, r, nil, fmt.Sprintf("%q renamed to %q", oldName, newName), http.StatusOK) +} + +func copyUserFsEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + source := r.URL.Query().Get("path") + target := r.URL.Query().Get("target") + copyFromSource := strings.HasSuffix(source, "/") + copyInTarget := strings.HasSuffix(target, "/") + source = connection.User.GetCleanedPath(source) + target = connection.User.GetCleanedPath(target) + if copyFromSource { + source += "/" + } + if copyInTarget { + target += "/" + } + err = connection.Copy(source, target) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to copy %q => %q", source, target), + getMappedStatusCode(err)) + return + } + sendAPIResponse(w, r, nil, fmt.Sprintf("%q copied to %q", source, target), http.StatusOK) +} + +func getUserFile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + if name == "/" { + sendAPIResponse(w, r, nil, "Please set the path to a valid file", http.StatusBadRequest) + return + } + info, err := connection.Stat(name, 0) + if err != nil { + sendAPIResponse(w, r, err, "Unable to stat the requested file", getMappedStatusCode(err)) + return + } + if info.IsDir() { + sendAPIResponse(w, r, nil, fmt.Sprintf("Please set the path to a valid file, %q is a directory", name), http.StatusBadRequest) + return + } + + inline := r.URL.Query().Get("inline") != "" + if status, err := downloadFile(w, r, connection, name, info, inline, nil); err != nil { + resp := apiResponse{ + Error: err.Error(), + Message: http.StatusText(status), + } + ctx := r.Context() + if status != 0 { + ctx = context.WithValue(ctx, render.StatusCtxKey, status) + } + render.JSON(w, r.WithContext(ctx), resp) + } +} + +func setFileDirMetadata(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + metadata := make(map[string]int64) + err := render.DecodeJSON(r.Body, &metadata) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + mTime, ok := metadata["modification_time"] + if !ok || !r.URL.Query().Has("path") { + sendAPIResponse(w, r, errors.New("please set a modification_time and a path"), "", http.StatusBadRequest) + return + } + + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + attrs := common.StatAttributes{ + Flags: common.StatAttrTimes, + Atime: util.GetTimeFromMsecSinceEpoch(mTime), + Mtime: util.GetTimeFromMsecSinceEpoch(mTime), + } + err = connection.SetStat(name, &attrs) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to set metadata for path %q", name), getMappedStatusCode(err)) + return + } + sendAPIResponse(w, r, nil, "OK", http.StatusOK) +} + +func uploadUserFile(w http.ResponseWriter, r *http.Request) { + if maxUploadFileSize > 0 { + r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) + } + + if !r.URL.Query().Has("path") { + sendAPIResponse(w, r, errors.New("please set a file path"), "", http.StatusBadRequest) + return + } + + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + filePath := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + if getBoolQueryParam(r, "mkdir_parents") { + if err = connection.CheckParentDirs(path.Dir(filePath)); err != nil { + sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) + return + } + } + doUploadFile(w, r, connection, filePath) //nolint:errcheck +} + +func doUploadFile(w http.ResponseWriter, r *http.Request, connection *Connection, filePath string) error { + writer, err := connection.getFileWriter(filePath) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to write file %q", filePath), getMappedStatusCode(err)) + return err + } + _, err = io.Copy(writer, r.Body) + if err != nil { + writer.Close() //nolint:errcheck + sendAPIResponse(w, r, err, fmt.Sprintf("Error saving file %q", filePath), getMappedStatusCode(err)) + return err + } + err = writer.Close() + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Error closing file %q", filePath), getMappedStatusCode(err)) + return err + } + setModificationTimeFromHeader(r, connection, filePath) + sendAPIResponse(w, r, nil, "Upload completed", http.StatusCreated) + return nil +} + +func uploadUserFiles(w http.ResponseWriter, r *http.Request) { + if maxUploadFileSize > 0 { + r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) + } + + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil { + connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits") + sendAPIResponse(w, r, err, "Denying file write due to transfer count limits", + http.StatusConflict) + return + } + + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasUploadSpace() { + connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") + sendAPIResponse(w, r, common.ErrQuotaExceeded, "Denying file write due to transfer quota limits", + http.StatusRequestEntityTooLarge) + return + } + + t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) + r.Body = t + err = r.ParseMultipartForm(maxMultipartMem) + if err != nil { + connection.RemoveTransfer(t) + sendAPIResponse(w, r, err, "Unable to parse multipart form", http.StatusBadRequest) + return + } + connection.RemoveTransfer(t) + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + parentDir := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + files := r.MultipartForm.File["filenames"] + if len(files) == 0 { + sendAPIResponse(w, r, nil, "No files uploaded!", http.StatusBadRequest) + return + } + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + if getBoolQueryParam(r, "mkdir_parents") { + if err = connection.CheckParentDirs(parentDir); err != nil { + sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) + return + } + } + doUploadFiles(w, r, connection, parentDir, files) +} + +func doUploadFiles(w http.ResponseWriter, r *http.Request, connection *Connection, parentDir string, + files []*multipart.FileHeader, +) int { + uploaded := 0 + connection.User.UploadBandwidth = 0 + for _, f := range files { + file, err := f.Open() + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to read uploaded file %q", f.Filename), getMappedStatusCode(err)) + return uploaded + } + defer file.Close() + + filePath := path.Join(parentDir, path.Base(util.CleanPath(f.Filename))) + writer, err := connection.getFileWriter(filePath) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to write file %q", f.Filename), getMappedStatusCode(err)) + return uploaded + } + _, err = io.Copy(writer, file) + if err != nil { + writer.Close() //nolint:errcheck + sendAPIResponse(w, r, err, fmt.Sprintf("Error saving file %q", f.Filename), getMappedStatusCode(err)) + return uploaded + } + err = writer.Close() + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Error closing file %q", f.Filename), getMappedStatusCode(err)) + return uploaded + } + uploaded++ + } + sendAPIResponse(w, r, nil, "Upload completed", http.StatusCreated) + return uploaded +} + +func deleteUserFile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + fs, p, err := connection.GetFsAndResolvedPath(name) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete file %q", name), getMappedStatusCode(err)) + return + } + + var fi os.FileInfo + if fi, err = fs.Lstat(p); err != nil { + connection.Log(logger.LevelError, "failed to remove file %q: stat error: %+v", p, err) + err = connection.GetFsError(fs, err) + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete file %q", name), getMappedStatusCode(err)) + return + } + + if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { + connection.Log(logger.LevelDebug, "cannot remove %q is not a file/symlink", p) + sendAPIResponse(w, r, err, fmt.Sprintf("Unable delete %q, it is not a file/symlink", name), http.StatusBadRequest) + return + } + err = connection.RemoveFile(fs, p, name, fi) + if err != nil { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to delete file %q", name), getMappedStatusCode(err)) + return + } + sendAPIResponse(w, r, nil, fmt.Sprintf("File %q deleted", name), http.StatusOK) +} + +func getUserFilesAsZipStream(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + var filesList []string + err = render.DecodeJSON(r.Body, &filesList) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + baseDir := "/" + for idx := range filesList { + filesList[idx] = util.CleanPath(filesList[idx]) + } + + filesList = util.RemoveDuplicates(filesList, false) + + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", + getCompressedFileName(connection.GetUsername(), filesList))) + renderCompressedFiles(w, connection, baseDir, filesList, nil) +} + +func getUserProfile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + user, err := dataprovider.UserExists(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + resp := userProfile{ + baseProfile: baseProfile{ + Email: user.Email, + Description: user.Description, + AllowAPIKeyAuth: user.Filters.AllowAPIKeyAuth, + }, + AdditionalEmails: user.Filters.AdditionalEmails, + PublicKeys: user.PublicKeys, + TLSCerts: user.Filters.TLSCerts, + } + render.JSON(w, r, resp) +} + +func updateUserProfile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var req userProfile + err = render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, userMerged, err := dataprovider.GetUserVariants(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !userMerged.CanUpdateProfile() { + sendAPIResponse(w, r, nil, "You are not allowed to change anything", http.StatusForbidden) + return + } + if userMerged.CanManagePublicKeys() { + user.PublicKeys = req.PublicKeys + } + if userMerged.CanManageTLSCerts() { + user.Filters.TLSCerts = req.TLSCerts + } + if userMerged.CanChangeAPIKeyAuth() { + user.Filters.AllowAPIKeyAuth = req.AllowAPIKeyAuth + } + if userMerged.CanChangeInfo() { + user.Email = req.Email + user.Filters.AdditionalEmails = req.AdditionalEmails + user.Description = req.Description + } + if err := dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Profile updated", http.StatusOK) +} + +func changeUserPassword(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + var pwd pwdChange + err := render.DecodeJSON(r.Body, &pwd) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = doChangeUserPassword(r, pwd.CurrentPassword, pwd.NewPassword, pwd.NewPassword) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + invalidateToken(r) + sendAPIResponse(w, r, err, "Password updated", http.StatusOK) +} + +func doChangeUserPassword(r *http.Request, currentPassword, newPassword, confirmNewPassword string) error { + if currentPassword == "" || newPassword == "" || confirmNewPassword == "" { + return util.NewI18nError( + util.NewValidationError("please provide the current password and the new one two times"), + util.I18nErrorChangePwdRequiredFields, + ) + } + if newPassword != confirmNewPassword { + return util.NewI18nError(util.NewValidationError("the two password fields do not match"), util.I18nErrorChangePwdNoMatch) + } + if currentPassword == newPassword { + return util.NewI18nError( + util.NewValidationError("the new password must be different from the current one"), + util.I18nErrorChangePwdNoDifferent, + ) + } + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + return util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken) + } + _, err = dataprovider.CheckUserAndPass(claims.Username, currentPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), + getProtocolFromRequest(r)) + if err != nil { + return util.NewI18nError(util.NewValidationError("current password does not match"), util.I18nErrorChangePwdCurrentNoMatch) + } + + return dataprovider.UpdateUserPassword(claims.Username, newPassword, dataprovider.ActionExecutorSelf, + util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) +} + +func setModificationTimeFromHeader(r *http.Request, c *Connection, filePath string) { + mTimeString := r.Header.Get(mTimeHeader) + if mTimeString != "" { + // we don't return an error here if we fail to set the modification time + mTime, err := strconv.ParseInt(mTimeString, 10, 64) + if err == nil { + attrs := common.StatAttributes{ + Flags: common.StatAttrTimes, + Atime: util.GetTimeFromMsecSinceEpoch(mTime), + Mtime: util.GetTimeFromMsecSinceEpoch(mTime), + } + err = c.SetStat(filePath, &attrs) + c.Log(logger.LevelDebug, "requested modification time %v for file %q, error: %v", + attrs.Mtime, filePath, err) + } else { + c.Log(logger.LevelInfo, "invalid modification time header was ignored: %v", mTimeString) + } + } +} diff --git a/internal/httpd/api_iplist.go b/internal/httpd/api_iplist.go new file mode 100644 index 00000000..d74b68d8 --- /dev/null +++ b/internal/httpd/api_iplist.go @@ -0,0 +1,158 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "strconv" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getIPListEntries(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, _, order, err := getSearchFilters(w, r) + if err != nil { + return + } + listType, _, err := getIPListPathParams(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + entries, err := dataprovider.GetIPListEntries(listType, r.URL.Query().Get("filter"), r.URL.Query().Get("from"), + order, limit) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + render.JSON(w, r, entries) +} + +func getIPListEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + listType, ipOrNet, err := getIPListPathParams(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + render.JSON(w, r, entry) +} + +func addIPListEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var entry dataprovider.IPListEntry + err = render.DecodeJSON(r.Body, &entry) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = dataprovider.AddIPListEntry(&entry, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%d/%s", ipListsPath, entry.Type, url.PathEscape(entry.IPOrNet))) + sendAPIResponse(w, r, nil, "Entry added", http.StatusCreated) +} + +func updateIPListEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + listType, ipOrNet, err := getIPListPathParams(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + var updatedEntry dataprovider.IPListEntry + err = render.DecodeJSON(r.Body, &updatedEntry) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + updatedEntry.Type = entry.Type + updatedEntry.IPOrNet = entry.IPOrNet + err = dataprovider.UpdateIPListEntry(&updatedEntry, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Entry updated", http.StatusOK) +} + +func deleteIPListEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + listType, ipOrNet, err := getIPListPathParams(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = dataprovider.DeleteIPListEntry(ipOrNet, listType, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), + claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Entry deleted", http.StatusOK) +} + +func getIPListPathParams(r *http.Request) (dataprovider.IPListType, string, error) { + listTypeString := chi.URLParam(r, "type") + listType, err := strconv.Atoi(listTypeString) + if err != nil { + return dataprovider.IPListType(listType), "", errors.New("invalid list type") + } + if err := dataprovider.CheckIPListType(dataprovider.IPListType(listType)); err != nil { + return dataprovider.IPListType(listType), "", err + } + return dataprovider.IPListType(listType), getURLParam(r, "ipornet"), nil +} diff --git a/internal/httpd/api_keys.go b/internal/httpd/api_keys.go new file mode 100644 index 00000000..a9a3bcb2 --- /dev/null +++ b/internal/httpd/api_keys.go @@ -0,0 +1,135 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getAPIKeys(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + apiKeys, err := dataprovider.GetAPIKeys(limit, offset, order) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + render.JSON(w, r, apiKeys) +} + +func getAPIKeyByID(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + keyID := getURLParam(r, "id") + apiKey, err := dataprovider.APIKeyExists(keyID) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + apiKey.HideConfidentialData() + + render.JSON(w, r, apiKey) +} + +func addAPIKey(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var apiKey dataprovider.APIKey + err = render.DecodeJSON(r.Body, &apiKey) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + apiKey.ID = 0 + apiKey.KeyID = "" + apiKey.Key = "" + apiKey.LastUseAt = 0 + err = dataprovider.AddAPIKey(&apiKey, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + response := make(map[string]string) + response["message"] = "API key created. This is the only time the API key is visible, please save it." + response["key"] = apiKey.DisplayKey() + w.Header().Add("Location", fmt.Sprintf("%s/%s", apiKeysPath, url.PathEscape(apiKey.KeyID))) + w.Header().Add("X-Object-ID", apiKey.KeyID) + ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusCreated) + render.JSON(w, r.WithContext(ctx), response) +} + +func updateAPIKey(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + keyID := getURLParam(r, "id") + apiKey, err := dataprovider.APIKeyExists(keyID) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedAPIKey dataprovider.APIKey + err = render.DecodeJSON(r.Body, &updatedAPIKey) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + updatedAPIKey.KeyID = keyID + updatedAPIKey.Key = apiKey.Key + err = dataprovider.UpdateAPIKey(&updatedAPIKey, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "API key updated", http.StatusOK) +} + +func deleteAPIKey(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + keyID := getURLParam(r, "id") + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + err = dataprovider.DeleteAPIKey(keyID, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "API key deleted", http.StatusOK) +} diff --git a/internal/httpd/api_maintenance.go b/internal/httpd/api_maintenance.go new file mode 100644 index 00000000..560c702e --- /dev/null +++ b/internal/httpd/api_maintenance.go @@ -0,0 +1,563 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func validateBackupFile(outputFile string) (string, error) { + if outputFile == "" { + return "", errors.New("invalid or missing output-file") + } + if filepath.IsAbs(outputFile) { + return "", fmt.Errorf("invalid output-file %q: it must be a relative path", outputFile) + } + if strings.Contains(outputFile, "..") { + return "", fmt.Errorf("invalid output-file %q", outputFile) + } + outputFile = filepath.Join(dataprovider.GetBackupsPath(), outputFile) + return outputFile, nil +} + +func dumpData(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + var outputFile, outputData, indent string + var scopes []string + if _, ok := r.URL.Query()["output-file"]; ok { + outputFile = strings.TrimSpace(r.URL.Query().Get("output-file")) + } + if _, ok := r.URL.Query()["output-data"]; ok { + outputData = strings.TrimSpace(r.URL.Query().Get("output-data")) + } + if _, ok := r.URL.Query()["indent"]; ok { + indent = strings.TrimSpace(r.URL.Query().Get("indent")) + } + if _, ok := r.URL.Query()["scopes"]; ok { + scopes = getCommaSeparatedQueryParam(r, "scopes") + } + + if outputData != "1" { + var err error + outputFile, err = validateBackupFile(outputFile) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + err = os.MkdirAll(filepath.Dir(outputFile), 0700) + if err != nil { + logger.Error(logSender, "", "dumping data error: %v, output file: %q", err, outputFile) + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + logger.Debug(logSender, "", "dumping data to: %q", outputFile) + } + + backup, err := dataprovider.DumpData(scopes) + if err != nil { + logger.Error(logSender, "", "dumping data error: %v, output file: %q", err, outputFile) + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + if outputData == "1" { + w.Header().Set("Content-Disposition", "attachment; filename=\"sftpgo-backup.json\"") + render.JSON(w, r, backup) + return + } + + var dump []byte + if indent == "1" { + dump, err = json.MarshalIndent(backup, "", " ") + } else { + dump, err = json.Marshal(backup) + } + if err == nil { + err = os.WriteFile(outputFile, dump, 0600) + } + if err != nil { + logger.Warn(logSender, "", "dumping data error: %v, output file: %q", err, outputFile) + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + logger.Debug(logSender, "", "dumping data completed, output file: %q, error: %v", outputFile, err) + sendAPIResponse(w, r, err, "Data saved", http.StatusOK) +} + +func loadDataFromRequest(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + _, scanQuota, mode, err := getLoaddataOptions(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + content, err := io.ReadAll(r.Body) + if err != nil || len(content) == 0 { + if len(content) == 0 { + err = util.NewValidationError("request body is required") + } + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if err := restoreBackup(content, "", scanQuota, mode, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Data restored", http.StatusOK) +} + +func loadData(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + inputFile, scanQuota, mode, err := getLoaddataOptions(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if !filepath.IsAbs(inputFile) { + sendAPIResponse(w, r, fmt.Errorf("invalid input_file %q: it must be an absolute path", inputFile), "", + http.StatusBadRequest) + return + } + fi, err := os.Stat(inputFile) + if err != nil { + sendAPIResponse(w, r, fmt.Errorf("invalid input_file %q", inputFile), "", http.StatusBadRequest) + return + } + if fi.Size() > MaxRestoreSize { + sendAPIResponse(w, r, err, fmt.Sprintf("Unable to restore input file: %q size too big: %d/%d bytes", + inputFile, fi.Size(), MaxRestoreSize), http.StatusBadRequest) + return + } + + content, err := os.ReadFile(inputFile) + if err != nil { + sendAPIResponse(w, r, fmt.Errorf("invalid input_file %q", inputFile), "", http.StatusBadRequest) + return + } + if err := restoreBackup(content, inputFile, scanQuota, mode, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Data restored", http.StatusOK) +} + +func restoreBackup(content []byte, inputFile string, scanQuota, mode int, executor, ipAddress, role string) error { + dump, err := dataprovider.ParseDumpData(content) + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("invalid input_file %q", inputFile)), + util.I18nErrorBackupFile, + ) + } + + if err = RestoreConfigs(dump.Configs, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreIPListEntries(dump.IPLists, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreRoles(dump.Roles, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreFolders(dump.Folders, inputFile, mode, scanQuota, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreGroups(dump.Groups, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreUsers(dump.Users, inputFile, mode, scanQuota, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreAdmins(dump.Admins, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreAPIKeys(dump.APIKeys, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreShares(dump.Shares, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreEventActions(dump.EventActions, inputFile, mode, executor, ipAddress, role); err != nil { + return err + } + + if err = RestoreEventRules(dump.EventRules, inputFile, mode, executor, ipAddress, role, dump.Version); err != nil { + return err + } + logger.Debug(logSender, "", "backup restored") + + return nil +} + +func getLoaddataOptions(r *http.Request) (string, int, int, error) { + var inputFile string + var err error + scanQuota := 0 + restoreMode := 0 + if _, ok := r.URL.Query()["input-file"]; ok { + inputFile = strings.TrimSpace(r.URL.Query().Get("input-file")) + } + if _, ok := r.URL.Query()["scan-quota"]; ok { + scanQuota, err = strconv.Atoi(r.URL.Query().Get("scan-quota")) + if err != nil { + err = fmt.Errorf("invalid scan_quota: %v", err) + return inputFile, scanQuota, restoreMode, err + } + } + if _, ok := r.URL.Query()["mode"]; ok { + restoreMode, err = strconv.Atoi(r.URL.Query().Get("mode")) + if err != nil { + err = fmt.Errorf("invalid mode: %v", err) + return inputFile, scanQuota, restoreMode, err + } + } + return inputFile, scanQuota, restoreMode, err +} + +// RestoreFolders restores the specified folders +func RestoreFolders(folders []vfs.BaseVirtualFolder, inputFile string, mode, scanQuota int, executor, ipAddress, role string) error { + for idx := range folders { + folder := folders[idx] + f, err := dataprovider.GetFolderByName(folder.Name) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing folder %q not updated", folder.Name) + continue + } + folder.ID = f.ID + folder.Name = f.Name + err = dataprovider.UpdateFolder(&folder, f.Users, f.Groups, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring existing folder %q, dump file: %q, error: %v", folder.Name, inputFile, err) + } else { + folder.Users = nil + err = dataprovider.AddFolder(&folder, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new folder %q, dump file: %q, error: %v", folder.Name, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore folder %q: %w", folder.Name, err) + } + if scanQuota >= 1 { + if common.QuotaScans.AddVFolderQuotaScan(folder.Name) { + logger.Debug(logSender, "", "starting quota scan for restored folder: %q", folder.Name) + go doFolderQuotaScan(folder) //nolint:errcheck + } + } + } + return nil +} + +// RestoreShares restores the specified shares +func RestoreShares(shares []dataprovider.Share, inputFile string, mode int, executor, + ipAddress, role string, +) error { + for idx := range shares { + share := shares[idx] + share.IsRestore = true + s, err := dataprovider.ShareExists(share.ShareID, "") + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing share %q not updated", share.ShareID) + continue + } + share.ID = s.ID + err = dataprovider.UpdateShare(&share, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring existing share %q, dump file: %q, error: %v", share.ShareID, inputFile, err) + } else { + err = dataprovider.AddShare(&share, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new share %q, dump file: %q, error: %v", share.ShareID, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore share %q: %w", share.ShareID, err) + } + } + return nil +} + +// RestoreEventActions restores the specified event actions +func RestoreEventActions(actions []dataprovider.BaseEventAction, inputFile string, mode int, executor, ipAddress, role string) error { + for idx := range actions { + action := actions[idx] + a, err := dataprovider.EventActionExists(action.Name) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing event action %q not updated", a.Name) + continue + } + action.ID = a.ID + err = dataprovider.UpdateEventAction(&action, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring event action %q, dump file: %q, error: %v", action.Name, inputFile, err) + } else { + err = dataprovider.AddEventAction(&action, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new event action %q, dump file: %q, error: %v", action.Name, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore event action %q: %w", action.Name, err) + } + } + return nil +} + +// RestoreEventRules restores the specified event rules +func RestoreEventRules(rules []dataprovider.EventRule, inputFile string, mode int, executor, ipAddress, + role string, dumpVersion int, +) error { + for idx := range rules { + rule := rules[idx] + if dumpVersion < 15 { + rule.Status = 1 + } + r, err := dataprovider.EventRuleExists(rule.Name) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing event rule %q not updated", r.Name) + continue + } + rule.ID = r.ID + err = dataprovider.UpdateEventRule(&rule, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring event rule %q, dump file: %q, error: %v", rule.Name, inputFile, err) + } else { + err = dataprovider.AddEventRule(&rule, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new event rule %q, dump file: %q, error: %v", rule.Name, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore event rule %q: %w", rule.Name, err) + } + } + return nil +} + +// RestoreAPIKeys restores the specified API keys +func RestoreAPIKeys(apiKeys []dataprovider.APIKey, inputFile string, mode int, executor, ipAddress, role string) error { + for idx := range apiKeys { + apiKey := apiKeys[idx] + if apiKey.Key == "" { + logger.Warn(logSender, "", "cannot restore empty API key") + return fmt.Errorf("cannot restore an empty API key: %+v", apiKey) + } + k, err := dataprovider.APIKeyExists(apiKey.KeyID) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing API key %q not updated", apiKey.KeyID) + continue + } + apiKey.ID = k.ID + err = dataprovider.UpdateAPIKey(&apiKey, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring existing API key %q, dump file: %q, error: %v", apiKey.KeyID, inputFile, err) + } else { + err = dataprovider.AddAPIKey(&apiKey, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new API key %q, dump file: %q, error: %v", apiKey.KeyID, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore API key %q: %w", apiKey.KeyID, err) + } + } + return nil +} + +// RestoreAdmins restores the specified admins +func RestoreAdmins(admins []dataprovider.Admin, inputFile string, mode int, executor, ipAddress, role string) error { + for idx := range admins { + admin := admins[idx] + a, err := dataprovider.AdminExists(admin.Username) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing admin %q not updated", a.Username) + continue + } + admin.ID = a.ID + admin.Username = a.Username + err = dataprovider.UpdateAdmin(&admin, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring existing admin %q, dump file: %q, error: %v", admin.Username, inputFile, err) + } else { + err = dataprovider.AddAdmin(&admin, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new admin %q, dump file: %q, error: %v", admin.Username, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore admin %q: %w", admin.Username, err) + } + } + + return nil +} + +// RestoreConfigs restores the specified provider configs +func RestoreConfigs(configs *dataprovider.Configs, mode int, executor, ipAddress, + executorRole string, +) error { + if configs == nil { + return nil + } + c, err := dataprovider.GetConfigs() + if err != nil { + return fmt.Errorf("unable to restore configs, error loading existing from db: %w", err) + } + if c.UpdatedAt > 0 { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing configs not updated") + return nil + } + } + return dataprovider.UpdateConfigs(configs, executor, ipAddress, executorRole) +} + +// RestoreIPListEntries restores the specified IP list entries +func RestoreIPListEntries(entries []dataprovider.IPListEntry, inputFile string, mode int, executor, ipAddress, + executorRole string, +) error { + for idx := range entries { + entry := entries[idx] + e, err := dataprovider.IPListEntryExists(entry.IPOrNet, entry.Type) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing IP list entry %s-%s not updated", + e.Type.AsString(), e.IPOrNet) + continue + } + err = dataprovider.UpdateIPListEntry(&entry, executor, ipAddress, executorRole) + logger.Debug(logSender, "", "restoring existing IP list entry: %s-%s, dump file: %q, error: %v", + entry.Type.AsString(), entry.IPOrNet, inputFile, err) + } else { + err = dataprovider.AddIPListEntry(&entry, executor, ipAddress, executorRole) + logger.Debug(logSender, "", "adding new IP list entry %s-%s, dump file: %q, error: %v", + entry.Type.AsString(), entry.IPOrNet, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore IP list entry %s-%s: %w", entry.Type.AsString(), entry.IPOrNet, err) + } + } + return nil +} + +// RestoreRoles restores the specified roles +func RestoreRoles(roles []dataprovider.Role, inputFile string, mode int, executor, ipAddress, executorRole string) error { + for idx := range roles { + role := roles[idx] + r, err := dataprovider.RoleExists(role.Name) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing role %q not updated", r.Name) + continue + } + role.ID = r.ID + err = dataprovider.UpdateRole(&role, executor, ipAddress, executorRole) + logger.Debug(logSender, "", "restoring existing role: %q, dump file: %q, error: %v", role.Name, inputFile, err) + } else { + err = dataprovider.AddRole(&role, executor, ipAddress, executorRole) + logger.Debug(logSender, "", "adding new role: %q, dump file: %q, error: %v", role.Name, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore role %q: %w", role.Name, err) + } + } + return nil +} + +// RestoreGroups restores the specified groups +func RestoreGroups(groups []dataprovider.Group, inputFile string, mode int, executor, ipAddress, role string) error { + for idx := range groups { + group := groups[idx] + g, err := dataprovider.GroupExists(group.Name) + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing group %q not updated", g.Name) + continue + } + group.ID = g.ID + group.Name = g.Name + err = dataprovider.UpdateGroup(&group, g.Users, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring existing group: %q, dump file: %q, error: %v", group.Name, inputFile, err) + } else { + err = dataprovider.AddGroup(&group, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new group: %q, dump file: %q, error: %v", group.Name, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore group %q: %w", group.Name, err) + } + } + return nil +} + +// RestoreUsers restores the specified users +func RestoreUsers(users []dataprovider.User, inputFile string, mode, scanQuota int, executor, ipAddress, role string) error { + for idx := range users { + user := users[idx] + u, err := dataprovider.UserExists(user.Username, "") + if err == nil { + if mode == 1 { + logger.Debug(logSender, "", "loaddata mode 1, existing user %q not updated", u.Username) + continue + } + user.ID = u.ID + user.Username = u.Username + err = dataprovider.UpdateUser(&user, executor, ipAddress, role) + logger.Debug(logSender, "", "restoring existing user: %q, dump file: %q, error: %v", user.Username, inputFile, err) + if mode == 2 && err == nil { + disconnectUser(user.Username, executor, role) + } + } else { + err = dataprovider.AddUser(&user, executor, ipAddress, role) + logger.Debug(logSender, "", "adding new user: %q, dump file: %q, error: %v", user.Username, inputFile, err) + } + if err != nil { + return fmt.Errorf("unable to restore user %q: %w", user.Username, err) + } + if scanQuota == 1 || (scanQuota == 2 && user.HasQuotaRestrictions()) { + user, err = dataprovider.GetUserWithGroupSettings(user.Username, "") + if err == nil && common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) { + logger.Debug(logSender, "", "starting quota scan for restored user: %q", user.Username) + go doUserQuotaScan(&user) //nolint:errcheck + } + } + } + return nil +} diff --git a/internal/httpd/api_mfa.go b/internal/httpd/api_mfa.go new file mode 100644 index 00000000..73df7593 --- /dev/null +++ b/internal/httpd/api_mfa.go @@ -0,0 +1,323 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/http" + "slices" + "strconv" + "strings" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + errRecoveryCodeForbidden = errors.New("recovery codes are not available with two-factor authentication disabled") +) + +type generateTOTPRequest struct { + ConfigName string `json:"config_name"` +} + +type generateTOTPResponse struct { + ConfigName string `json:"config_name"` + Issuer string `json:"issuer"` + Secret string `json:"secret"` + URL string `json:"url"` + QRCode []byte `json:"qr_code"` +} + +type validateTOTPRequest struct { + ConfigName string `json:"config_name"` + Passcode string `json:"passcode"` + Secret string `json:"secret"` +} + +type recoveryCode struct { + Code string `json:"code"` + Used bool `json:"used"` +} + +func getTOTPConfigs(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + render.JSON(w, r, mfa.GetAvailableTOTPConfigs()) +} + +func generateTOTPSecret(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var accountName string + if hasUserAudience(claims) { + accountName = fmt.Sprintf("User %q", claims.Username) + } else { + accountName = fmt.Sprintf("Admin %q", claims.Username) + } + + var req generateTOTPRequest + err = render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + configName, key, qrCode, err := mfa.GenerateTOTPSecret(req.ConfigName, accountName) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + render.JSON(w, r, generateTOTPResponse{ + ConfigName: configName, + Issuer: key.Issuer(), + Secret: key.Secret(), + URL: key.URL(), + QRCode: qrCode, + }) +} + +func getQRCode(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + img, err := mfa.GenerateQRCodeFromURL(r.URL.Query().Get("url"), 400, 400) + if err != nil { + sendAPIResponse(w, r, nil, "unable to generate qr code", http.StatusInternalServerError) + return + } + imgSize := int64(len(img)) + w.Header().Set("Content-Length", strconv.FormatInt(imgSize, 10)) + w.Header().Set("Content-Type", "image/png") + io.CopyN(w, bytes.NewBuffer(img), imgSize) //nolint:errcheck +} + +func saveTOTPConfig(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + recoveryCodes := make([]dataprovider.RecoveryCode, 0, 12) + for i := 0; i < 12; i++ { + code := getNewRecoveryCode() + recoveryCodes = append(recoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)}) + } + baseURL := webBaseClientPath + if hasUserAudience(claims) { + if err := saveUserTOTPConfig(claims.Username, r, recoveryCodes); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + } else { + if err := saveAdminTOTPConfig(claims.Username, r, recoveryCodes); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + baseURL = webBasePath + } + if claims.MustSetTwoFactorAuth { + // force logout + defer func() { + removeCookie(w, r, baseURL) + }() + } + + sendAPIResponse(w, r, nil, "TOTP configuration saved", http.StatusOK) +} + +func validateTOTPPasscode(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + var req validateTOTPRequest + err := render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + match, err := mfa.ValidateTOTPPasscode(req.ConfigName, req.Passcode, req.Secret) + if !match || err != nil { + sendAPIResponse(w, r, err, "Invalid passcode", http.StatusBadRequest) + return + } + sendAPIResponse(w, r, nil, "Passcode successfully validated", http.StatusOK) +} + +func getRecoveryCodes(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + recoveryCodes := make([]recoveryCode, 0, 12) + var accountRecoveryCodes []dataprovider.RecoveryCode + if hasUserAudience(claims) { + user, err := dataprovider.UserExists(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !user.Filters.TOTPConfig.Enabled { + sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) + return + } + accountRecoveryCodes = user.Filters.RecoveryCodes + } else { + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !admin.Filters.TOTPConfig.Enabled { + sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) + return + } + accountRecoveryCodes = admin.Filters.RecoveryCodes + } + + for _, code := range accountRecoveryCodes { + if err := code.Secret.Decrypt(); err != nil { + sendAPIResponse(w, r, err, "Unable to decrypt recovery codes", getRespStatus(err)) + return + } + recoveryCodes = append(recoveryCodes, recoveryCode{ + Code: code.Secret.GetPayload(), + Used: code.Used, + }) + } + render.JSON(w, r, recoveryCodes) +} + +func generateRecoveryCodes(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + recoveryCodes := make([]string, 0, 12) + accountRecoveryCodes := make([]dataprovider.RecoveryCode, 0, 12) + for i := 0; i < 12; i++ { + code := getNewRecoveryCode() + recoveryCodes = append(recoveryCodes, code) + accountRecoveryCodes = append(accountRecoveryCodes, dataprovider.RecoveryCode{Secret: kms.NewPlainSecret(code)}) + } + if hasUserAudience(claims) { + user, err := dataprovider.UserExists(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !user.Filters.TOTPConfig.Enabled { + sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) + return + } + user.Filters.RecoveryCodes = accountRecoveryCodes + if err := dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + } else { + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !admin.Filters.TOTPConfig.Enabled { + sendAPIResponse(w, r, errRecoveryCodeForbidden, "", http.StatusForbidden) + return + } + admin.Filters.RecoveryCodes = accountRecoveryCodes + if err := dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), admin.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + } + + render.JSON(w, r, recoveryCodes) +} + +func getNewRecoveryCode() string { + return fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID())) +} + +func saveUserTOTPConfig(username string, r *http.Request, recoveryCodes []dataprovider.RecoveryCode) error { + user, userMerged, err := dataprovider.GetUserVariants(username, "") + if err != nil { + return err + } + currentTOTPSecret := user.Filters.TOTPConfig.Secret + user.Filters.TOTPConfig.Secret = nil + err = render.DecodeJSON(r.Body, &user.Filters.TOTPConfig) + if err != nil { + return util.NewValidationError(fmt.Sprintf("unable to decode JSON body: %v", err)) + } + if !user.Filters.TOTPConfig.Enabled && len(userMerged.Filters.TwoFactorAuthProtocols) > 0 { + return util.NewValidationError("two-factor authentication must be enabled") + } + for _, p := range userMerged.Filters.TwoFactorAuthProtocols { + if !slices.Contains(user.Filters.TOTPConfig.Protocols, p) { + return util.NewValidationError(fmt.Sprintf("totp: the following protocols are required: %q", + strings.Join(userMerged.Filters.TwoFactorAuthProtocols, ", "))) + } + } + if user.Filters.TOTPConfig.Secret == nil || !user.Filters.TOTPConfig.Secret.IsPlain() { + user.Filters.TOTPConfig.Secret = currentTOTPSecret + } + if user.Filters.TOTPConfig.Enabled { + if user.CountUnusedRecoveryCodes() < 5 && user.Filters.TOTPConfig.Enabled { + user.Filters.RecoveryCodes = recoveryCodes + } + } else { + user.Filters.RecoveryCodes = nil + } + return dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role) +} + +func saveAdminTOTPConfig(username string, r *http.Request, recoveryCodes []dataprovider.RecoveryCode) error { + admin, err := dataprovider.AdminExists(username) + if err != nil { + return err + } + currentTOTPSecret := admin.Filters.TOTPConfig.Secret + admin.Filters.TOTPConfig.Secret = nil + err = render.DecodeJSON(r.Body, &admin.Filters.TOTPConfig) + if err != nil { + return util.NewValidationError(fmt.Sprintf("unable to decode JSON body: %v", err)) + } + if !admin.Filters.TOTPConfig.Enabled && admin.Filters.RequireTwoFactor { + return util.NewValidationError("two-factor authentication must be enabled") + } + if admin.Filters.TOTPConfig.Enabled { + if admin.CountUnusedRecoveryCodes() < 5 && admin.Filters.TOTPConfig.Enabled { + admin.Filters.RecoveryCodes = recoveryCodes + } + } else { + admin.Filters.RecoveryCodes = nil + } + if admin.Filters.TOTPConfig.Secret == nil || !admin.Filters.TOTPConfig.Secret.IsPlain() { + admin.Filters.TOTPConfig.Secret = currentTOTPSecret + } + return dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, util.GetIPFromRemoteAddress(r.RemoteAddr), admin.Role) +} diff --git a/internal/httpd/api_quota.go b/internal/httpd/api_quota.go new file mode 100644 index 00000000..db339e39 --- /dev/null +++ b/internal/httpd/api_quota.go @@ -0,0 +1,283 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "errors" + "fmt" + "net/http" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + quotaUpdateModeAdd = "add" + quotaUpdateModeReset = "reset" +) + +type quotaUsage struct { + UsedQuotaSize int64 `json:"used_quota_size"` + UsedQuotaFiles int `json:"used_quota_files"` +} + +type transferQuotaUsage struct { + UsedUploadDataTransfer int64 `json:"used_upload_data_transfer"` + UsedDownloadDataTransfer int64 `json:"used_download_data_transfer"` +} + +func getUsersQuotaScans(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + render.JSON(w, r, common.QuotaScans.GetUsersQuotaScans(claims.Role)) +} + +func getFoldersQuotaScans(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + render.JSON(w, r, common.QuotaScans.GetVFoldersQuotaScans()) +} + +func updateUserQuotaUsage(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + var usage quotaUsage + err := render.DecodeJSON(r.Body, &usage) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + doUpdateUserQuotaUsage(w, r, getURLParam(r, "username"), usage) +} + +func updateFolderQuotaUsage(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + var usage quotaUsage + err := render.DecodeJSON(r.Body, &usage) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + doUpdateFolderQuotaUsage(w, r, getURLParam(r, "name"), usage) +} + +func startUserQuotaScan(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + doStartUserQuotaScan(w, r, getURLParam(r, "username")) +} + +func startFolderQuotaScan(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + doStartFolderQuotaScan(w, r, getURLParam(r, "name")) +} + +func updateUserTransferQuotaUsage(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + var usage transferQuotaUsage + err = render.DecodeJSON(r.Body, &usage) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if usage.UsedUploadDataTransfer < 0 || usage.UsedDownloadDataTransfer < 0 { + sendAPIResponse(w, r, errors.New("invalid used transfer quota parameters, negative values are not allowed"), + "", http.StatusBadRequest) + return + } + mode, err := getQuotaUpdateMode(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserWithGroupSettings(getURLParam(r, "username"), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if mode == quotaUpdateModeAdd && !user.HasTransferQuotaRestrictions() && dataprovider.GetQuotaTracking() == 2 { + sendAPIResponse(w, r, errors.New("this user has no transfer quota restrictions, only reset mode is supported"), + "", http.StatusBadRequest) + return + } + err = dataprovider.UpdateUserTransferQuota(&user, usage.UsedUploadDataTransfer, usage.UsedDownloadDataTransfer, + mode == quotaUpdateModeReset) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) +} + +func doUpdateUserQuotaUsage(w http.ResponseWriter, r *http.Request, username string, usage quotaUsage) { + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + if usage.UsedQuotaFiles < 0 || usage.UsedQuotaSize < 0 { + sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), + "", http.StatusBadRequest) + return + } + mode, err := getQuotaUpdateMode(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserWithGroupSettings(username, claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if mode == quotaUpdateModeAdd && !user.HasQuotaRestrictions() && dataprovider.GetQuotaTracking() == 2 { + sendAPIResponse(w, r, errors.New("this user has no quota restrictions, only reset mode is supported"), + "", http.StatusBadRequest) + return + } + if !common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) { + sendAPIResponse(w, r, err, "A quota scan is in progress for this user", http.StatusConflict) + return + } + defer common.QuotaScans.RemoveUserQuotaScan(user.Username) + err = dataprovider.UpdateUserQuota(&user, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) +} + +func doUpdateFolderQuotaUsage(w http.ResponseWriter, r *http.Request, name string, usage quotaUsage) { + if usage.UsedQuotaFiles < 0 || usage.UsedQuotaSize < 0 { + sendAPIResponse(w, r, errors.New("invalid used quota parameters, negative values are not allowed"), + "", http.StatusBadRequest) + return + } + mode, err := getQuotaUpdateMode(r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + folder, err := dataprovider.GetFolderByName(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) { + sendAPIResponse(w, r, err, "A quota scan is in progress for this folder", http.StatusConflict) + return + } + defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) + err = dataprovider.UpdateVirtualFolderQuota(&folder, usage.UsedQuotaFiles, usage.UsedQuotaSize, mode == quotaUpdateModeReset) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + } else { + sendAPIResponse(w, r, err, "Quota updated", http.StatusOK) + } +} + +func doStartUserQuotaScan(w http.ResponseWriter, r *http.Request, username string) { + if dataprovider.GetQuotaTracking() == 0 { + sendAPIResponse(w, r, nil, "Quota tracking is disabled!", http.StatusForbidden) + return + } + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserWithGroupSettings(username, claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !common.QuotaScans.AddUserQuotaScan(user.Username, user.Role) { + sendAPIResponse(w, r, nil, fmt.Sprintf("Another scan is already in progress for user %q", username), + http.StatusConflict) + return + } + go doUserQuotaScan(&user) //nolint:errcheck + sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted) +} + +func doStartFolderQuotaScan(w http.ResponseWriter, r *http.Request, name string) { + if dataprovider.GetQuotaTracking() == 0 { + sendAPIResponse(w, r, nil, "Quota tracking is disabled!", http.StatusForbidden) + return + } + folder, err := dataprovider.GetFolderByName(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !common.QuotaScans.AddVFolderQuotaScan(folder.Name) { + sendAPIResponse(w, r, err, fmt.Sprintf("Another scan is already in progress for folder %q", name), + http.StatusConflict) + return + } + go doFolderQuotaScan(folder) //nolint:errcheck + sendAPIResponse(w, r, err, "Scan started", http.StatusAccepted) +} + +func doUserQuotaScan(user *dataprovider.User) error { + defer common.QuotaScans.RemoveUserQuotaScan(user.Username) + numFiles, size, err := user.ScanQuota() + if err != nil { + logger.Warn(logSender, "", "error scanning user quota %q: %v", user.Username, err) + return err + } + err = dataprovider.UpdateUserQuota(user, numFiles, size, true) + logger.Debug(logSender, "", "user quota scanned, user: %q, error: %v", user.Username, err) + return err +} + +func doFolderQuotaScan(folder vfs.BaseVirtualFolder) error { + defer common.QuotaScans.RemoveVFolderQuotaScan(folder.Name) + f := vfs.VirtualFolder{ + BaseVirtualFolder: folder, + VirtualPath: "/", + } + numFiles, size, err := f.ScanQuota() + if err != nil { + logger.Warn(logSender, "", "error scanning folder %q: %v", folder.Name, err) + return err + } + err = dataprovider.UpdateVirtualFolderQuota(&folder, numFiles, size, true) + logger.Debug(logSender, "", "virtual folder %q scanned, error: %v", folder.Name, err) + return err +} + +func getQuotaUpdateMode(r *http.Request) (string, error) { + mode := quotaUpdateModeReset + if _, ok := r.URL.Query()["mode"]; ok { + mode = r.URL.Query().Get("mode") + if mode != quotaUpdateModeReset && mode != quotaUpdateModeAdd { + return "", errors.New("invalid mode") + } + } + return mode, nil +} diff --git a/internal/httpd/api_retention.go b/internal/httpd/api_retention.go new file mode 100644 index 00000000..1502a327 --- /dev/null +++ b/internal/httpd/api_retention.go @@ -0,0 +1,34 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "net/http" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/jwt" +) + +func getRetentionChecks(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + render.JSON(w, r, common.RetentionChecks.Get(claims.Role)) +} diff --git a/internal/httpd/api_role.go b/internal/httpd/api_role.go new file mode 100644 index 00000000..d8840155 --- /dev/null +++ b/internal/httpd/api_role.go @@ -0,0 +1,135 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getRoles(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + roles, err := dataprovider.GetRoles(limit, offset, order, false) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + return + } + render.JSON(w, r, roles) +} + +func addRole(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + var role dataprovider.Role + err = render.DecodeJSON(r.Body, &role) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + err = dataprovider.AddRole(&role, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + } else { + w.Header().Add("Location", fmt.Sprintf("%s/%s", rolesPath, url.PathEscape(role.Name))) + renderRole(w, r, role.Name, http.StatusCreated) + } +} + +func updateRole(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + name := getURLParam(r, "name") + role, err := dataprovider.RoleExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedRole dataprovider.Role + err = render.DecodeJSON(r.Body, &updatedRole) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + updatedRole.ID = role.ID + updatedRole.Name = role.Name + err = dataprovider.UpdateRole(&updatedRole, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Role updated", http.StatusOK) +} + +func renderRole(w http.ResponseWriter, r *http.Request, name string, status int) { + role, err := dataprovider.RoleExists(name) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) + render.JSON(w, r.WithContext(ctx), role) + } else { + render.JSON(w, r, role) + } +} + +func getRoleByName(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + name := getURLParam(r, "name") + renderRole(w, r, name, http.StatusOK) +} + +func deleteRole(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + name := getURLParam(r, "name") + err = dataprovider.DeleteRole(name, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Role deleted", http.StatusOK) +} diff --git a/internal/httpd/api_shares.go b/internal/httpd/api_shares.go new file mode 100644 index 00000000..1ea1542e --- /dev/null +++ b/internal/httpd/api_shares.go @@ -0,0 +1,600 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "path" + "slices" + "strings" + "time" + + "github.com/go-chi/render" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func getShares(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + + shares, err := dataprovider.GetShares(limit, offset, order, claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + render.JSON(w, r, shares) +} + +func getShareByID(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + shareID := getURLParam(r, "id") + share, err := dataprovider.ShareExists(shareID, claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + share.HideConfidentialData() + + render.JSON(w, r, share) +} + +func addShare(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, err, "Unable to retrieve your user", getRespStatus(err)) + return + } + var share dataprovider.Share + if user.Filters.DefaultSharesExpiration > 0 { + share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(user.Filters.DefaultSharesExpiration))) + } + err = render.DecodeJSON(r.Body, &share) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(share.ExpiresAt)); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + share.ID = 0 + share.ShareID = util.GenerateUniqueID() + share.LastUseAt = 0 + share.Username = claims.Username + if share.Name == "" { + share.Name = share.ShareID + } + if share.Password == "" { + if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { + sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password", + http.StatusForbidden) + return + } + } + err = dataprovider.AddShare(&share, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", userSharesPath, url.PathEscape(share.ShareID))) + w.Header().Add("X-Object-ID", share.ShareID) + sendAPIResponse(w, r, nil, "Share created", http.StatusCreated) +} + +func updateShare(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, err, "Unable to retrieve your user", getRespStatus(err)) + return + } + shareID := getURLParam(r, "id") + share, err := dataprovider.ShareExists(shareID, claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedShare dataprovider.Share + err = render.DecodeJSON(r.Body, &updatedShare) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + updatedShare.ShareID = shareID + updatedShare.Username = claims.Username + if updatedShare.Password == redactedSecret { + updatedShare.Password = share.Password + } + if updatedShare.Password == "" { + if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { + sendAPIResponse(w, r, nil, "You are not authorized to share files/folders without a password", + http.StatusForbidden) + return + } + } + if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(updatedShare.ExpiresAt)); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + err = dataprovider.UpdateShare(&updatedShare, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "Share updated", http.StatusOK) +} + +func deleteShare(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + shareID := getURLParam(r, "id") + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + err = dataprovider.DeleteShare(shareID, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Share deleted", http.StatusOK) +} + +func (s *httpdServer) readBrowsableShareContents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + lister, err := connection.ReadDir(name) + if err != nil { + sendAPIResponse(w, r, err, "Unable to get directory lister", getMappedStatusCode(err)) + return + } + renderAPIDirContents(w, lister, true) +} + +func (s *httpdServer) downloadBrowsableSharedFile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + info, err := connection.Stat(name, 1) + if err != nil { + sendAPIResponse(w, r, err, "Unable to stat the requested file", getMappedStatusCode(err)) + return + } + if info.IsDir() { + sendAPIResponse(w, r, nil, fmt.Sprintf("Please set the path to a valid file, %q is a directory", name), + http.StatusBadRequest) + return + } + + inline := r.URL.Query().Get("inline") != "" + dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck + if status, err := downloadFile(w, r, connection, name, info, inline, &share); err != nil { + dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck + resp := apiResponse{ + Error: err.Error(), + Message: http.StatusText(status), + } + ctx := r.Context() + if status != 0 { + ctx = context.WithValue(ctx, render.StatusCtxKey, status) + } + render.JSON(w, r.WithContext(ctx), resp) + } +} + +func (s *httpdServer) downloadFromShare(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + compress := true + var info os.FileInfo + if len(share.Paths) == 1 { + info, err = connection.Stat(share.Paths[0], 1) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if info.Mode().IsRegular() && r.URL.Query().Get("compress") == "false" { + compress = false + } + } + + dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck + if compress { + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + err = connection.GetReadQuotaExceededError() + connection.Log(logger.LevelInfo, "denying share read due to quota limits") + sendAPIResponse(w, r, err, "", getMappedStatusCode(err)) + dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck + return + } + baseDir := "/" + if info != nil && info.IsDir() { + baseDir = share.Paths[0] + share.Paths[0] = "/" + } + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"share-%v.zip\"", share.Name)) + renderCompressedFiles(w, connection, baseDir, share.Paths, &share) + return + } + if status, err := downloadFile(w, r, connection, share.Paths[0], info, false, &share); err != nil { + dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck + resp := apiResponse{ + Error: err.Error(), + Message: http.StatusText(status), + } + ctx := r.Context() + if status != 0 { + ctx = context.WithValue(ctx, render.StatusCtxKey, status) + } + render.JSON(w, r.WithContext(ctx), resp) + } +} + +func (s *httpdServer) uploadFileToShare(w http.ResponseWriter, r *http.Request) { + if maxUploadFileSize > 0 { + r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) + } + name := getURLParam(r, "name") + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeWrite, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + filePath := util.CleanPath(path.Join(share.Paths[0], name)) + expectedPrefix := share.Paths[0] + if !strings.HasSuffix(expectedPrefix, "/") { + expectedPrefix += "/" + } + if !strings.HasPrefix(filePath, expectedPrefix) { + sendAPIResponse(w, r, err, "Uploading outside the share is not allowed", http.StatusForbidden) + return + } + dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck + + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + if getBoolQueryParam(r, "mkdir_parents") { + if err = connection.CheckParentDirs(path.Dir(filePath)); err != nil { + sendAPIResponse(w, r, err, "Error checking parent directories", getMappedStatusCode(err)) + return + } + } + if err := doUploadFile(w, r, connection, filePath); err != nil { + dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck + } +} + +func (s *httpdServer) uploadFilesToShare(w http.ResponseWriter, r *http.Request) { + if maxUploadFileSize > 0 { + r.Body = http.MaxBytesReader(w, r.Body, maxUploadFileSize) + } + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeWrite, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := common.Connections.IsNewTransferAllowed(connection.User.Username); err != nil { + connection.Log(logger.LevelInfo, "denying file write due to number of transfer limits") + sendAPIResponse(w, r, err, "Denying file write due to transfer count limits", + http.StatusConflict) + return + } + + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasUploadSpace() { + connection.Log(logger.LevelInfo, "denying file write due to transfer quota limits") + sendAPIResponse(w, r, common.ErrQuotaExceeded, "Denying file write due to transfer quota limits", + http.StatusRequestEntityTooLarge) + return + } + + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + t := newThrottledReader(r.Body, connection.User.UploadBandwidth, connection) + r.Body = t + err = r.ParseMultipartForm(maxMultipartMem) + if err != nil { + connection.RemoveTransfer(t) + sendAPIResponse(w, r, err, "Unable to parse multipart form", http.StatusBadRequest) + return + } + connection.RemoveTransfer(t) + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + files := r.MultipartForm.File["filenames"] + if len(files) == 0 { + sendAPIResponse(w, r, nil, "No files uploaded!", http.StatusBadRequest) + return + } + if share.MaxTokens > 0 { + if len(files) > (share.MaxTokens - share.UsedTokens) { + sendAPIResponse(w, r, nil, "Allowed usage exceeded", http.StatusBadRequest) + return + } + } + dataprovider.UpdateShareLastUse(&share, len(files)) //nolint:errcheck + + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + numUploads := doUploadFiles(w, r, connection, share.Paths[0], files) + if numUploads != len(files) { + dataprovider.UpdateShareLastUse(&share, numUploads-len(files)) //nolint:errcheck + } +} + +func (s *httpdServer) getShareClaims(r *http.Request, shareID string) (context.Context, *jwt.Claims, error) { + token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie) + if err != nil || token == nil { + return nil, nil, errInvalidToken + } + tokenString := jwt.TokenFromCookie(r) + if tokenString == "" || invalidatedJWTTokens.Get(tokenString) { + return nil, nil, errInvalidToken + } + if !token.Audience.Contains(tokenAudienceWebShare) { + logger.Debug(logSender, "", "invalid token audience for share %q", shareID) + return nil, nil, errInvalidToken + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + logger.Debug(logSender, "", "token for share %q is not valid for the ip address %q", shareID, ipAddr) + return nil, nil, err + } + if token.Username != shareID { + logger.Debug(logSender, "", "token not valid for share %q", shareID) + return nil, nil, errInvalidToken + } + ctx := jwt.NewContext(r.Context(), token, nil) + return ctx, token, nil +} + +func (s *httpdServer) checkWebClientShareCredentials(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) error { + doRedirect := func() { + redirectURL := path.Join(webClientPubSharesPath, share.ShareID, fmt.Sprintf("login?next=%s", url.QueryEscape(r.RequestURI))) + http.Redirect(w, r, redirectURL, http.StatusFound) + } + + if _, _, err := s.getShareClaims(r, share.ShareID); err != nil { + doRedirect() + return err + } + return nil +} + +func (s *httpdServer) checkPublicShare(w http.ResponseWriter, r *http.Request, validScopes []dataprovider.ShareScope, +) (dataprovider.Share, *Connection, error) { + isWebClient := isWebClientRequest(r) + renderError := func(err error, message string, statusCode int) { + if isWebClient { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, statusCode, err, message) + } else { + sendAPIResponse(w, r, err, message, statusCode) + } + } + + shareID := getURLParam(r, "id") + share, err := dataprovider.ShareExists(shareID, "") + if err != nil { + statusCode := getRespStatus(err) + if statusCode == http.StatusNotFound { + err = util.NewI18nError(errors.New("share does not exist"), util.I18nError404Message) + } + renderError(err, "", statusCode) + return share, nil, err + } + if !slices.Contains(validScopes, share.Scope) { + err := errors.New("invalid share scope") + renderError(util.NewI18nError(err, util.I18nErrorShareScope), "", http.StatusForbidden) + return share, nil, err + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + ok, err := share.IsUsable(ipAddr) + if !ok || err != nil { + renderError(err, "", getRespStatus(err)) + return share, nil, err + } + if share.Password != "" { + if isWebClient { + if err := s.checkWebClientShareCredentials(w, r, &share); err != nil { + return share, nil, dataprovider.ErrInvalidCredentials + } + } else { + _, password, ok := r.BasicAuth() + if !ok { + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + renderError(dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return share, nil, dataprovider.ErrInvalidCredentials + } + match, err := share.CheckCredentials(password) + if !match || err != nil { + handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + renderError(dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return share, nil, dataprovider.ErrInvalidCredentials + } + } + common.DelayLogin(nil) + } + user, err := getUserForShare(share) + if err != nil { + renderError(err, "", getRespStatus(err)) + return share, nil, err + } + connID := xid.New().String() + baseConn := common.NewBaseConnection(connID, common.ProtocolHTTPShare, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + + return share, connection, nil +} + +func getUserForShare(share dataprovider.Share) (dataprovider.User, error) { + user, err := dataprovider.GetUserWithGroupSettings(share.Username, "") + if err != nil { + return user, err + } + if !user.CanManageShares() { + return user, util.NewI18nError(util.NewRecordNotFoundError("this share does not exist"), util.I18nError404Message) + } + if share.Password == "" && slices.Contains(user.Filters.WebClient, sdk.WebClientShareNoPasswordDisabled) { + return user, util.NewI18nError( + fmt.Errorf("sharing without a password was disabled: %w", os.ErrPermission), + util.I18nError403Message, + ) + } + if user.MustSetSecondFactorForProtocol(common.ProtocolHTTP) { + return user, util.NewI18nError( + util.NewMethodDisabledError("two-factor authentication requirements not met"), + util.I18nError403Message, + ) + } + return user, nil +} + +func validateBrowsableShare(share dataprovider.Share, connection *Connection) error { + if len(share.Paths) != 1 { + return util.NewI18nError( + util.NewValidationError("a share with multiple paths is not browsable"), + util.I18nErrorShareBrowsePaths, + ) + } + basePath := share.Paths[0] + info, err := connection.Stat(basePath, 0) + if err != nil { + connection.CloseFS() //nolint:errcheck + return util.NewI18nError( + fmt.Errorf("unable to check the share directory: %w", err), + util.I18nErrorShareInvalidPath, + ) + } + if !info.IsDir() { + return util.NewI18nError( + util.NewValidationError("the shared object is not a directory and so it is not browsable"), + util.I18nErrorShareBrowseNoDir, + ) + } + return nil +} + +func getBrowsableSharedPath(shareBasePath string, r *http.Request) (string, error) { + name := util.CleanPath(path.Join(shareBasePath, r.URL.Query().Get("path"))) + if shareBasePath == "/" { + return name, nil + } + if name != shareBasePath && !strings.HasPrefix(name, shareBasePath+"/") { + return "", util.NewI18nError( + util.NewValidationError(fmt.Sprintf("Invalid path %q", r.URL.Query().Get("path"))), + util.I18nErrorPathInvalid, + ) + } + return name, nil +} diff --git a/internal/httpd/api_user.go b/internal/httpd/api_user.go new file mode 100644 index 00000000..4af45f30 --- /dev/null +++ b/internal/httpd/api_user.go @@ -0,0 +1,330 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strconv" + "time" + + "github.com/go-chi/render" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +func getUsers(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + limit, offset, order, err := getSearchFilters(w, r) + if err != nil { + return + } + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + users, err := dataprovider.GetUsers(limit, offset, order, claims.Role) + if err == nil { + render.JSON(w, r, users) + } else { + sendAPIResponse(w, r, err, "", http.StatusInternalServerError) + } +} + +func getUserByUsername(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + username := getURLParam(r, "username") + renderUser(w, r, username, claims, http.StatusOK) +} + +func renderUser(w http.ResponseWriter, r *http.Request, username string, claims *jwt.Claims, status int) { + user, err := dataprovider.UserExists(username, claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if hideConfidentialData(claims, r) { + user.PrepareForRendering() + } + if status != http.StatusOK { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, status) + render.JSON(w, r.WithContext(ctx), user) + } else { + render.JSON(w, r, user) + } +} + +func addUser(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + var user dataprovider.User + if admin.Filters.Preferences.DefaultUsersExpiration > 0 { + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) + } + err = render.DecodeJSON(r.Body, &user) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if claims.Role != "" { + user.Role = claims.Role + } + user.LastPasswordChange = 0 + user.Filters.RecoveryCodes = nil + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: false, + } + err = dataprovider.AddUser(&user, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + w.Header().Add("Location", fmt.Sprintf("%s/%s", userPath, url.PathEscape(user.Username))) + renderUser(w, r, user.Username, claims, http.StatusCreated) +} + +func disableUser2FA(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + username := getURLParam(r, "username") + user, err := dataprovider.UserExists(username, claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + if !user.Filters.TOTPConfig.Enabled { + sendAPIResponse(w, r, nil, "two-factor authentication is not enabled", http.StatusBadRequest) + return + } + user.Filters.RecoveryCodes = nil + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: false, + } + if err := dataprovider.UpdateUser(&user, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, nil, "2FA disabled", http.StatusOK) +} + +func updateUser(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + + username := getURLParam(r, "username") + disconnect := 0 + if _, ok := r.URL.Query()["disconnect"]; ok { + disconnect, err = strconv.Atoi(r.URL.Query().Get("disconnect")) + if err != nil { + err = fmt.Errorf("invalid disconnect parameter: %v", err) + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + } + user, err := dataprovider.UserExists(username, claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + var updatedUser dataprovider.User + updatedUser.Password = user.Password + err = render.DecodeJSON(r.Body, &updatedUser) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + updatedUser.ID = user.ID + updatedUser.Username = user.Username + updatedUser.Filters.RecoveryCodes = user.Filters.RecoveryCodes + updatedUser.Filters.TOTPConfig = user.Filters.TOTPConfig + updatedUser.LastPasswordChange = user.LastPasswordChange + updatedUser.SetEmptySecretsIfNil() + updateEncryptedSecrets(&updatedUser.FsConfig, &user.FsConfig) + if claims.Role != "" { + updatedUser.Role = claims.Role + } + err = dataprovider.UpdateUser(&updatedUser, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "User updated", http.StatusOK) + if disconnect == 1 { + disconnectUser(user.Username, claims.Username, claims.Role) + } +} + +func deleteUser(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + username := getURLParam(r, "username") + err = dataprovider.DeleteUser(username, claims.Username, util.GetIPFromRemoteAddress(r.RemoteAddr), claims.Role) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "User deleted", http.StatusOK) + disconnectUser(dataprovider.ConvertName(username), claims.Username, claims.Role) +} + +func forgotUserPassword(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + if !smtp.IsEnabled() { + sendAPIResponse(w, r, nil, "No SMTP configuration", http.StatusBadRequest) + return + } + + err := handleForgotPassword(r, getURLParam(r, "username"), false) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + + sendAPIResponse(w, r, err, "Check your email for the confirmation code", http.StatusOK) +} + +func resetUserPassword(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + var req pwdReset + err := render.DecodeJSON(r.Body, &req) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + _, _, err = handleResetPassword(r, req.Code, req.Password, req.Password, false) + if err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + sendAPIResponse(w, r, err, "Password reset successful", http.StatusOK) +} + +func disconnectUser(username, admin, role string) { + for _, stat := range common.Connections.GetStats("") { + if stat.Username == username { + common.Connections.Close(stat.ConnectionID, "") + } + } + for _, stat := range getNodesConnections(admin, role) { + if stat.Username == username { + n, err := dataprovider.GetNodeByName(stat.Node) + if err != nil { + logger.Warn(logSender, "", "unable to disconnect user %q, error getting node %q: %v", username, stat.Node, err) + continue + } + perms := []string{dataprovider.PermAdminCloseConnections} + uri := fmt.Sprintf("%s/%s", activeConnectionsPath, stat.ConnectionID) + if err := n.SendDeleteRequest(admin, role, uri, perms); err != nil { + logger.Warn(logSender, "", "unable to disconnect user %q from node %q, error: %v", username, n.Name, err) + } + } + } +} + +func updateEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Filesystem) { + // we use the new access secret if plain or empty, otherwise the old value + switch fsConfig.Provider { + case sdk.S3FilesystemProvider: + if fsConfig.S3Config.AccessSecret.IsNotPlainAndNotEmpty() { + fsConfig.S3Config.AccessSecret = currentFsConfig.S3Config.AccessSecret + } + if fsConfig.S3Config.SSECustomerKey.IsNotPlainAndNotEmpty() { + fsConfig.S3Config.SSECustomerKey = currentFsConfig.S3Config.SSECustomerKey + } + case sdk.AzureBlobFilesystemProvider: + if fsConfig.AzBlobConfig.AccountKey.IsNotPlainAndNotEmpty() { + fsConfig.AzBlobConfig.AccountKey = currentFsConfig.AzBlobConfig.AccountKey + } + if fsConfig.AzBlobConfig.SASURL.IsNotPlainAndNotEmpty() { + fsConfig.AzBlobConfig.SASURL = currentFsConfig.AzBlobConfig.SASURL + } + case sdk.GCSFilesystemProvider: + // for GCS credentials will be cleared if we enable automatic credentials + // so keep the old credentials here if no new credentials are provided + if !fsConfig.GCSConfig.Credentials.IsPlain() { + fsConfig.GCSConfig.Credentials = currentFsConfig.GCSConfig.Credentials + } + case sdk.CryptedFilesystemProvider: + if fsConfig.CryptConfig.Passphrase.IsNotPlainAndNotEmpty() { + fsConfig.CryptConfig.Passphrase = currentFsConfig.CryptConfig.Passphrase + } + case sdk.SFTPFilesystemProvider: + updateSFTPFsEncryptedSecrets(fsConfig, currentFsConfig) + case sdk.HTTPFilesystemProvider: + updateHTTPFsEncryptedSecrets(fsConfig, currentFsConfig) + } +} + +func updateSFTPFsEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Filesystem) { + if fsConfig.SFTPConfig.Password.IsNotPlainAndNotEmpty() { + fsConfig.SFTPConfig.Password = currentFsConfig.SFTPConfig.Password + } + if fsConfig.SFTPConfig.PrivateKey.IsNotPlainAndNotEmpty() { + fsConfig.SFTPConfig.PrivateKey = currentFsConfig.SFTPConfig.PrivateKey + } + if fsConfig.SFTPConfig.KeyPassphrase.IsNotPlainAndNotEmpty() { + fsConfig.SFTPConfig.KeyPassphrase = currentFsConfig.SFTPConfig.KeyPassphrase + } +} + +func updateHTTPFsEncryptedSecrets(fsConfig *vfs.Filesystem, currentFsConfig *vfs.Filesystem) { + if fsConfig.HTTPConfig.Password.IsNotPlainAndNotEmpty() { + fsConfig.HTTPConfig.Password = currentFsConfig.HTTPConfig.Password + } + if fsConfig.HTTPConfig.APIKey.IsNotPlainAndNotEmpty() { + fsConfig.HTTPConfig.APIKey = currentFsConfig.HTTPConfig.APIKey + } +} diff --git a/internal/httpd/api_utils.go b/internal/httpd/api_utils.go new file mode 100644 index 00000000..cef57b53 --- /dev/null +++ b/internal/httpd/api_utils.go @@ -0,0 +1,961 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "mime" + "net/http" + "net/url" + "os" + "path" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" + "github.com/klauspost/compress/zip" + "github.com/rs/xid" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +type pwdChange struct { + CurrentPassword string `json:"current_password"` + NewPassword string `json:"new_password"` +} + +type pwdReset struct { + Code string `json:"code"` + Password string `json:"password"` +} + +type baseProfile struct { + Email string `json:"email,omitempty"` + Description string `json:"description,omitempty"` + AllowAPIKeyAuth bool `json:"allow_api_key_auth"` +} + +type adminProfile struct { + baseProfile +} + +type userProfile struct { + baseProfile + AdditionalEmails []string `json:"additional_emails,omitempty"` + PublicKeys []string `json:"public_keys,omitempty"` + TLSCerts []string `json:"tls_certs,omitempty"` +} + +func sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { + var errorString string + if errors.Is(err, util.ErrNotFound) { + errorString = http.StatusText(http.StatusNotFound) + } else if err != nil { + errorString = err.Error() + } + resp := apiResponse{ + Error: errorString, + Message: message, + } + ctx := context.WithValue(r.Context(), render.StatusCtxKey, code) + render.JSON(w, r.WithContext(ctx), resp) +} + +func getRespStatus(err error) int { + if errors.Is(err, util.ErrValidation) { + return http.StatusBadRequest + } + if errors.Is(err, util.ErrMethodDisabled) { + return http.StatusForbidden + } + if errors.Is(err, util.ErrNotFound) { + return http.StatusNotFound + } + if errors.Is(err, fs.ErrNotExist) { + return http.StatusBadRequest + } + if errors.Is(err, fs.ErrPermission) || errors.Is(err, dataprovider.ErrLoginNotAllowedFromIP) { + return http.StatusForbidden + } + if errors.Is(err, plugin.ErrNoSearcher) || errors.Is(err, dataprovider.ErrNotImplemented) { + return http.StatusNotImplemented + } + if errors.Is(err, dataprovider.ErrDuplicatedKey) || errors.Is(err, dataprovider.ErrForeignKeyViolated) { + return http.StatusConflict + } + return http.StatusInternalServerError +} + +// mappig between fs errors for HTTP protocol and HTTP response status codes +func getMappedStatusCode(err error) int { + var statusCode int + switch { + case errors.Is(err, fs.ErrPermission): + statusCode = http.StatusForbidden + case errors.Is(err, common.ErrReadQuotaExceeded): + statusCode = http.StatusForbidden + case errors.Is(err, fs.ErrNotExist): + statusCode = http.StatusNotFound + case errors.Is(err, common.ErrQuotaExceeded): + statusCode = http.StatusRequestEntityTooLarge + case errors.Is(err, common.ErrOpUnsupported): + statusCode = http.StatusBadRequest + default: + if _, ok := err.(*http.MaxBytesError); ok { + statusCode = http.StatusRequestEntityTooLarge + } else { + statusCode = http.StatusInternalServerError + } + } + return statusCode +} + +func getURLParam(r *http.Request, key string) string { + v := chi.URLParam(r, key) + unescaped, err := url.PathUnescape(v) + if err != nil { + return v + } + return unescaped +} + +func getURLPath(r *http.Request) string { + rctx := chi.RouteContext(r.Context()) + if rctx != nil && rctx.RoutePath != "" { + return rctx.RoutePath + } + return r.URL.Path +} + +func getCommaSeparatedQueryParam(r *http.Request, key string) []string { + var result []string + + for val := range strings.SplitSeq(r.URL.Query().Get(key), ",") { + val = strings.TrimSpace(val) + if val != "" { + result = append(result, val) + } + } + + return util.RemoveDuplicates(result, false) +} + +func getBoolQueryParam(r *http.Request, param string) bool { + return r.URL.Query().Get(param) == "true" +} + +func getActiveConnections(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + stats := common.Connections.GetStats(claims.Role) + if claims.NodeID == "" { + stats = append(stats, getNodesConnections(claims.Username, claims.Role)...) + } + render.JSON(w, r, stats) +} + +func handleCloseConnection(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + connectionID := getURLParam(r, "connectionID") + if connectionID == "" { + sendAPIResponse(w, r, nil, "connectionID is mandatory", http.StatusBadRequest) + return + } + node := r.URL.Query().Get("node") + if node == "" || node == dataprovider.GetNodeName() { + if common.Connections.Close(connectionID, claims.Role) { + sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK) + } else { + sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) + } + return + } + n, err := dataprovider.GetNodeByName(node) + if err != nil { + logger.Warn(logSender, "", "unable to get node with name %q: %v", node, err) + status := getRespStatus(err) + sendAPIResponse(w, r, nil, http.StatusText(status), status) + return + } + perms := []string{dataprovider.PermAdminCloseConnections} + uri := fmt.Sprintf("%s/%s", activeConnectionsPath, connectionID) + if err := n.SendDeleteRequest(claims.Username, claims.Role, uri, perms); err != nil { + logger.Warn(logSender, "", "unable to delete connection id %q from node %q: %v", connectionID, n.Name, err) + sendAPIResponse(w, r, nil, "Not Found", http.StatusNotFound) + return + } + sendAPIResponse(w, r, nil, "Connection closed", http.StatusOK) +} + +// getNodesConnections returns the active connections from other nodes. +// Errors are silently ignored +func getNodesConnections(admin, role string) []common.ConnectionStatus { + nodes, err := dataprovider.GetNodes() + if err != nil || len(nodes) == 0 { + return nil + } + var results []common.ConnectionStatus + var mu sync.Mutex + var wg sync.WaitGroup + + for _, n := range nodes { + wg.Add(1) + + go func(node dataprovider.Node) { + defer wg.Done() + + var stats []common.ConnectionStatus + perms := []string{dataprovider.PermAdminViewConnections} + if err := node.SendGetRequest(admin, role, activeConnectionsPath, perms, &stats); err != nil { + logger.Warn(logSender, "", "unable to get connections from node %s: %v", node.Name, err) + return + } + + mu.Lock() + results = append(results, stats...) + mu.Unlock() + }(n) + } + wg.Wait() + + return results +} + +func getSearchFilters(w http.ResponseWriter, r *http.Request) (int, int, string, error) { + var err error + limit := 100 + offset := 0 + order := dataprovider.OrderASC + if _, ok := r.URL.Query()["limit"]; ok { + limit, err = strconv.Atoi(r.URL.Query().Get("limit")) + if err != nil { + err = errors.New("invalid limit") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return limit, offset, order, err + } + if limit > 500 { + limit = 500 + } + } + if _, ok := r.URL.Query()["offset"]; ok { + offset, err = strconv.Atoi(r.URL.Query().Get("offset")) + if err != nil { + err = errors.New("invalid offset") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return limit, offset, order, err + } + } + if _, ok := r.URL.Query()["order"]; ok { + order = r.URL.Query().Get("order") + if order != dataprovider.OrderASC && order != dataprovider.OrderDESC { + err = errors.New("invalid order") + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return limit, offset, order, err + } + } + + return limit, offset, order, err +} + +func renderAPIDirContents(w http.ResponseWriter, lister vfs.DirLister, omitNonRegularFiles bool) { + defer lister.Close() + + dataGetter := func(limit, _ int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + return nil, 0, err + } + results := make([]map[string]any, 0, len(contents)) + for _, info := range contents { + if omitNonRegularFiles && !info.Mode().IsDir() && !info.Mode().IsRegular() { + continue + } + res := make(map[string]any) + res["name"] = info.Name() + if info.Mode().IsRegular() { + res["size"] = info.Size() + } + res["mode"] = info.Mode() + res["last_modified"] = info.ModTime().UTC().Format(time.RFC3339) + results = append(results, res) + } + data, err := json.Marshal(results) + count := limit + if len(results) == 0 { + count = 0 + } + return data, count, err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func streamData(w io.Writer, data []byte) { + b := bytes.NewBuffer(data) + _, err := io.CopyN(w, b, int64(len(data))) + if err != nil { + panic(http.ErrAbortHandler) + } +} + +func streamJSONArray(w http.ResponseWriter, chunkSize int, dataGetter func(limit, offset int) ([]byte, int, error)) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Accept-Ranges", "none") + w.WriteHeader(http.StatusOK) + + streamData(w, []byte("[")) + offset := 0 + for { + data, count, err := dataGetter(chunkSize, offset) + if err != nil { + panic(http.ErrAbortHandler) + } + if count == 0 { + break + } + if offset > 0 { + streamData(w, []byte(",")) + } + streamData(w, data[1:len(data)-1]) + if count < chunkSize { + break + } + offset += count + } + streamData(w, []byte("]")) +} + +func renderPNGImage(w http.ResponseWriter, r *http.Request, b []byte) { + if len(b) == 0 { + ctx := context.WithValue(r.Context(), render.StatusCtxKey, http.StatusNotFound) + render.PlainText(w, r.WithContext(ctx), http.StatusText(http.StatusNotFound)) + return + } + w.Header().Set("Content-Type", "image/png") + streamData(w, b) +} + +func getCompressedFileName(username string, files []string) string { + if len(files) == 1 { + name := path.Base(files[0]) + return fmt.Sprintf("%s-%s.zip", username, strings.TrimSuffix(name, path.Ext(name))) + } + return fmt.Sprintf("%s-download.zip", username) +} + +func renderCompressedFiles(w http.ResponseWriter, conn *Connection, baseDir string, files []string, + share *dataprovider.Share, +) { + conn.User.CheckFsRoot(conn.ID) //nolint:errcheck + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Accept-Ranges", "none") + w.Header().Set("Content-Transfer-Encoding", "binary") + w.WriteHeader(http.StatusOK) + + wr := zip.NewWriter(w) + + for _, file := range files { + fullPath := util.CleanPath(path.Join(baseDir, file)) + if err := addZipEntry(wr, conn, fullPath, baseDir, nil, 0); err != nil { + if share != nil { + dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck + } + panic(http.ErrAbortHandler) + } + } + if err := wr.Close(); err != nil { + conn.Log(logger.LevelError, "unable to close zip file: %v", err) + if share != nil { + dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck + } + panic(http.ErrAbortHandler) + } +} + +func addZipEntry(wr *zip.Writer, conn *Connection, entryPath, baseDir string, info os.FileInfo, recursion int) error { + if recursion >= util.MaxRecursion { + conn.Log(logger.LevelDebug, "unable to add zip entry %q, recursion too depth: %d", entryPath, recursion) + return util.ErrRecursionTooDeep + } + recursion++ + var err error + if info == nil { + info, err = conn.Stat(entryPath, 1) + if err != nil { + conn.Log(logger.LevelDebug, "unable to add zip entry %q, stat error: %v", entryPath, err) + return err + } + } + entryName, err := getZipEntryName(entryPath, baseDir) + if err != nil { + conn.Log(logger.LevelError, "unable to get zip entry name: %v", err) + return err + } + if info.IsDir() { + _, err = wr.CreateHeader(&zip.FileHeader{ + Name: entryName + "/", + Method: zip.Deflate, + Modified: info.ModTime(), + }) + if err != nil { + conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) + return err + } + lister, err := conn.ReadDir(entryPath) + if err != nil { + conn.Log(logger.LevelDebug, "unable to add zip entry %q, get list dir error: %v", entryPath, err) + return err + } + defer lister.Close() + + for { + contents, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return err + } + for _, info := range contents { + fullPath := util.CleanPath(path.Join(entryPath, info.Name())) + if err := addZipEntry(wr, conn, fullPath, baseDir, info, recursion); err != nil { + return err + } + } + if finished { + return nil + } + } + } + if !info.Mode().IsRegular() { + // we only allow regular files + conn.Log(logger.LevelInfo, "skipping zip entry for non regular file %q", entryPath) + return nil + } + return addFileToZipEntry(wr, conn, entryPath, entryName, info) +} + +func addFileToZipEntry(wr *zip.Writer, conn *Connection, entryPath, entryName string, info os.FileInfo) error { + reader, err := conn.getFileReader(entryPath, 0, http.MethodGet) + if err != nil { + conn.Log(logger.LevelDebug, "unable to add zip entry %q, cannot open file: %v", entryPath, err) + return err + } + defer reader.Close() + + f, err := wr.CreateHeader(&zip.FileHeader{ + Name: entryName, + Method: zip.Deflate, + Modified: info.ModTime(), + }) + if err != nil { + conn.Log(logger.LevelError, "unable to create zip entry %q: %v", entryPath, err) + return err + } + _, err = io.Copy(f, reader) + return err +} + +func getZipEntryName(entryPath, baseDir string) (string, error) { + if !strings.HasPrefix(entryPath, baseDir) { + return "", fmt.Errorf("entry path %q is outside base dir %q", entryPath, baseDir) + } + entryPath = strings.TrimPrefix(entryPath, baseDir) + return strings.TrimPrefix(entryPath, "/"), nil +} + +func checkDownloadFileFromShare(share *dataprovider.Share, info os.FileInfo) error { + if share != nil && !info.Mode().IsRegular() { + return util.NewValidationError("non regular files are not supported for shares") + } + return nil +} + +func downloadFile(w http.ResponseWriter, r *http.Request, connection *Connection, name string, + info os.FileInfo, inline bool, share *dataprovider.Share, +) (int, error) { + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + err := checkDownloadFileFromShare(share, info) + if err != nil { + return http.StatusBadRequest, err + } + rangeHeader := r.Header.Get("Range") + if rangeHeader != "" && checkIfRange(r, info.ModTime()) == condFalse { + rangeHeader = "" + } + offset := int64(0) + size := info.Size() + responseStatus := http.StatusOK + if strings.HasPrefix(rangeHeader, "bytes=") { + if strings.Contains(rangeHeader, ",") { + return http.StatusRequestedRangeNotSatisfiable, fmt.Errorf("unsupported range %q", rangeHeader) + } + offset, size, err = parseRangeRequest(rangeHeader[6:], size) + if err != nil { + return http.StatusRequestedRangeNotSatisfiable, err + } + responseStatus = http.StatusPartialContent + } + reader, err := connection.getFileReader(name, offset, r.Method) + if err != nil { + return getMappedStatusCode(err), fmt.Errorf("unable to read file %q: %v", name, err) + } + defer reader.Close() + + w.Header().Set("Last-Modified", info.ModTime().UTC().Format(http.TimeFormat)) + if checkPreconditions(w, r, info.ModTime()) { + return 0, fmt.Errorf("%v", http.StatusText(http.StatusPreconditionFailed)) + } + ctype := mime.TypeByExtension(path.Ext(name)) + if ctype == "" { + ctype = "application/octet-stream" + } + if responseStatus == http.StatusPartialContent { + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, info.Size())) + } + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + w.Header().Set("Content-Type", ctype) + if !inline { + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", path.Base(name))) + } + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(responseStatus) + if r.Method != http.MethodHead { + _, err = io.CopyN(w, reader, size) + if err != nil { + if share != nil { + dataprovider.UpdateShareLastUse(share, -1) //nolint:errcheck + } + connection.Log(logger.LevelDebug, "error reading file to download: %v", err) + panic(http.ErrAbortHandler) + } + } + return http.StatusOK, nil +} + +func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) bool { + if checkIfUnmodifiedSince(r, modtime) == condFalse { + w.WriteHeader(http.StatusPreconditionFailed) + return true + } + if checkIfModifiedSince(r, modtime) == condFalse { + w.WriteHeader(http.StatusNotModified) + return true + } + return false +} + +func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if modtime.Before(t) || modtime.Equal(t) { + return condTrue + } + return condFalse +} + +func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if modtime.Before(t) || modtime.Equal(t) { + return condFalse + } + return condTrue +} + +func checkIfRange(r *http.Request, modtime time.Time) condResult { + if r.Method != http.MethodGet && r.Method != http.MethodHead { + return condNone + } + ir := r.Header.Get("If-Range") + if ir == "" { + return condNone + } + if modtime.IsZero() { + return condFalse + } + t, err := http.ParseTime(ir) + if err != nil { + return condFalse + } + if modtime.Unix() == t.Unix() { + return condTrue + } + return condFalse +} + +func parseRangeRequest(bytesRange string, size int64) (int64, int64, error) { + var start, end int64 + var err error + + values := strings.Split(bytesRange, "-") + if values[0] == "" { + start = -1 + } else { + start, err = strconv.ParseInt(values[0], 10, 64) + if err != nil { + return start, size, err + } + } + if len(values) >= 2 { + if values[1] != "" { + end, err = strconv.ParseInt(values[1], 10, 64) + if err != nil { + return start, size, err + } + if end >= size { + end = size - 1 + } + } + } + if start == -1 && end == 0 { + return 0, 0, fmt.Errorf("unsupported range %q", bytesRange) + } + + if end > 0 { + if start == -1 { + // we have something like -500 + start = size - end + size = end + // start cannot be < 0 here, we did end = size -1 above + } else { + // we have something like 500-600 + size = end - start + 1 + if size < 0 { + return 0, 0, fmt.Errorf("unacceptable range %q", bytesRange) + } + } + return start, size, nil + } + // we have something like 500- + size -= start + if size < 0 { + return 0, 0, fmt.Errorf("unacceptable range %q", bytesRange) + } + return start, size, err +} + +func handleDefenderEventLoginFailed(ipAddr string, err error) error { + event := common.HostEventLoginFailed + if errors.Is(err, util.ErrNotFound) { + event = common.HostEventUserNotFound + err = dataprovider.ErrInvalidCredentials + } + common.AddDefenderEvent(ipAddr, common.ProtocolHTTP, event) + common.DelayLogin(err) + return err +} + +func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err error, r *http.Request) { + metric.AddLoginAttempt(loginMethod) + var protocol string + switch loginMethod { + case dataprovider.LoginMethodIDP: + protocol = common.ProtocolOIDC + default: + protocol = common.ProtocolHTTP + } + if err == nil { + logger.LoginLog(user.Username, ip, loginMethod, protocol, "", r.UserAgent(), r.TLS != nil, "") + plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, protocol, user.Username, ip, "", nil) + common.DelayLogin(nil) + } else if err != common.ErrInternalFailure && err != common.ErrNoCredentials { + logger.ConnectionFailedLog(user.Username, ip, loginMethod, protocol, err.Error()) + err = handleDefenderEventLoginFailed(ip, err) + logEv := notifier.LogEventTypeLoginFailed + if errors.Is(err, util.ErrNotFound) { + logEv = notifier.LogEventTypeLoginNoUser + } + plugin.Handler.NotifyLogEvent(logEv, protocol, user.Username, ip, "", err) + } + metric.AddLoginResult(loginMethod, err) + dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err) +} + +func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions, isOIDCLogin bool) error { + if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) { + logger.Info(logSender, connectionID, "cannot login user %q, protocol HTTP is not allowed", user.Username) + return util.NewI18nError( + fmt.Errorf("protocol HTTP is not allowed for user %q", user.Username), + util.I18nErrorProtocolForbidden, + ) + } + if !isLoggedInWithOIDC(r) && !isOIDCLogin && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) { + logger.Info(logSender, connectionID, "cannot login user %q, password login method is not allowed", user.Username) + return util.NewI18nError( + fmt.Errorf("login method password is not allowed for user %q", user.Username), + util.I18nErrorPwdLoginForbidden, + ) + } + if checkSessions && user.MaxSessions > 0 { + activeSessions := common.Connections.GetActiveSessions(user.Username) + if activeSessions >= user.MaxSessions { + logger.Info(logSender, connectionID, "authentication refused for user: %q, too many open sessions: %v/%v", user.Username, + activeSessions, user.MaxSessions) + return util.NewI18nError(fmt.Errorf("too many open sessions: %v", activeSessions), util.I18nError429Message) + } + } + if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { + logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", user.Username, r.RemoteAddr) + return util.NewI18nError( + fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, r.RemoteAddr), + util.I18nErrorIPForbidden, + ) + } + return nil +} + +func getActiveAdmin(username, ipAddr string) (dataprovider.Admin, error) { + admin, err := dataprovider.AdminExists(username) + if err != nil { + return admin, err + } + if err := admin.CanLogin(ipAddr); err != nil { + return admin, util.NewRecordNotFoundError(fmt.Sprintf("admin %q cannot login: %v", username, err)) + } + return admin, nil +} + +func getActiveUser(username string, r *http.Request) (dataprovider.User, error) { + user, err := dataprovider.GetUserWithGroupSettings(username, "") + if err != nil { + return user, err + } + if err := user.CheckLoginConditions(); err != nil { + return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err)) + } + if err := checkHTTPClientUser(&user, r, xid.New().String(), false, false); err != nil { + return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err)) + } + return user, nil +} + +func handleForgotPassword(r *http.Request, username string, isAdmin bool) error { + var emails []string + var subject string + var err error + var admin dataprovider.Admin + var user dataprovider.User + + if username == "" { + return util.NewI18nError(util.NewValidationError("username is mandatory"), util.I18nErrorUsernameRequired) + } + if isAdmin { + admin, err = getActiveAdmin(username, util.GetIPFromRemoteAddress(r.RemoteAddr)) + if admin.Email != "" { + emails = []string{admin.Email} + } + subject = fmt.Sprintf("Email Verification Code for admin %q", username) + } else { + user, err = getActiveUser(username, r) + emails = user.GetEmailAddresses() + subject = fmt.Sprintf("Email Verification Code for user %q", username) + if err == nil { + if !isUserAllowedToResetPassword(r, &user) { + return util.NewI18nError( + util.NewValidationError("you are not allowed to reset your password"), + util.I18nErrorPwdResetForbidded, + ) + } + } + } + if err != nil { + if errors.Is(err, util.ErrNotFound) { + handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), err) //nolint:errcheck + logger.Debug(logSender, middleware.GetReqID(r.Context()), + "username %q does not exists or cannot login, reset password request silently ignored, is admin? %t, err: %v", + username, isAdmin, err) + return nil + } + return util.NewI18nError(util.NewGenericError("Error retrieving your account, please try again later"), util.I18nErrorGetUser) + } + if len(emails) == 0 { + return util.NewI18nError( + util.NewValidationError("Your account does not have an email address, it is not possible to reset your password by sending an email verification code"), + util.I18nErrorPwdResetNoEmail, + ) + } + c := newResetCode(username, isAdmin) + body := new(bytes.Buffer) + data := make(map[string]string) + data["Code"] = c.Code + if err := smtp.RenderPasswordResetTemplate(body, data); err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to render password reset template: %v", err) + return util.NewGenericError("Unable to render password reset template") + } + startTime := time.Now() + if err := smtp.SendEmail(emails, nil, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil { + logger.Warn(logSender, middleware.GetReqID(r.Context()), "unable to send password reset code via email: %v, elapsed: %v", + err, time.Since(startTime)) + return util.NewI18nError( + util.NewGenericError(fmt.Sprintf("Error sending confirmation code via email: %v", err)), + util.I18nErrorPwdResetSendEmail, + ) + } + logger.Debug(logSender, middleware.GetReqID(r.Context()), "reset code sent via email to %q, emails: %+v, is admin? %v, elapsed: %v", + username, emails, isAdmin, time.Since(startTime)) + return resetCodesMgr.Add(c) +} + +func handleResetPassword(r *http.Request, code, newPassword, confirmPassword string, isAdmin bool) ( + *dataprovider.Admin, *dataprovider.User, error, +) { + var admin dataprovider.Admin + var user dataprovider.User + var err error + + if newPassword == "" { + return &admin, &user, util.NewValidationError("please set a password") + } + if code == "" { + return &admin, &user, util.NewValidationError("please set a confirmation code") + } + if newPassword != confirmPassword { + return &admin, &user, util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch) + } + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + resetCode, err := resetCodesMgr.Get(code) + if err != nil { + handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck + return &admin, &user, util.NewValidationError("confirmation code not found") + } + if resetCode.IsAdmin != isAdmin { + return &admin, &user, util.NewValidationError("invalid confirmation code") + } + if isAdmin { + admin, err = getActiveAdmin(resetCode.Username, ipAddr) + if err != nil { + return &admin, &user, util.NewValidationError("unable to associate the confirmation code with an existing admin") + } + admin.Password = newPassword + admin.Filters.RequirePasswordChange = false + err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role) + if err != nil { + return &admin, &user, util.NewGenericError(fmt.Sprintf("unable to set the new password: %v", err)) + } + err = resetCodesMgr.Delete(code) + return &admin, &user, err + } + user, err = getActiveUser(resetCode.Username, r) + if err != nil { + return &admin, &user, util.NewValidationError("Unable to associate the confirmation code with an existing user") + } + if !isUserAllowedToResetPassword(r, &user) { + return &admin, &user, util.NewI18nError( + util.NewValidationError("you are not allowed to reset your password"), + util.I18nErrorPwdResetForbidded, + ) + } + err = dataprovider.UpdateUserPassword(user.Username, newPassword, dataprovider.ActionExecutorSelf, + util.GetIPFromRemoteAddress(r.RemoteAddr), user.Role) + if err == nil { + err = resetCodesMgr.Delete(code) + } + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now()) + user.Filters.RequirePasswordChange = false + return &admin, &user, err +} + +func isUserAllowedToResetPassword(r *http.Request, user *dataprovider.User) bool { + if !user.CanResetPassword() { + return false + } + if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) { + return false + } + if !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) { + return false + } + if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { + return false + } + return true +} + +func getProtocolFromRequest(r *http.Request) string { + if isLoggedInWithOIDC(r) { + return common.ProtocolOIDC + } + return common.ProtocolHTTP +} + +func hideConfidentialData(claims *jwt.Claims, r *http.Request) bool { + if !claims.HasPerm(dataprovider.PermAdminAny) { + return true + } + return r.URL.Query().Get("confidential_data") != "1" +} + +func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) { + if err := rc.SetReadDeadline(read); err != nil { + logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err) + } + if err := rc.SetWriteDeadline(write); err != nil { + logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err) + } +} diff --git a/internal/httpd/auth_utils.go b/internal/httpd/auth_utils.go new file mode 100644 index 00000000..ecd9ce86 --- /dev/null +++ b/internal/httpd/auth_utils.go @@ -0,0 +1,436 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "crypto/rand" + "errors" + "fmt" + "net/http" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type tokenAudience = string + +const ( + tokenAudienceWebAdmin tokenAudience = "WebAdmin" + tokenAudienceWebClient tokenAudience = "WebClient" + tokenAudienceWebShare tokenAudience = "WebShare" + tokenAudienceWebAdminPartial tokenAudience = "WebAdminPartial" + tokenAudienceWebClientPartial tokenAudience = "WebClientPartial" + tokenAudienceAPI tokenAudience = "API" + tokenAudienceAPIUser tokenAudience = "APIUser" + tokenAudienceCSRF tokenAudience = "CSRF" + tokenAudienceOAuth2 tokenAudience = "OAuth2" + tokenAudienceWebLogin tokenAudience = "WebLogin" +) + +const ( + tokenValidationModeDefault = 0 + tokenValidationModeNoIPMatch = 1 + tokenValidationModeUserSignature = 2 +) + +const ( + basicRealm = "Basic realm=\"SFTPGo\"" +) + +var ( + apiTokenDuration = 20 * time.Minute + cookieTokenDuration = 20 * time.Minute + shareTokenDuration = 2 * time.Hour + // csrf token duration is greater than normal token duration to reduce issues + // with the login form + csrfTokenDuration = 4 * time.Hour + cookieRefreshThreshold = 10 * time.Minute + maxTokenDuration = 12 * time.Hour + tokenValidationMode = tokenValidationModeDefault +) + +func isTokenDurationValid(minutes int) bool { + return minutes >= 1 && minutes <= 720 +} + +func updateTokensDuration(api, cookie, share int) { + if isTokenDurationValid(api) { + apiTokenDuration = time.Duration(api) * time.Minute + } + if isTokenDurationValid(cookie) { + cookieTokenDuration = time.Duration(cookie) * time.Minute + cookieRefreshThreshold = cookieTokenDuration / 2 + if cookieTokenDuration > csrfTokenDuration { + csrfTokenDuration = cookieTokenDuration + } + } + if isTokenDurationValid(share) { + shareTokenDuration = time.Duration(share) * time.Minute + } + logger.Debug(logSender, "", "API token duration %s, cookie token duration %s, cookie refresh threshold %s, share token duration %s, csrf token duration %s", + apiTokenDuration, cookieTokenDuration, cookieRefreshThreshold, shareTokenDuration, csrfTokenDuration) +} + +func getTokenDuration(audience tokenAudience) time.Duration { + switch audience { + case tokenAudienceWebShare: + return shareTokenDuration + case tokenAudienceWebLogin, tokenAudienceCSRF: + return csrfTokenDuration + case tokenAudienceAPI, tokenAudienceAPIUser: + return apiTokenDuration + case tokenAudienceWebAdmin, tokenAudienceWebClient: + return cookieTokenDuration + case tokenAudienceWebAdminPartial, tokenAudienceWebClientPartial, tokenAudienceOAuth2: + return 5 * time.Minute + default: + logger.Error(logSender, "", "token duration not handled for audience: %q", audience) + return 20 * time.Minute + } +} + +func getMaxCookieDuration() time.Duration { + result := csrfTokenDuration + if shareTokenDuration > result { + result = shareTokenDuration + } + if cookieTokenDuration > result { + result = cookieTokenDuration + } + return result +} + +func hasUserAudience(claims *jwt.Claims) bool { + return claims.HasAnyAudience([]string{tokenAudienceWebClient, tokenAudienceAPIUser}) +} + +func createAndSetCookie(w http.ResponseWriter, r *http.Request, claims *jwt.Claims, tokenAuth *jwt.Signer, + audience tokenAudience, ip string, +) error { + duration := getTokenDuration(audience) + token, err := tokenAuth.SignWithParams(claims, audience, ip, duration) + if err != nil { + return err + } + resp := claims.BuildTokenResponse(token) + var basePath string + if audience == tokenAudienceWebAdmin || audience == tokenAudienceWebAdminPartial { + basePath = webBaseAdminPath + } else { + basePath = webBaseClientPath + } + setCookie(w, r, basePath, resp.Token, duration) + + return nil +} + +func setCookie(w http.ResponseWriter, r *http.Request, cookiePath, cookieValue string, duration time.Duration) { + http.SetCookie(w, &http.Cookie{ + Name: jwt.CookieKey, + Value: cookieValue, + Path: cookiePath, + Expires: time.Now().Add(duration), + MaxAge: int(duration / time.Second), + HttpOnly: true, + Secure: isTLS(r), + SameSite: http.SameSiteStrictMode, + }) +} + +func removeCookie(w http.ResponseWriter, r *http.Request, cookiePath string) { + invalidateToken(r) + http.SetCookie(w, &http.Cookie{ + Name: jwt.CookieKey, + Value: "", + Path: cookiePath, + Expires: time.Unix(0, 0), + MaxAge: -1, + HttpOnly: true, + Secure: isTLS(r), + SameSite: http.SameSiteStrictMode, + }) + w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) +} + +func oidcTokenFromContext(r *http.Request) string { + if token, ok := r.Context().Value(oidcGeneratedToken).(string); ok { + return token + } + return "" +} + +func isTLS(r *http.Request) bool { + if r.TLS != nil { + return true + } + if proto, ok := r.Context().Value(forwardedProtoKey).(string); ok { + return proto == "https" //nolint:goconst + } + return false +} + +func isTokenInvalidated(r *http.Request) bool { + var findTokenFns []func(r *http.Request) string + findTokenFns = append(findTokenFns, jwt.TokenFromHeader) + findTokenFns = append(findTokenFns, jwt.TokenFromCookie) + findTokenFns = append(findTokenFns, oidcTokenFromContext) + + isTokenFound := false + for _, fn := range findTokenFns { + token := fn(r) + if token != "" { + isTokenFound = true + if invalidatedJWTTokens.Get(token) { + return true + } + } + } + + return !isTokenFound +} + +func invalidateToken(r *http.Request) { + tokenString := jwt.TokenFromHeader(r) + if tokenString != "" { + invalidateTokenString(r, tokenString, apiTokenDuration) + } + tokenString = jwt.TokenFromCookie(r) + if tokenString != "" { + invalidateTokenString(r, tokenString, getMaxCookieDuration()) + } +} + +func invalidateTokenString(r *http.Request, tokenString string, fallbackDuration time.Duration) { + token, err := jwt.FromContext(r.Context()) + if err != nil { + invalidatedJWTTokens.Add(tokenString, time.Now().Add(fallbackDuration).UTC()) + return + } + invalidatedJWTTokens.Add(tokenString, token.Expiry.Time().Add(1*time.Minute).UTC()) +} + +func getUserFromToken(r *http.Request) *dataprovider.User { + user := &dataprovider.User{} + claims, err := jwt.FromContext(r.Context()) + if err != nil { + return user + } + user.Username = claims.Username + user.Filters.WebClient = claims.Permissions + user.Role = claims.Role + return user +} + +func getAdminFromToken(r *http.Request) *dataprovider.Admin { + admin := &dataprovider.Admin{} + claims, err := jwt.FromContext(r.Context()) + if err != nil { + return admin + } + admin.Username = claims.Username + admin.Permissions = claims.Permissions + admin.Filters.Preferences.HideUserPageSections = claims.HideUserPageSections + admin.Role = claims.Role + return admin +} + +func createLoginCookie(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, basePath, ip string, +) { + c := jwt.NewClaims(tokenAudienceWebLogin, ip, getTokenDuration(tokenAudienceWebLogin)) + c.ID = tokenID + resp, err := c.GenerateTokenResponse(csrfTokenAuth) + if err != nil { + return + } + setCookie(w, r, basePath, resp.Token, csrfTokenDuration) +} + +func createCSRFToken(w http.ResponseWriter, r *http.Request, csrfTokenAuth *jwt.Signer, tokenID, + basePath string, +) string { + ip := util.GetIPFromRemoteAddress(r.RemoteAddr) + claims := jwt.NewClaims(tokenAudienceCSRF, ip, csrfTokenDuration) + claims.ID = rand.Text() + if tokenID != "" { + createLoginCookie(w, r, csrfTokenAuth, tokenID, basePath, ip) + claims.Ref = tokenID + } else { + if c, err := jwt.FromContext(r.Context()); err == nil { + claims.Ref = c.ID + } else { + logger.Error(logSender, "", "unable to add reference to CSRF token: %v", err) + } + } + tokenString, err := csrfTokenAuth.Sign(claims) + if err != nil { + logger.Debug(logSender, "", "unable to create CSRF token: %v", err) + return "" + } + return tokenString +} + +func verifyCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error { + tokenString := r.Form.Get(csrfFormToken) + token, err := jwt.VerifyToken(csrfTokenAuth, tokenString) + if err != nil || token == nil { + logger.Debug(logSender, "", "error validating CSRF token %q: %v", tokenString, err) + return fmt.Errorf("unable to verify form token: %v", err) + } + + if !token.Audience.Contains(tokenAudienceCSRF) { + logger.Debug(logSender, "", "error validating CSRF token audience") + return errors.New("the form token is not valid") + } + + if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + logger.Debug(logSender, "", "error validating CSRF token IP audience") + return errors.New("the form token is not valid") + } + return checkCSRFTokenRef(r, token) +} + +func checkCSRFTokenRef(r *http.Request, token *jwt.Claims) error { + claims, err := jwt.FromContext(r.Context()) + if err != nil { + logger.Debug(logSender, "", "error getting token claims for CSRF validation: %v", err) + return err + } + if token.ID == "" { + logger.Debug(logSender, "", "error validating CSRF token, missing reference") + return errors.New("the form token is not valid") + } + if claims.ID != token.Ref { + logger.Debug(logSender, "", "error validating CSRF reference, id %q, reference %q", claims.ID, token.ID) + return errors.New("unexpected form token") + } + + return nil +} + +func verifyLoginCookie(r *http.Request) error { + token, err := jwt.FromContext(r.Context()) + if err != nil { + logger.Debug(logSender, "", "error getting login token: %v", err) + return errInvalidToken + } + if isTokenInvalidated(r) { + logger.Debug(logSender, "", "the login token has been invalidated") + return errInvalidToken + } + if !token.Audience.Contains(tokenAudienceWebLogin) { + logger.Debug(logSender, "", "the token with id %q is not valid for audience %q", token.ID, tokenAudienceWebLogin) + return errInvalidToken + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + return err + } + return nil +} + +func verifyLoginCookieAndCSRFToken(r *http.Request, csrfTokenAuth *jwt.Signer) error { + if err := verifyLoginCookie(r); err != nil { + return err + } + if err := verifyCSRFToken(r, csrfTokenAuth); err != nil { + return err + } + return nil +} + +func createOAuth2Token(csrfTokenAuth *jwt.Signer, state, ip string) string { + claims := jwt.NewClaims(tokenAudienceOAuth2, ip, getTokenDuration(tokenAudienceOAuth2)) + claims.ID = state + + tokenString, err := csrfTokenAuth.Sign(claims) + if err != nil { + logger.Debug(logSender, "", "unable to create OAuth2 token: %v", err) + return "" + } + return tokenString +} + +func verifyOAuth2Token(csrfTokenAuth *jwt.Signer, tokenString, ip string) (string, error) { + token, err := jwt.VerifyToken(csrfTokenAuth, tokenString) + if err != nil || token == nil { + logger.Debug(logSender, "", "error validating OAuth2 token %q: %v", tokenString, err) + return "", util.NewI18nError( + fmt.Errorf("unable to verify OAuth2 state: %v", err), + util.I18nOAuth2ErrorVerifyState, + ) + } + + if !token.Audience.Contains(tokenAudienceOAuth2) { + logger.Debug(logSender, "", "error validating OAuth2 token audience") + return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) + } + + if err := validateIPForToken(token, ip); err != nil { + logger.Debug(logSender, "", "error validating OAuth2 token IP audience") + return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) + } + if token.ID != "" { + return token.ID, nil + } + logger.Debug(logSender, "", "jti not found in OAuth2 token") + return "", util.NewI18nError(errors.New("invalid OAuth2 state"), util.I18nOAuth2InvalidState) +} + +func validateIPForToken(token *jwt.Claims, ip string) error { + if tokenValidationMode&tokenValidationModeNoIPMatch == 0 { + if !token.Audience.Contains(ip) { + return errInvalidToken + } + } + return nil +} + +func checkTokenSignature(r *http.Request, token *jwt.Claims) error { + if _, ok := r.Context().Value(oidcTokenKey).(string); ok { + return nil + } + var err error + if tokenValidationMode&tokenValidationModeUserSignature != 0 { + for _, audience := range token.Audience { + switch audience { + case tokenAudienceAPI, tokenAudienceWebAdmin: + err = validateSignatureForToken(token, dataprovider.GetAdminSignature) + case tokenAudienceAPIUser, tokenAudienceWebClient: + err = validateSignatureForToken(token, dataprovider.GetUserSignature) + } + } + } + if err != nil { + invalidateToken(r) + } + return err +} + +func validateSignatureForToken(token *jwt.Claims, getter func(string) (string, error)) error { + signature, err := getter(token.Username) + if err != nil { + logger.Debug(logSender, "", "unable to get signature for username %q: %v", token.Username, err) + return errInvalidToken + } + if signature != "" && signature == token.Subject { + return nil + } + logger.Debug(logSender, "", "signature mismatch for username %q, signature %q, token signature %q", + token.Username, signature, token.Subject) + return errInvalidToken +} diff --git a/internal/httpd/file.go b/internal/httpd/file.go new file mode 100644 index 00000000..daa70b49 --- /dev/null +++ b/internal/httpd/file.go @@ -0,0 +1,144 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "io" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +type httpdFile struct { + *common.BaseTransfer + writer io.WriteCloser + reader io.ReadCloser + isFinished bool +} + +func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader) *httpdFile { + var writer io.WriteCloser + var reader io.ReadCloser + if baseTransfer.File != nil { + writer = baseTransfer.File + reader = baseTransfer.File + } else if pipeWriter != nil { + writer = pipeWriter + } else if pipeReader != nil { + reader = pipeReader + } + return &httpdFile{ + BaseTransfer: baseTransfer, + writer: writer, + reader: reader, + isFinished: false, + } +} + +// Read reads the contents to downloads. +func (f *httpdFile) Read(p []byte) (n int, err error) { + if f.AbortTransfer.Load() { + err := f.GetAbortError() + f.TransferError(err) + return 0, err + } + + f.Connection.UpdateLastActivity() + + n, err = f.reader.Read(p) + f.BytesSent.Add(int64(n)) + + if err == nil { + err = f.CheckRead() + } + if err != nil && err != io.EOF { + f.TransferError(err) + err = f.ConvertError(err) + return + } + f.HandleThrottle() + return +} + +// Write writes the contents to upload +func (f *httpdFile) Write(p []byte) (n int, err error) { + if f.AbortTransfer.Load() { + err := f.GetAbortError() + f.TransferError(err) + return 0, err + } + + f.Connection.UpdateLastActivity() + + n, err = f.writer.Write(p) + f.BytesReceived.Add(int64(n)) + + if err == nil { + err = f.CheckWrite() + } + if err != nil { + f.TransferError(err) + err = f.ConvertError(err) + return + } + f.HandleThrottle() + return +} + +// Close closes the current transfer +func (f *httpdFile) Close() error { + if err := f.setFinished(); err != nil { + return err + } + err := f.closeIO() + errBaseClose := f.BaseTransfer.Close() + if errBaseClose != nil { + err = errBaseClose + } + + return f.Connection.GetFsError(f.Fs, err) +} + +func (f *httpdFile) closeIO() error { + var err error + if f.File != nil { + err = f.File.Close() + } else if f.writer != nil { + err = f.writer.Close() + f.Lock() + // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic + if err != nil && f.ErrTransfer == nil { + f.ErrTransfer = err + } + f.Unlock() + } else if f.reader != nil { + err = f.reader.Close() + if metadater, ok := f.reader.(vfs.Metadater); ok { + f.SetMetadata(metadater.Metadata()) + } + } + return err +} + +func (f *httpdFile) setFinished() error { + f.Lock() + defer f.Unlock() + + if f.isFinished { + return common.ErrTransferClosed + } + f.isFinished = true + return nil +} diff --git a/internal/httpd/flash.go b/internal/httpd/flash.go new file mode 100644 index 00000000..484f67dc --- /dev/null +++ b/internal/httpd/flash.go @@ -0,0 +1,95 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "time" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + flashCookieName = "message" +) + +func newFlashMessage(errorStrig, i18nMessage string) flashMessage { + return flashMessage{ + ErrorString: errorStrig, + I18nMessage: i18nMessage, + } +} + +type flashMessage struct { + ErrorString string `json:"error"` + I18nMessage string `json:"message"` +} + +func (m *flashMessage) getI18nError() *util.I18nError { + if m.ErrorString == "" && m.I18nMessage == "" { + return nil + } + return util.NewI18nError( + util.NewGenericError(m.ErrorString), + m.I18nMessage, + ) +} + +func setFlashMessage(w http.ResponseWriter, r *http.Request, message flashMessage) { + value, err := json.Marshal(message) + if err != nil { + return + } + http.SetCookie(w, &http.Cookie{ + Name: flashCookieName, + Value: base64.URLEncoding.EncodeToString(value), + Path: "/", + Expires: time.Now().Add(60 * time.Second), + MaxAge: 60, + HttpOnly: true, + Secure: isTLS(r), + SameSite: http.SameSiteLaxMode, + }) + w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) +} + +func getFlashMessage(w http.ResponseWriter, r *http.Request) flashMessage { + var msg flashMessage + cookie, err := r.Cookie(flashCookieName) + if err != nil { + return msg + } + http.SetCookie(w, &http.Cookie{ + Name: flashCookieName, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HttpOnly: true, + Secure: isTLS(r), + SameSite: http.SameSiteLaxMode, + }) + value, err := base64.URLEncoding.DecodeString(cookie.Value) + if err != nil { + return msg + } + err = json.Unmarshal(value, &msg) + if err != nil { + return flashMessage{} + } + return msg +} diff --git a/internal/httpd/flash_test.go b/internal/httpd/flash_test.go new file mode 100644 index 00000000..4a347b45 --- /dev/null +++ b/internal/httpd/flash_test.go @@ -0,0 +1,52 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func TestFlashMessages(t *testing.T) { + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/url", nil) + require.NoError(t, err) + message := flashMessage{ + ErrorString: "error", + I18nMessage: util.I18nChangePwdTitle, + } + setFlashMessage(rr, req, message) + value, err := json.Marshal(message) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("%v=%v", flashCookieName, base64.URLEncoding.EncodeToString(value))) + msg := getFlashMessage(rr, req) + assert.Equal(t, message, msg) + assert.Equal(t, util.I18nChangePwdTitle, msg.getI18nError().Message) + req.Header.Set("Cookie", fmt.Sprintf("%v=%v", flashCookieName, "a")) + msg = getFlashMessage(rr, req) + assert.Empty(t, msg) + req.Header.Set("Cookie", fmt.Sprintf("%v=%v", flashCookieName, "YQ==")) + msg = getFlashMessage(rr, req) + assert.Empty(t, msg) +} diff --git a/internal/httpd/handler.go b/internal/httpd/handler.go new file mode 100644 index 00000000..15b085fe --- /dev/null +++ b/internal/httpd/handler.go @@ -0,0 +1,376 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "io" + "net/http" + "os" + "path" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// Connection details for a HTTP connection used to inteact with an SFTPGo filesystem +type Connection struct { + *common.BaseConnection + request *http.Request + rc *http.ResponseController +} + +func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection { + rc := http.NewResponseController(w) + responseControllerDeadlines(rc, time.Time{}, time.Time{}) + return &Connection{ + BaseConnection: conn, + request: r, + rc: rc, + } +} + +// GetClientVersion returns the connected client's version. +func (c *Connection) GetClientVersion() string { + if c.request != nil { + return c.request.UserAgent() + } + return "" +} + +// GetLocalAddress returns local connection address +func (c *Connection) GetLocalAddress() string { + return util.GetHTTPLocalAddress(c.request) +} + +// GetRemoteAddress returns the connected client's address +func (c *Connection) GetRemoteAddress() string { + if c.request != nil { + return c.request.RemoteAddr + } + return "" +} + +// Disconnect closes the active transfer +func (c *Connection) Disconnect() (err error) { + if c.rc != nil { + responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second)) + } + return c.SignalTransfersAbort() +} + +// GetCommand returns the request method +func (c *Connection) GetCommand() string { + if c.request != nil { + return strings.ToUpper(c.request.Method) + } + return "" +} + +// Stat returns a FileInfo describing the named file/directory, or an error, +// if any happens +func (c *Connection) Stat(name string, mode int) (os.FileInfo, error) { + c.UpdateLastActivity() + + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { + return nil, c.GetPermissionDeniedError() + } + + fi, err := c.DoStat(name, mode, true) + if err != nil { + return nil, err + } + return fi, err +} + +// ReadDir returns a list of directory entries +func (c *Connection) ReadDir(name string) (vfs.DirLister, error) { + c.UpdateLastActivity() + + return c.ListDir(name) +} + +func (c *Connection) getFileReader(name string, offset int64, method string) (io.ReadCloser, error) { + c.UpdateLastActivity() + + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file read due to transfer count limits") + return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } + + transferQuota := c.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.Log(logger.LevelInfo, "denying file read due to quota limits") + return nil, util.NewI18nError(c.GetReadQuotaExceededError(), util.I18nErrorQuotaRead) + } + + if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(name)) { + return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } + + if ok, policy := c.User.IsFileAllowed(name); !ok { + c.Log(logger.LevelWarn, "reading file %q is not allowed", name) + return nil, util.NewI18nError(c.GetErrorForDeniedFile(policy), util.I18nError403Message) + } + + fs, p, err := c.GetFsAndResolvedPath(name) + if err != nil { + return nil, err + } + + if method != http.MethodHead { + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreDownload, p, name, 0, 0); err != nil { + c.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", name, err) + return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } + } + + file, r, cancelFn, err := fs.Open(p, offset) + if err != nil { + c.Log(logger.LevelError, "could not open file %q for reading: %+v", p, err) + return nil, c.GetFsError(fs, err) + } + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, name, common.TransferDownload, + 0, 0, 0, 0, false, fs, transferQuota) + return newHTTPDFile(baseTransfer, nil, r), nil +} + +func (c *Connection) getFileWriter(name string) (io.WriteCloser, error) { + c.UpdateLastActivity() + + if ok, _ := c.User.IsFileAllowed(name); !ok { + c.Log(logger.LevelWarn, "writing file %q is not allowed", name) + return nil, c.GetPermissionDeniedError() + } + + fs, p, err := c.GetFsAndResolvedPath(name) + if err != nil { + return nil, err + } + filePath := p + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(p) + } + + stat, statErr := fs.Lstat(p) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { + if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(name)) { + return nil, c.GetPermissionDeniedError() + } + return c.handleUploadFile(fs, p, filePath, name, true, 0) + } + + if statErr != nil { + c.Log(logger.LevelError, "error performing file stat %q: %+v", p, statErr) + return nil, c.GetFsError(fs, statErr) + } + + // This happen if we upload a file that has the same name of an existing directory + if stat.IsDir() { + c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p) + return nil, c.GetOpUnsupportedError() + } + + if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(name)) { + return nil, c.GetPermissionDeniedError() + } + + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + _, _, err = fs.Rename(p, filePath, 0) + if err != nil { + c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", + p, filePath, err) + return nil, c.GetFsError(fs, err) + } + } + + return c.handleUploadFile(fs, p, filePath, name, false, stat.Size()) +} + +func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, requestPath string, isNewFile bool, fileSize int64) (io.WriteCloser, error) { + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file write due to transfer count limits") + return nil, util.NewI18nError(c.GetPermissionDeniedError(), util.I18nError403Message) + } + diskQuota, transferQuota := c.HasSpace(isNewFile, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, common.ErrQuotaExceeded + } + _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, os.O_TRUNC) + if err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, c.GetPermissionDeniedError() + } + + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) + + file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.GetCreateChecks(requestPath, isNewFile, false)) + if err != nil { + c.Log(logger.LevelError, "error opening existing file, source: %q, err: %+v", filePath, err) + return nil, c.GetFsError(fs, err) + } + + initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota + if !isNewFile { + if vfs.HasTruncateSupport(fs) { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) + } else { + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck + } + } else { + initialSize = fileSize + truncatedSize = fileSize + } + if maxWriteSize > 0 { + maxWriteSize += fileSize + } + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota) + return newHTTPDFile(baseTransfer, w, nil), nil +} + +func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader { + t := &throttledReader{ + id: conn.GetTransferID(), + limit: limit, + r: r, + start: time.Now(), + conn: conn, + } + t.bytesRead.Store(0) + t.abortTransfer.Store(false) + conn.AddTransfer(t) + return t +} + +type throttledReader struct { + bytesRead atomic.Int64 + id int64 + limit int64 + r io.ReadCloser + abortTransfer atomic.Bool + start time.Time + conn *Connection + mu sync.Mutex + errAbort error +} + +func (t *throttledReader) GetID() int64 { + return t.id +} + +func (t *throttledReader) GetType() int { + return common.TransferUpload +} + +func (t *throttledReader) GetSize() int64 { + return t.bytesRead.Load() +} + +func (t *throttledReader) GetDownloadedSize() int64 { + return 0 +} + +func (t *throttledReader) GetUploadedSize() int64 { + return t.bytesRead.Load() +} + +func (t *throttledReader) GetVirtualPath() string { + return "**reading request body**" +} + +func (t *throttledReader) GetStartTime() time.Time { + return t.start +} + +func (t *throttledReader) GetAbortError() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.errAbort != nil { + return t.errAbort + } + return common.ErrTransferAborted +} + +func (t *throttledReader) SignalClose(err error) { + t.mu.Lock() + t.errAbort = err + t.mu.Unlock() + t.abortTransfer.Store(true) +} + +func (t *throttledReader) GetTruncatedSize() int64 { + return 0 +} + +func (t *throttledReader) HasSizeLimit() bool { + return false +} + +func (t *throttledReader) Truncate(_ string, _ int64) (int64, error) { + return 0, vfs.ErrVfsUnsupported +} + +func (t *throttledReader) GetRealFsPath(_ string) string { + return "" +} + +func (t *throttledReader) GetFsPath() string { + return "" +} + +func (t *throttledReader) SetTimes(_ string, _ time.Time, _ time.Time) bool { + return false +} + +func (t *throttledReader) Read(p []byte) (n int, err error) { + if t.abortTransfer.Load() { + return 0, t.GetAbortError() + } + + t.conn.UpdateLastActivity() + n, err = t.r.Read(p) + if t.limit > 0 { + t.bytesRead.Add(int64(n)) + trasferredBytes := t.bytesRead.Load() + elapsed := time.Since(t.start).Nanoseconds() / 1000000 + wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit + if wantedElapsed > elapsed { + toSleep := time.Duration(wantedElapsed - elapsed) + time.Sleep(toSleep * time.Millisecond) + } + } + return +} + +func (t *throttledReader) Close() error { + return t.r.Close() +} diff --git a/internal/httpd/httpd.go b/internal/httpd/httpd.go new file mode 100644 index 00000000..13e39206 --- /dev/null +++ b/internal/httpd/httpd.go @@ -0,0 +1,1460 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package httpd implements REST API and Web interface for SFTPGo. +// The OpenAPI 3 schema for the supported API can be found inside the source tree: +// https://github.com/drakkan/sftpgo/blob/main/openapi/openapi.yaml +package httpd + +import ( + "crypto/sha256" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + logSender = "httpd" + tokenPath = "/api/v2/token" + logoutPath = "/api/v2/logout" + userTokenPath = "/api/v2/user/token" + userLogoutPath = "/api/v2/user/logout" + activeConnectionsPath = "/api/v2/connections" + quotasBasePath = "/api/v2/quotas" + userPath = "/api/v2/users" + versionPath = "/api/v2/version" + folderPath = "/api/v2/folders" + groupPath = "/api/v2/groups" + serverStatusPath = "/api/v2/status" + dumpDataPath = "/api/v2/dumpdata" + loadDataPath = "/api/v2/loaddata" + defenderHosts = "/api/v2/defender/hosts" + adminPath = "/api/v2/admins" + adminPwdPath = "/api/v2/admin/changepwd" + adminProfilePath = "/api/v2/admin/profile" + userPwdPath = "/api/v2/user/changepwd" + userDirsPath = "/api/v2/user/dirs" + userFilesPath = "/api/v2/user/files" + userFileActionsPath = "/api/v2/user/file-actions" + userStreamZipPath = "/api/v2/user/streamzip" + userUploadFilePath = "/api/v2/user/files/upload" + userFilesDirsMetadataPath = "/api/v2/user/files/metadata" + apiKeysPath = "/api/v2/apikeys" + adminTOTPConfigsPath = "/api/v2/admin/totp/configs" + adminTOTPGeneratePath = "/api/v2/admin/totp/generate" + adminTOTPValidatePath = "/api/v2/admin/totp/validate" + adminTOTPSavePath = "/api/v2/admin/totp/save" + admin2FARecoveryCodesPath = "/api/v2/admin/2fa/recoverycodes" + userTOTPConfigsPath = "/api/v2/user/totp/configs" + userTOTPGeneratePath = "/api/v2/user/totp/generate" + userTOTPValidatePath = "/api/v2/user/totp/validate" + userTOTPSavePath = "/api/v2/user/totp/save" + user2FARecoveryCodesPath = "/api/v2/user/2fa/recoverycodes" + userProfilePath = "/api/v2/user/profile" + userSharesPath = "/api/v2/user/shares" + retentionChecksPath = "/api/v2/retention/users/checks" + fsEventsPath = "/api/v2/events/fs" + providerEventsPath = "/api/v2/events/provider" + logEventsPath = "/api/v2/events/logs" + sharesPath = "/api/v2/shares" + eventActionsPath = "/api/v2/eventactions" + eventRulesPath = "/api/v2/eventrules" + rolesPath = "/api/v2/roles" + ipListsPath = "/api/v2/iplists" + healthzPath = "/healthz" + webRootPathDefault = "/" + webBasePathDefault = "/web" + webBasePathAdminDefault = "/web/admin" + webBasePathClientDefault = "/web/client" + webAdminSetupPathDefault = "/web/admin/setup" + webAdminLoginPathDefault = "/web/admin/login" + webAdminOIDCLoginPathDefault = "/web/admin/oidclogin" + webOIDCRedirectPathDefault = "/web/oidc/redirect" + webOAuth2RedirectPathDefault = "/web/oauth2/redirect" + webOAuth2TokenPathDefault = "/web/admin/oauth2/token" + webAdminTwoFactorPathDefault = "/web/admin/twofactor" + webAdminTwoFactorRecoveryPathDefault = "/web/admin/twofactor-recovery" + webLogoutPathDefault = "/web/admin/logout" + webUsersPathDefault = "/web/admin/users" + webUserPathDefault = "/web/admin/user" + webConnectionsPathDefault = "/web/admin/connections" + webFoldersPathDefault = "/web/admin/folders" + webFolderPathDefault = "/web/admin/folder" + webGroupsPathDefault = "/web/admin/groups" + webGroupPathDefault = "/web/admin/group" + webStatusPathDefault = "/web/admin/status" + webAdminsPathDefault = "/web/admin/managers" + webAdminPathDefault = "/web/admin/manager" + webMaintenancePathDefault = "/web/admin/maintenance" + webBackupPathDefault = "/web/admin/backup" + webRestorePathDefault = "/web/admin/restore" + webScanVFolderPathDefault = "/web/admin/quotas/scanfolder" + webQuotaScanPathDefault = "/web/admin/quotas/scanuser" + webChangeAdminPwdPathDefault = "/web/admin/changepwd" + webAdminForgotPwdPathDefault = "/web/admin/forgot-password" + webAdminResetPwdPathDefault = "/web/admin/reset-password" + webAdminProfilePathDefault = "/web/admin/profile" + webAdminMFAPathDefault = "/web/admin/mfa" + webAdminEventRulesPathDefault = "/web/admin/eventrules" + webAdminEventRulePathDefault = "/web/admin/eventrule" + webAdminEventActionsPathDefault = "/web/admin/eventactions" + webAdminEventActionPathDefault = "/web/admin/eventaction" + webAdminRolesPathDefault = "/web/admin/roles" + webAdminRolePathDefault = "/web/admin/role" + webAdminTOTPGeneratePathDefault = "/web/admin/totp/generate" + webAdminTOTPValidatePathDefault = "/web/admin/totp/validate" + webAdminTOTPSavePathDefault = "/web/admin/totp/save" + webAdminRecoveryCodesPathDefault = "/web/admin/recoverycodes" + webTemplateUserDefault = "/web/admin/template/user" + webTemplateFolderDefault = "/web/admin/template/folder" + webDefenderPathDefault = "/web/admin/defender" + webIPListsPathDefault = "/web/admin/ip-lists" + webIPListPathDefault = "/web/admin/ip-list" + webDefenderHostsPathDefault = "/web/admin/defender/hosts" + webEventsPathDefault = "/web/admin/events" + webEventsFsSearchPathDefault = "/web/admin/events/fs" + webEventsProviderSearchPathDefault = "/web/admin/events/provider" + webEventsLogSearchPathDefault = "/web/admin/events/logs" + webConfigsPathDefault = "/web/admin/configs" + webClientLoginPathDefault = "/web/client/login" + webClientOIDCLoginPathDefault = "/web/client/oidclogin" + webClientTwoFactorPathDefault = "/web/client/twofactor" + webClientTwoFactorRecoveryPathDefault = "/web/client/twofactor-recovery" + webClientFilesPathDefault = "/web/client/files" + webClientFilePathDefault = "/web/client/file" + webClientFileActionsPathDefault = "/web/client/file-actions" + webClientSharesPathDefault = "/web/client/shares" + webClientSharePathDefault = "/web/client/share" + webClientEditFilePathDefault = "/web/client/editfile" + webClientDirsPathDefault = "/web/client/dirs" + webClientDownloadZipPathDefault = "/web/client/downloadzip" + webClientProfilePathDefault = "/web/client/profile" + webClientPingPathDefault = "/web/client/ping" + webClientMFAPathDefault = "/web/client/mfa" + webClientTOTPGeneratePathDefault = "/web/client/totp/generate" + webClientTOTPValidatePathDefault = "/web/client/totp/validate" + webClientTOTPSavePathDefault = "/web/client/totp/save" + webClientRecoveryCodesPathDefault = "/web/client/recoverycodes" + webChangeClientPwdPathDefault = "/web/client/changepwd" + webClientLogoutPathDefault = "/web/client/logout" + webClientPubSharesPathDefault = "/web/client/pubshares" + webClientForgotPwdPathDefault = "/web/client/forgot-password" + webClientResetPwdPathDefault = "/web/client/reset-password" + webClientViewPDFPathDefault = "/web/client/viewpdf" + webClientGetPDFPathDefault = "/web/client/getpdf" + webClientExistPathDefault = "/web/client/exist" + webClientTasksPathDefault = "/web/client/tasks" + webStaticFilesPathDefault = "/static" + webOpenAPIPathDefault = "/openapi" + // MaxRestoreSize defines the max size for the loaddata input file + MaxRestoreSize = 20 * 1048576 // 20 MB + maxRequestSize = 1048576 // 1MB + maxLoginBodySize = 262144 // 256 KB + httpdMaxEditFileSize = 2 * 1048576 // 2 MB + maxMultipartMem = 10 * 1048576 // 10 MB + osWindows = "windows" + otpHeaderCode = "X-SFTPGO-OTP" + mTimeHeader = "X-SFTPGO-MTIME" + acmeChallengeURI = "/.well-known/acme-challenge/" +) + +var ( + certMgr *common.CertManager + cleanupTicker *time.Ticker + cleanupDone chan bool + invalidatedJWTTokens tokenManager + webRootPath string + webBasePath string + webBaseAdminPath string + webBaseClientPath string + webOIDCRedirectPath string + webOAuth2RedirectPath string + webOAuth2TokenPath string + webAdminSetupPath string + webAdminOIDCLoginPath string + webAdminLoginPath string + webAdminTwoFactorPath string + webAdminTwoFactorRecoveryPath string + webLogoutPath string + webUsersPath string + webUserPath string + webConnectionsPath string + webFoldersPath string + webFolderPath string + webGroupsPath string + webGroupPath string + webStatusPath string + webAdminsPath string + webAdminPath string + webMaintenancePath string + webBackupPath string + webRestorePath string + webScanVFolderPath string + webQuotaScanPath string + webAdminProfilePath string + webAdminMFAPath string + webAdminEventRulesPath string + webAdminEventRulePath string + webAdminEventActionsPath string + webAdminEventActionPath string + webAdminRolesPath string + webAdminRolePath string + webAdminTOTPGeneratePath string + webAdminTOTPValidatePath string + webAdminTOTPSavePath string + webAdminRecoveryCodesPath string + webChangeAdminPwdPath string + webAdminForgotPwdPath string + webAdminResetPwdPath string + webTemplateUser string + webTemplateFolder string + webDefenderPath string + webIPListPath string + webIPListsPath string + webEventsPath string + webEventsFsSearchPath string + webEventsProviderSearchPath string + webEventsLogSearchPath string + webConfigsPath string + webDefenderHostsPath string + webClientLoginPath string + webClientOIDCLoginPath string + webClientTwoFactorPath string + webClientTwoFactorRecoveryPath string + webClientFilesPath string + webClientFilePath string + webClientFileActionsPath string + webClientSharesPath string + webClientSharePath string + webClientEditFilePath string + webClientDirsPath string + webClientDownloadZipPath string + webClientProfilePath string + webClientPingPath string + webChangeClientPwdPath string + webClientMFAPath string + webClientTOTPGeneratePath string + webClientTOTPValidatePath string + webClientTOTPSavePath string + webClientRecoveryCodesPath string + webClientPubSharesPath string + webClientLogoutPath string + webClientForgotPwdPath string + webClientResetPwdPath string + webClientViewPDFPath string + webClientGetPDFPath string + webClientExistPath string + webClientTasksPath string + webStaticFilesPath string + webOpenAPIPath string + // max upload size for http clients, 1GB by default + maxUploadFileSize = int64(1048576000) + hideSupportLink bool + installationCode string + installationCodeHint string + fnInstallationCodeResolver FnInstallationCodeResolver + configurationDir string + dbBrandingConfig brandingCache +) + +func init() { + updateWebAdminURLs("") + updateWebClientURLs("") + acme.SetReloadHTTPDCertsFn(ReloadCertificateMgr) + common.SetUpdateBrandingFn(dbBrandingConfig.Set) +} + +type brandingCache struct { + mu sync.RWMutex + configs *dataprovider.BrandingConfigs +} + +func (b *brandingCache) Set(configs *dataprovider.BrandingConfigs) { + b.mu.Lock() + defer b.mu.Unlock() + + b.configs = configs +} + +func (b *brandingCache) getWebAdminLogo() []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + return b.configs.WebAdmin.Logo +} + +func (b *brandingCache) getWebAdminFavicon() []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + return b.configs.WebAdmin.Favicon +} + +func (b *brandingCache) getWebClientLogo() []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + return b.configs.WebClient.Logo +} + +func (b *brandingCache) getWebClientFavicon() []byte { + b.mu.RLock() + defer b.mu.RUnlock() + + return b.configs.WebClient.Favicon +} + +func (b *brandingCache) mergeBrandingConfig(branding UIBranding, isWebClient bool) UIBranding { + b.mu.RLock() + defer b.mu.RUnlock() + + var urlPrefix string + var cfg dataprovider.BrandingConfig + if isWebClient { + cfg = b.configs.WebClient + urlPrefix = "webclient" + } else { + cfg = b.configs.WebAdmin + urlPrefix = "webadmin" + } + if cfg.Name != "" { + branding.Name = cfg.Name + } + if cfg.ShortName != "" { + branding.ShortName = cfg.ShortName + } + if cfg.DisclaimerName != "" { + branding.DisclaimerName = cfg.DisclaimerName + } + if cfg.DisclaimerURL != "" { + branding.DisclaimerPath = cfg.DisclaimerURL + } + if len(cfg.Logo) > 0 { + branding.LogoPath = path.Join("/", "branding", urlPrefix, "logo.png") + } + if len(cfg.Favicon) > 0 { + branding.FaviconPath = path.Join("/", "branding", urlPrefix, "favicon.png") + } + return branding +} + +// FnInstallationCodeResolver defines a method to get the installation code. +// If the installation code cannot be resolved the provided default must be returned +type FnInstallationCodeResolver func(defaultInstallationCode string) string + +// HTTPSProxyHeader defines an HTTPS proxy header as key/value. +// For example Key could be "X-Forwarded-Proto" and Value "https" +type HTTPSProxyHeader struct { + Key string + Value string +} + +// SecurityConf allows to add some security related headers to HTTP responses and to restrict allowed hosts +type SecurityConf struct { + // Set to true to enable the security configurations + Enabled bool `json:"enabled" mapstructure:"enabled"` + // AllowedHosts is a list of fully qualified domain names that are allowed. + // Default is empty list, which allows any and all host names. + AllowedHosts []string `json:"allowed_hosts" mapstructure:"allowed_hosts"` + // AllowedHostsAreRegex determines if the provided allowed hosts contains valid regular expressions + AllowedHostsAreRegex bool `json:"allowed_hosts_are_regex" mapstructure:"allowed_hosts_are_regex"` + // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. + HostsProxyHeaders []string `json:"hosts_proxy_headers" mapstructure:"hosts_proxy_headers"` + // Set to true to redirect HTTP requests to HTTPS + HTTPSRedirect bool `json:"https_redirect" mapstructure:"https_redirect"` + // HTTPSHost defines the host name that is used to redirect HTTP requests to HTTPS. + // Default is "", which indicates to use the same host. + HTTPSHost string `json:"https_host" mapstructure:"https_host"` + // HTTPSProxyHeaders is a list of header keys with associated values that would indicate a valid https request. + HTTPSProxyHeaders []HTTPSProxyHeader `json:"https_proxy_headers" mapstructure:"https_proxy_headers"` + // STSSeconds is the max-age of the Strict-Transport-Security header. + // Default is 0, which would NOT include the header. + STSSeconds int64 `json:"sts_seconds" mapstructure:"sts_seconds"` + // If STSIncludeSubdomains is set to true, the "includeSubdomains" will be appended to the + // Strict-Transport-Security header. Default is false. + STSIncludeSubdomains bool `json:"sts_include_subdomains" mapstructure:"sts_include_subdomains"` + // If STSPreload is set to true, the `preload` flag will be appended to the + // Strict-Transport-Security header. Default is false. + STSPreload bool `json:"sts_preload" mapstructure:"sts_preload"` + // If ContentTypeNosniff is true, adds the X-Content-Type-Options header with the value "nosniff". Default is false. + ContentTypeNosniff bool `json:"content_type_nosniff" mapstructure:"content_type_nosniff"` + // ContentSecurityPolicy allows to set the Content-Security-Policy header value. Default is "". + ContentSecurityPolicy string `json:"content_security_policy" mapstructure:"content_security_policy"` + // PermissionsPolicy allows to set the Permissions-Policy header value. Default is "". + PermissionsPolicy string `json:"permissions_policy" mapstructure:"permissions_policy"` + // CrossOriginOpenerPolicy allows to set the Cross-Origin-Opener-Policy header value. Default is "". + CrossOriginOpenerPolicy string `json:"cross_origin_opener_policy" mapstructure:"cross_origin_opener_policy"` + // CrossOriginResourcePolicy allows to set the Cross-Origin-Resource-Policy header value. Default is "". + CrossOriginResourcePolicy string `json:"cross_origin_resource_policy" mapstructure:"cross_origin_resource_policy"` + // CrossOriginEmbedderPolicy allows to set the Cross-Origin-Embedder-Policy header value. Default is "". + CrossOriginEmbedderPolicy string `json:"cross_origin_embedder_policy" mapstructure:"cross_origin_embedder_policy"` + // CacheControl allows to set the Cache-Control header value. + CacheControl string `json:"cache_control" mapstructure:"cache_control"` + // ReferrerPolicy allows to set the Referrer-Policy header values. + ReferrerPolicy string `json:"referrer_policy" mapstructure:"referrer_policy"` + proxyHeaders []string +} + +func (s *SecurityConf) updateProxyHeaders() { + if !s.Enabled { + s.proxyHeaders = nil + return + } + s.proxyHeaders = s.HostsProxyHeaders + for _, httpsProxyHeader := range s.HTTPSProxyHeaders { + s.proxyHeaders = append(s.proxyHeaders, httpsProxyHeader.Key) + } +} + +func (s *SecurityConf) getHTTPSProxyHeaders() map[string]string { + headers := make(map[string]string) + for _, httpsProxyHeader := range s.HTTPSProxyHeaders { + headers[httpsProxyHeader.Key] = httpsProxyHeader.Value + } + return headers +} + +func (s *SecurityConf) redirectHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !isTLS(r) && !strings.HasPrefix(r.RequestURI, acmeChallengeURI) { + url := r.URL + url.Scheme = "https" + if s.HTTPSHost != "" { + url.Host = s.HTTPSHost + } else { + host := r.Host + for _, header := range s.HostsProxyHeaders { + if h := r.Header.Get(header); h != "" { + host = h + break + } + } + url.Host = host + } + http.Redirect(w, r, url.String(), http.StatusTemporaryRedirect) + return + } + next.ServeHTTP(w, r) + }) +} + +// UIBranding defines the supported customizations for the web UIs +type UIBranding struct { + // Name defines the text to show at the login page and as HTML title + Name string `json:"name" mapstructure:"name"` + // ShortName defines the name to show next to the logo image + ShortName string `json:"short_name" mapstructure:"short_name"` + // Path to your logo relative to "static_files_path". + // For example, if you create a directory named "branding" inside the static dir and + // put the "mylogo.png" file in it, you must set "/branding/mylogo.png" as logo path. + LogoPath string `json:"logo_path" mapstructure:"logo_path"` + // Path to your favicon relative to "static_files_path" + FaviconPath string `json:"favicon_path" mapstructure:"favicon_path"` + // DisclaimerName defines the name for the link to your optional disclaimer + DisclaimerName string `json:"disclaimer_name" mapstructure:"disclaimer_name"` + // Path to the HTML page for your disclaimer relative to "static_files_path" + // or an absolute http/https URL. + DisclaimerPath string `json:"disclaimer_path" mapstructure:"disclaimer_path"` + // Path to custom CSS files, relative to "static_files_path", which replaces + // the default CSS files + DefaultCSS []string `json:"default_css" mapstructure:"default_css"` + // Additional CSS file paths, relative to "static_files_path", to include + ExtraCSS []string `json:"extra_css" mapstructure:"extra_css"` + DefaultLogoPath string `json:"-" mapstructure:"-"` + DefaultFaviconPath string `json:"-" mapstructure:"-"` +} + +func (b *UIBranding) check() { + b.DefaultLogoPath = "/img/logo.png" + b.DefaultFaviconPath = "/favicon.png" + if b.LogoPath != "" { + b.LogoPath = util.CleanPath(b.LogoPath) + } else { + b.LogoPath = b.DefaultLogoPath + } + if b.FaviconPath != "" { + b.FaviconPath = util.CleanPath(b.FaviconPath) + } else { + b.FaviconPath = b.DefaultFaviconPath + } + if b.DisclaimerPath != "" { + if !strings.HasPrefix(b.DisclaimerPath, "https://") && !strings.HasPrefix(b.DisclaimerPath, "http://") { + b.DisclaimerPath = path.Join(webStaticFilesPath, util.CleanPath(b.DisclaimerPath)) + } + } + if len(b.DefaultCSS) > 0 { + for idx := range b.DefaultCSS { + b.DefaultCSS[idx] = util.CleanPath(b.DefaultCSS[idx]) + } + } else { + b.DefaultCSS = []string{ + "/assets/plugins/global/plugins.bundle.css", + "/assets/css/style.bundle.css", + } + } + for idx := range b.ExtraCSS { + b.ExtraCSS[idx] = util.CleanPath(b.ExtraCSS[idx]) + } +} + +// Branding defines the branding-related customizations supported +type Branding struct { + WebAdmin UIBranding `json:"web_admin" mapstructure:"web_admin"` + WebClient UIBranding `json:"web_client" mapstructure:"web_client"` +} + +// WebClientIntegration defines the configuration for an external Web Client integration +type WebClientIntegration struct { + // Files with these extensions can be sent to the configured URL + FileExtensions []string `json:"file_extensions" mapstructure:"file_extensions"` + // URL that will receive the files + URL string `json:"url" mapstructure:"url"` +} + +// Binding defines the configuration for a network listener +type Binding struct { + // The address to listen on. A blank value means listen on all available network interfaces. + Address string `json:"address" mapstructure:"address"` + // The port used for serving requests + Port int `json:"port" mapstructure:"port"` + // Enable the built-in admin interface. + // You have to define TemplatesPath and StaticFilesPath for this to work + EnableWebAdmin bool `json:"enable_web_admin" mapstructure:"enable_web_admin"` + // Enable the built-in client interface. + // You have to define TemplatesPath and StaticFilesPath for this to work + EnableWebClient bool `json:"enable_web_client" mapstructure:"enable_web_client"` + // Enable REST API + EnableRESTAPI bool `json:"enable_rest_api" mapstructure:"enable_rest_api"` + // Defines the login methods available for the WebAdmin and WebClient UIs: + // + // - 0 means any configured method: username/password login form and OIDC, if enabled + // - 1 means OIDC for the WebAdmin UI + // - 2 means OIDC for the WebClient UI + // - 4 means login form for the WebAdmin UI + // - 8 means login form for the WebClient UI + // + // You can combine the values. For example 3 means that you can only login using OIDC on + // both WebClient and WebAdmin UI. + // Deprecated because it is not extensible, use DisabledLoginMethods + EnabledLoginMethods int `json:"enabled_login_methods" mapstructure:"enabled_login_methods"` + // Defines the login methods disabled for the WebAdmin and WebClient UIs: + // + // - 1 means OIDC for the WebAdmin UI + // - 2 means OIDC for the WebClient UI + // - 4 means login form for the WebAdmin UI + // - 8 means login form for the WebClient UI + // - 16 means basic auth for admin REST API + // - 32 means basic auth for user REST API + // - 64 means API key auth for admins + // - 128 means API key auth for users + // You can combine the values. For example 12 means that you can only login using OIDC on + // both WebClient and WebAdmin UI. + DisabledLoginMethods int `json:"disabled_login_methods" mapstructure:"disabled_login_methods"` + // you also need to provide a certificate for enabling HTTPS + EnableHTTPS bool `json:"enable_https" mapstructure:"enable_https"` + // Certificate and matching private key for this specific binding, if empty the global + // ones will be used, if any + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 + MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` + // set to 1 to require client certificate authentication in addition to basic auth. + // You need to define at least a certificate authority for this to work + ClientAuthType int `json:"client_auth_type" mapstructure:"client_auth_type"` + // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. + // If CipherSuites is nil/empty, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. + // Note that TLS 1.3 ciphersuites are not configurable. + // The supported ciphersuites names are defined here: + // + // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 + // + // any invalid name will be silently ignored. + // The order matters, the ciphers listed first will be the preferred ones. + TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` + // HTTP protocols in preference order. Supported values: http/1.1, h2 + Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"` + // Defines whether to use the common proxy protocol configuration or the + // binding-specific proxy header configuration. + ProxyMode int `json:"proxy_mode" mapstructure:"proxy_mode"` + // List of IP addresses and IP ranges allowed to set client IP proxy headers and + // X-Forwarded-Proto header. + ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` + // Allowed client IP proxy header such as "X-Forwarded-For", "X-Real-IP" + ClientIPProxyHeader string `json:"client_ip_proxy_header" mapstructure:"client_ip_proxy_header"` + // Some client IP headers such as "X-Forwarded-For" can contain multiple IP address, this setting + // define the position to trust starting from the right. For example if we have: + // "10.0.0.1,11.0.0.1,12.0.0.1,13.0.0.1" and the depth is 0, SFTPGo will use "13.0.0.1" + // as client IP, if depth is 1, "12.0.0.1" will be used and so on + ClientIPHeaderDepth int `json:"client_ip_header_depth" mapstructure:"client_ip_header_depth"` + // If both web admin and web client are enabled each login page will show a link + // to the other one. This setting allows to hide this link: + // - 0 login links are displayed on both admin and client login page. This is the default + // - 1 the login link to the web client login page is hidden on admin login page + // - 2 the login link to the web admin login page is hidden on client login page + // The flags can be combined, for example 3 will disable both login links. + HideLoginURL int `json:"hide_login_url" mapstructure:"hide_login_url"` + // Enable the built-in OpenAPI renderer + RenderOpenAPI bool `json:"render_openapi" mapstructure:"render_openapi"` + // BaseURL defines the external base URL for generating public links + // (currently share access link), bypassing the default browser-based + // detection. + BaseURL string `json:"base_url" mapstructure:"base_url"` + // Languages defines the list of enabled translations for the WebAdmin and WebClient UI. + Languages []string `json:"languages" mapstructure:"languages"` + // Defining an OIDC configuration the web admin and web client UI will use OpenID to authenticate users. + OIDC OIDC `json:"oidc" mapstructure:"oidc"` + // Security defines security headers to add to HTTP responses and allows to restrict allowed hosts + Security SecurityConf `json:"security" mapstructure:"security"` + // Branding defines customizations to suit your brand + Branding Branding `json:"branding" mapstructure:"branding"` + allowHeadersFrom []func(net.IP) bool +} + +func (b *Binding) checkBranding() { + b.Branding.WebAdmin.check() + b.Branding.WebClient.check() + if b.Branding.WebAdmin.Name == "" { + b.Branding.WebAdmin.Name = "SFTPGo WebAdmin" + } + if b.Branding.WebAdmin.ShortName == "" { + b.Branding.WebAdmin.ShortName = "WebAdmin" + } + if b.Branding.WebClient.Name == "" { + b.Branding.WebClient.Name = "SFTPGo WebClient" + } + if b.Branding.WebClient.ShortName == "" { + b.Branding.WebClient.ShortName = "WebClient" + } +} + +func (b *Binding) webAdminBranding() UIBranding { + return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebAdmin, false) +} + +func (b *Binding) webClientBranding() UIBranding { + return dbBrandingConfig.mergeBrandingConfig(b.Branding.WebClient, true) +} + +func (b *Binding) languages() []string { + return b.Languages +} + +func (b *Binding) validateBaseURL() error { + if b.BaseURL == "" { + return nil + } + u, err := url.ParseRequestURI(b.BaseURL) + if err != nil { + return err + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("invalid base URL schema %s", b.BaseURL) + } + if u.Host == "" { + return fmt.Errorf("invalid base URL host %s", b.BaseURL) + } + b.BaseURL = strings.TrimRight(u.String(), "/") + return nil +} + +func (b *Binding) parseAllowedProxy() error { + if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 { + // unix domain socket + b.allowHeadersFrom = []func(net.IP) bool{func(_ net.IP) bool { return true }} + return nil + } + allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed) + if err != nil { + return err + } + b.allowHeadersFrom = allowedFuncs + return nil +} + +// GetAddress returns the binding address +func (b *Binding) GetAddress() string { + return fmt.Sprintf("%s:%d", b.Address, b.Port) +} + +// IsValid returns true if the binding is valid +func (b *Binding) IsValid() bool { + if !b.EnableRESTAPI && !b.EnableWebAdmin && !b.EnableWebClient { + return false + } + if b.Port > 0 { + return true + } + if filepath.IsAbs(b.Address) && runtime.GOOS != osWindows { + return true + } + return false +} + +func (b *Binding) check() error { + if err := b.parseAllowedProxy(); err != nil { + return err + } + if err := b.validateBaseURL(); err != nil { + return err + } + b.checkBranding() + b.Security.updateProxyHeaders() + return nil +} + +func (b *Binding) isWebAdminOIDCLoginDisabled() bool { + if b.EnableWebAdmin { + return b.DisabledLoginMethods&1 != 0 + } + return false +} + +func (b *Binding) isWebClientOIDCLoginDisabled() bool { + if b.EnableWebClient { + return b.DisabledLoginMethods&2 != 0 + } + return false +} + +func (b *Binding) isWebAdminLoginFormDisabled() bool { + if b.EnableWebAdmin { + return b.DisabledLoginMethods&4 != 0 + } + return false +} + +func (b *Binding) isWebClientLoginFormDisabled() bool { + if b.EnableWebClient { + return b.DisabledLoginMethods&8 != 0 + } + return false +} + +func (b *Binding) isAdminTokenEndpointDisabled() bool { + return b.DisabledLoginMethods&16 != 0 +} + +func (b *Binding) isUserTokenEndpointDisabled() bool { + return b.DisabledLoginMethods&32 != 0 +} + +func (b *Binding) isAdminAPIKeyAuthDisabled() bool { + return b.DisabledLoginMethods&64 != 0 +} + +func (b *Binding) isUserAPIKeyAuthDisabled() bool { + return b.DisabledLoginMethods&128 != 0 +} + +func (b *Binding) hasLoginForAPI() bool { + return !b.isAdminTokenEndpointDisabled() || !b.isUserTokenEndpointDisabled() || + !b.isAdminAPIKeyAuthDisabled() || !b.isUserAPIKeyAuthDisabled() +} + +// convertLoginMethods checks if the deprecated EnabledLoginMethods is set and +// convert the value to DisabledLoginMethods. +func (b *Binding) convertLoginMethods() { + if b.DisabledLoginMethods > 0 || b.EnabledLoginMethods == 0 { + // DisabledLoginMethods already in use or EnabledLoginMethods not set. + return + } + if b.EnabledLoginMethods&1 == 0 { + b.DisabledLoginMethods++ + } + if b.EnabledLoginMethods&2 == 0 { + b.DisabledLoginMethods += 2 + } + if b.EnabledLoginMethods&4 == 0 { + b.DisabledLoginMethods += 4 + } + if b.EnabledLoginMethods&8 == 0 { + b.DisabledLoginMethods += 8 + } +} + +func (b *Binding) checkLoginMethods() error { + b.convertLoginMethods() + if b.isWebAdminLoginFormDisabled() && b.isWebAdminOIDCLoginDisabled() { + return errors.New("no login method available for WebAdmin UI") + } + if !b.isWebAdminOIDCLoginDisabled() { + if b.isWebAdminLoginFormDisabled() && !b.OIDC.hasRoles() { + return errors.New("no login method available for WebAdmin UI") + } + } + if b.isWebClientLoginFormDisabled() && b.isWebClientOIDCLoginDisabled() { + return errors.New("no login method available for WebClient UI") + } + if !b.isWebClientOIDCLoginDisabled() { + if b.isWebClientLoginFormDisabled() && !b.OIDC.isEnabled() { + return errors.New("no login method available for WebClient UI") + } + } + if b.EnableRESTAPI && !b.hasLoginForAPI() { + return errors.New("no login method available for REST API") + } + return nil +} + +func (b *Binding) showAdminLoginURL() bool { + if !b.EnableWebAdmin { + return false + } + if b.HideLoginURL&2 != 0 { + return false + } + return true +} + +func (b *Binding) showClientLoginURL() bool { + if !b.EnableWebClient { + return false + } + if b.HideLoginURL&1 != 0 { + return false + } + return true +} + +func (b *Binding) isMutualTLSEnabled() bool { + return b.ClientAuthType == 1 +} + +func (b *Binding) listenerWrapper() func(net.Listener) (net.Listener, error) { + if b.ProxyMode == 1 { + return common.Config.GetProxyListener + } + return nil +} + +type defenderStatus struct { + IsActive bool `json:"is_active"` +} + +type allowListStatus struct { + IsActive bool `json:"is_active"` +} + +type rateLimiters struct { + IsActive bool `json:"is_active"` + Protocols []string `json:"protocols"` +} + +// GetProtocolsAsString returns the enabled protocols as comma separated string +func (r *rateLimiters) GetProtocolsAsString() string { + return strings.Join(r.Protocols, ", ") +} + +// ServicesStatus keep the state of the running services +type ServicesStatus struct { + SSH sftpd.ServiceStatus `json:"ssh"` + FTP ftpd.ServiceStatus `json:"ftp"` + WebDAV webdavd.ServiceStatus `json:"webdav"` + DataProvider dataprovider.ProviderStatus `json:"data_provider"` + Defender defenderStatus `json:"defender"` + MFA mfa.ServiceStatus `json:"mfa"` + AllowList allowListStatus `json:"allow_list"` + RateLimiters rateLimiters `json:"rate_limiters"` +} + +// SetupConfig defines the configuration parameters for the initial web admin setup +type SetupConfig struct { + // Installation code to require when creating the first admin account. + // As for the other configurations, this value is read at SFTPGo startup and not at runtime + // even if set using an environment variable. + // This is not a license key or similar, the purpose here is to prevent anyone who can access + // to the initial setup screen from creating an admin user + InstallationCode string `json:"installation_code" mapstructure:"installation_code"` + // Description for the installation code input field + InstallationCodeHint string `json:"installation_code_hint" mapstructure:"installation_code_hint"` +} + +// CorsConfig defines the CORS configuration +type CorsConfig struct { + AllowedOrigins []string `json:"allowed_origins" mapstructure:"allowed_origins"` + AllowedMethods []string `json:"allowed_methods" mapstructure:"allowed_methods"` + AllowedHeaders []string `json:"allowed_headers" mapstructure:"allowed_headers"` + ExposedHeaders []string `json:"exposed_headers" mapstructure:"exposed_headers"` + AllowCredentials bool `json:"allow_credentials" mapstructure:"allow_credentials"` + Enabled bool `json:"enabled" mapstructure:"enabled"` + MaxAge int `json:"max_age" mapstructure:"max_age"` + OptionsPassthrough bool `json:"options_passthrough" mapstructure:"options_passthrough"` + OptionsSuccessStatus int `json:"options_success_status" mapstructure:"options_success_status"` + AllowPrivateNetwork bool `json:"allow_private_network" mapstructure:"allow_private_network"` +} + +// Conf httpd daemon configuration +type Conf struct { + // Addresses and ports to bind to + Bindings []Binding `json:"bindings" mapstructure:"bindings"` + // Path to the HTML web templates. This can be an absolute path or a path relative to the config dir + TemplatesPath string `json:"templates_path" mapstructure:"templates_path"` + // Path to the static files for the web interface. This can be an absolute path or a path relative to the config dir. + // If both TemplatesPath and StaticFilesPath are empty the built-in web interface will be disabled + StaticFilesPath string `json:"static_files_path" mapstructure:"static_files_path"` + // Path to the backup directory. This can be an absolute path or a path relative to the config dir + //BackupsPath string `json:"backups_path" mapstructure:"backups_path"` + // Path to the directory that contains the OpenAPI schema and the default renderer. + // This can be an absolute path or a path relative to the config dir + OpenAPIPath string `json:"openapi_path" mapstructure:"openapi_path"` + // Defines a base URL for the web admin and client interfaces. If empty web admin and client resources will + // be available at the root ("/") URI. If defined it must be an absolute URI or it will be ignored. + WebRoot string `json:"web_root" mapstructure:"web_root"` + // If files containing a certificate and matching private key for the server are provided you can enable + // HTTPS connections for the configured bindings. + // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a + // "paramchange" request to the running service on Windows. + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // CACertificates defines the set of root certificate authorities to be used to verify client certificates. + CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` + // CARevocationLists defines a set a revocation lists, one for each root CA, to be used to check + // if a client certificate has been revoked + CARevocationLists []string `json:"ca_revocation_lists" mapstructure:"ca_revocation_lists"` + // SigningPassphrase defines the passphrase to use to derive the signing key for JWT and CSRF tokens. + // If empty a random signing key will be generated each time SFTPGo starts. If you set a + // signing passphrase you should consider rotating it periodically for added security + SigningPassphrase string `json:"signing_passphrase" mapstructure:"signing_passphrase"` + SigningPassphraseFile string `json:"signing_passphrase_file" mapstructure:"signing_passphrase_file"` + // TokenValidation allows to define how to validate JWT tokens, cookies and CSRF tokens. + // By default all the available security checks are enabled. Set to 1 to disable the requirement + // that a token must be used by the same IP for which it was issued. + TokenValidation int `json:"token_validation" mapstructure:"token_validation"` + // CookieLifetime defines the duration of cookies for WebAdmin and WebClient + CookieLifetime int `json:"cookie_lifetime" mapstructure:"cookie_lifetime"` + // ShareCookieLifetime defines the duration of cookies for public shares + ShareCookieLifetime int `json:"share_cookie_lifetime" mapstructure:"share_cookie_lifetime"` + // JWTLifetime defines the duration of JWT tokens used in REST API + JWTLifetime int `json:"jwt_lifetime" mapstructure:"jwt_lifetime"` + // MaxUploadFileSize Defines the maximum request body size, in bytes, for Web Client/API HTTP upload requests. + // 0 means no limit + MaxUploadFileSize int64 `json:"max_upload_file_size" mapstructure:"max_upload_file_size"` + // CORS configuration + Cors CorsConfig `json:"cors" mapstructure:"cors"` + // Initial setup configuration + Setup SetupConfig `json:"setup" mapstructure:"setup"` + // If enabled, the link to the sponsors section will not appear on the setup screen page + HideSupportLink bool `json:"hide_support_link" mapstructure:"hide_support_link"` + acmeDomain string +} + +type apiResponse struct { + Error string `json:"error,omitempty"` + Message string `json:"message"` +} + +// ShouldBind returns true if there is at least a valid binding +func (c *Conf) ShouldBind() bool { + for _, binding := range c.Bindings { + if binding.IsValid() { + return true + } + } + + return false +} + +func (c *Conf) isWebAdminEnabled() bool { + for _, binding := range c.Bindings { + if binding.EnableWebAdmin { + return true + } + } + return false +} + +func (c *Conf) isWebClientEnabled() bool { + for _, binding := range c.Bindings { + if binding.EnableWebClient { + return true + } + } + return false +} + +func (c *Conf) checkRequiredDirs(staticFilesPath, templatesPath string) error { + if (c.isWebAdminEnabled() || c.isWebClientEnabled()) && (staticFilesPath == "" || templatesPath == "") { + return fmt.Errorf("required directory is invalid, static file path: %q template path: %q", + staticFilesPath, templatesPath) + } + return nil +} + +func (c *Conf) getRedacted() Conf { + redacted := "[redacted]" + conf := *c + if conf.SigningPassphrase != "" { + conf.SigningPassphrase = redacted + } + if conf.Setup.InstallationCode != "" { + conf.Setup.InstallationCode = redacted + } + conf.Bindings = nil + for _, binding := range c.Bindings { + if binding.OIDC.ClientID != "" { + binding.OIDC.ClientID = redacted + } + if binding.OIDC.ClientSecret != "" { + binding.OIDC.ClientSecret = redacted + } + conf.Bindings = append(conf.Bindings, binding) + } + return conf +} + +func (c *Conf) getKeyPairs(configDir string) []common.TLSKeyPair { + var keyPairs []common.TLSKeyPair + + for _, binding := range c.Bindings { + certificateFile := getConfigPath(binding.CertificateFile, configDir) + certificateKeyFile := getConfigPath(binding.CertificateKeyFile, configDir) + if certificateFile != "" && certificateKeyFile != "" { + keyPairs = append(keyPairs, common.TLSKeyPair{ + Cert: certificateFile, + Key: certificateKeyFile, + ID: binding.GetAddress(), + }) + } + } + var certificateFile, certificateKeyFile string + if c.acmeDomain != "" { + certificateFile, certificateKeyFile = util.GetACMECertificateKeyPair(c.acmeDomain) + } else { + certificateFile = getConfigPath(c.CertificateFile, configDir) + certificateKeyFile = getConfigPath(c.CertificateKeyFile, configDir) + } + if certificateFile != "" && certificateKeyFile != "" { + keyPairs = append(keyPairs, common.TLSKeyPair{ + Cert: certificateFile, + Key: certificateKeyFile, + ID: common.DefaultTLSKeyPaidID, + }) + } + return keyPairs +} + +func (c *Conf) setTokenValidationMode() { + tokenValidationMode = c.TokenValidation +} + +func (c *Conf) loadFromProvider() error { + configs, err := dataprovider.GetConfigs() + if err != nil { + return fmt.Errorf("unable to load config from provider: %w", err) + } + configs.SetNilsToEmpty() + dbBrandingConfig.Set(configs.Branding) + if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolHTTP) { + return nil + } + crt, key := util.GetACMECertificateKeyPair(configs.ACME.Domain) + if crt != "" && key != "" { + if _, err := os.Stat(crt); err != nil { + logger.Error(logSender, "", "unable to load acme cert file %q: %v", crt, err) + return nil + } + if _, err := os.Stat(key); err != nil { + logger.Error(logSender, "", "unable to load acme key file %q: %v", key, err) + return nil + } + for idx := range c.Bindings { + if c.Bindings[idx].Security.Enabled && c.Bindings[idx].Security.HTTPSRedirect { + continue + } + c.Bindings[idx].EnableHTTPS = true + } + c.acmeDomain = configs.ACME.Domain + logger.Info(logSender, "", "acme domain set to %q", c.acmeDomain) + return nil + } + return nil +} + +func (c *Conf) loadTemplates(templatesPath string) { + if c.isWebAdminEnabled() { + updateWebAdminURLs(c.WebRoot) + loadAdminTemplates(templatesPath) + } else { + logger.Info(logSender, "", "built-in web admin interface disabled") + } + if c.isWebClientEnabled() { + updateWebClientURLs(c.WebRoot) + loadClientTemplates(templatesPath) + } else { + logger.Info(logSender, "", "built-in web client interface disabled") + } +} + +// Initialize configures and starts the HTTP server +func (c *Conf) Initialize(configDir string, isShared int) error { + if err := c.loadFromProvider(); err != nil { + return err + } + logger.Info(logSender, "", "initializing HTTP server with config %+v", c.getRedacted()) + configurationDir = configDir + invalidatedJWTTokens = newTokenManager(isShared) + resetCodesMgr = newResetCodeManager(isShared) + oidcMgr = newOIDCManager(isShared) + oauth2Mgr = newOAuth2Manager(isShared) + webTaskMgr = newWebTaskManager(isShared) + staticFilesPath := util.FindSharedDataPath(c.StaticFilesPath, configDir) + templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir) + openAPIPath := util.FindSharedDataPath(c.OpenAPIPath, configDir) + if err := c.checkRequiredDirs(staticFilesPath, templatesPath); err != nil { + return err + } + c.loadTemplates(templatesPath) + keyPairs := c.getKeyPairs(configDir) + if len(keyPairs) > 0 { + mgr, err := common.NewCertManager(keyPairs, configDir, logSender) + if err != nil { + return err + } + mgr.SetCACertificates(c.CACertificates) + if err := mgr.LoadRootCAs(); err != nil { + return err + } + mgr.SetCARevocationLists(c.CARevocationLists) + if err := mgr.LoadCRLs(); err != nil { + return err + } + certMgr = mgr + } + + if c.SigningPassphraseFile != "" { + passphrase, err := util.ReadConfigFromFile(c.SigningPassphraseFile, configDir) + if err != nil { + return err + } + c.SigningPassphrase = passphrase + } + + hideSupportLink = c.HideSupportLink + + exitChannel := make(chan error, 1) + + for _, binding := range c.Bindings { + if !binding.IsValid() { + continue + } + if err := binding.check(); err != nil { + return err + } + + go func(b Binding) { + if err := b.OIDC.initialize(); err != nil { + exitChannel <- err + return + } + if err := b.checkLoginMethods(); err != nil { + exitChannel <- err + return + } + server := newHttpdServer(b, staticFilesPath, c.SigningPassphrase, c.Cors, openAPIPath) + server.setShared(isShared) + + exitChannel <- server.listenAndServe() + }(binding) + } + + maxUploadFileSize = c.MaxUploadFileSize + installationCode = c.Setup.InstallationCode + installationCodeHint = c.Setup.InstallationCodeHint + updateTokensDuration(c.JWTLifetime, c.CookieLifetime, c.ShareCookieLifetime) + startCleanupTicker(10 * time.Minute) + c.setTokenValidationMode() + return <-exitChannel +} + +func isWebRequest(r *http.Request) bool { + return strings.HasPrefix(r.RequestURI, webBasePath+"/") +} + +func isWebClientRequest(r *http.Request) bool { + return strings.HasPrefix(r.RequestURI, webBaseClientPath+"/") +} + +// ReloadCertificateMgr reloads the certificate manager +func ReloadCertificateMgr() error { + if certMgr != nil { + return certMgr.Reload() + } + return nil +} + +func getConfigPath(name, configDir string) string { + if !util.IsFileInputValid(name) { + return "" + } + if name != "" && !filepath.IsAbs(name) { + return filepath.Join(configDir, name) + } + return name +} + +func getServicesStatus() *ServicesStatus { + rtlEnabled, rtlProtocols := common.Config.GetRateLimitersStatus() + status := &ServicesStatus{ + SSH: sftpd.GetStatus(), + FTP: ftpd.GetStatus(), + WebDAV: webdavd.GetStatus(), + DataProvider: dataprovider.GetProviderStatus(), + Defender: defenderStatus{ + IsActive: common.Config.DefenderConfig.Enabled, + }, + MFA: mfa.GetStatus(), + AllowList: allowListStatus{ + IsActive: common.Config.IsAllowListEnabled(), + }, + RateLimiters: rateLimiters{ + IsActive: rtlEnabled, + Protocols: rtlProtocols, + }, + } + return status +} + +func fileServer(r chi.Router, path string, root http.FileSystem, disableDirectoryIndex bool) { + if path != "/" && path[len(path)-1] != '/' { + r.Get(path, http.RedirectHandler(path+"/", http.StatusMovedPermanently).ServeHTTP) + path += "/" + } + path += "*" + + r.Get(path, func(w http.ResponseWriter, r *http.Request) { + rctx := chi.RouteContext(r.Context()) + pathPrefix := strings.TrimSuffix(rctx.RoutePattern(), "/*") + if disableDirectoryIndex { + root = neuteredFileSystem{root} + } + fs := http.StripPrefix(pathPrefix, http.FileServer(root)) + fs.ServeHTTP(w, r) + }) +} + +func updateWebClientURLs(baseURL string) { + if !path.IsAbs(baseURL) { + baseURL = "/" + } + webRootPath = path.Join(baseURL, webRootPathDefault) + webBasePath = path.Join(baseURL, webBasePathDefault) + webBaseClientPath = path.Join(baseURL, webBasePathClientDefault) + webOIDCRedirectPath = path.Join(baseURL, webOIDCRedirectPathDefault) + webClientLoginPath = path.Join(baseURL, webClientLoginPathDefault) + webClientOIDCLoginPath = path.Join(baseURL, webClientOIDCLoginPathDefault) + webClientTwoFactorPath = path.Join(baseURL, webClientTwoFactorPathDefault) + webClientTwoFactorRecoveryPath = path.Join(baseURL, webClientTwoFactorRecoveryPathDefault) + webClientFilesPath = path.Join(baseURL, webClientFilesPathDefault) + webClientFilePath = path.Join(baseURL, webClientFilePathDefault) + webClientFileActionsPath = path.Join(baseURL, webClientFileActionsPathDefault) + webClientSharesPath = path.Join(baseURL, webClientSharesPathDefault) + webClientPubSharesPath = path.Join(baseURL, webClientPubSharesPathDefault) + webClientSharePath = path.Join(baseURL, webClientSharePathDefault) + webClientEditFilePath = path.Join(baseURL, webClientEditFilePathDefault) + webClientDirsPath = path.Join(baseURL, webClientDirsPathDefault) + webClientDownloadZipPath = path.Join(baseURL, webClientDownloadZipPathDefault) + webClientProfilePath = path.Join(baseURL, webClientProfilePathDefault) + webClientPingPath = path.Join(baseURL, webClientPingPathDefault) + webChangeClientPwdPath = path.Join(baseURL, webChangeClientPwdPathDefault) + webClientLogoutPath = path.Join(baseURL, webClientLogoutPathDefault) + webClientMFAPath = path.Join(baseURL, webClientMFAPathDefault) + webClientTOTPGeneratePath = path.Join(baseURL, webClientTOTPGeneratePathDefault) + webClientTOTPValidatePath = path.Join(baseURL, webClientTOTPValidatePathDefault) + webClientTOTPSavePath = path.Join(baseURL, webClientTOTPSavePathDefault) + webClientRecoveryCodesPath = path.Join(baseURL, webClientRecoveryCodesPathDefault) + webClientForgotPwdPath = path.Join(baseURL, webClientForgotPwdPathDefault) + webClientResetPwdPath = path.Join(baseURL, webClientResetPwdPathDefault) + webClientViewPDFPath = path.Join(baseURL, webClientViewPDFPathDefault) + webClientGetPDFPath = path.Join(baseURL, webClientGetPDFPathDefault) + webClientExistPath = path.Join(baseURL, webClientExistPathDefault) + webClientTasksPath = path.Join(baseURL, webClientTasksPathDefault) + webStaticFilesPath = path.Join(baseURL, webStaticFilesPathDefault) + webOpenAPIPath = path.Join(baseURL, webOpenAPIPathDefault) +} + +func updateWebAdminURLs(baseURL string) { + if !path.IsAbs(baseURL) { + baseURL = "/" + } + webRootPath = path.Join(baseURL, webRootPathDefault) + webBasePath = path.Join(baseURL, webBasePathDefault) + webBaseAdminPath = path.Join(baseURL, webBasePathAdminDefault) + webOIDCRedirectPath = path.Join(baseURL, webOIDCRedirectPathDefault) + webOAuth2RedirectPath = path.Join(baseURL, webOAuth2RedirectPathDefault) + webOAuth2TokenPath = path.Join(baseURL, webOAuth2TokenPathDefault) + webAdminSetupPath = path.Join(baseURL, webAdminSetupPathDefault) + webAdminLoginPath = path.Join(baseURL, webAdminLoginPathDefault) + webAdminOIDCLoginPath = path.Join(baseURL, webAdminOIDCLoginPathDefault) + webAdminTwoFactorPath = path.Join(baseURL, webAdminTwoFactorPathDefault) + webAdminTwoFactorRecoveryPath = path.Join(baseURL, webAdminTwoFactorRecoveryPathDefault) + webLogoutPath = path.Join(baseURL, webLogoutPathDefault) + webUsersPath = path.Join(baseURL, webUsersPathDefault) + webUserPath = path.Join(baseURL, webUserPathDefault) + webConnectionsPath = path.Join(baseURL, webConnectionsPathDefault) + webFoldersPath = path.Join(baseURL, webFoldersPathDefault) + webFolderPath = path.Join(baseURL, webFolderPathDefault) + webGroupsPath = path.Join(baseURL, webGroupsPathDefault) + webGroupPath = path.Join(baseURL, webGroupPathDefault) + webStatusPath = path.Join(baseURL, webStatusPathDefault) + webAdminsPath = path.Join(baseURL, webAdminsPathDefault) + webAdminPath = path.Join(baseURL, webAdminPathDefault) + webMaintenancePath = path.Join(baseURL, webMaintenancePathDefault) + webBackupPath = path.Join(baseURL, webBackupPathDefault) + webRestorePath = path.Join(baseURL, webRestorePathDefault) + webScanVFolderPath = path.Join(baseURL, webScanVFolderPathDefault) + webQuotaScanPath = path.Join(baseURL, webQuotaScanPathDefault) + webChangeAdminPwdPath = path.Join(baseURL, webChangeAdminPwdPathDefault) + webAdminForgotPwdPath = path.Join(baseURL, webAdminForgotPwdPathDefault) + webAdminResetPwdPath = path.Join(baseURL, webAdminResetPwdPathDefault) + webAdminProfilePath = path.Join(baseURL, webAdminProfilePathDefault) + webAdminMFAPath = path.Join(baseURL, webAdminMFAPathDefault) + webAdminEventRulesPath = path.Join(baseURL, webAdminEventRulesPathDefault) + webAdminEventRulePath = path.Join(baseURL, webAdminEventRulePathDefault) + webAdminEventActionsPath = path.Join(baseURL, webAdminEventActionsPathDefault) + webAdminEventActionPath = path.Join(baseURL, webAdminEventActionPathDefault) + webAdminRolesPath = path.Join(baseURL, webAdminRolesPathDefault) + webAdminRolePath = path.Join(baseURL, webAdminRolePathDefault) + webAdminTOTPGeneratePath = path.Join(baseURL, webAdminTOTPGeneratePathDefault) + webAdminTOTPValidatePath = path.Join(baseURL, webAdminTOTPValidatePathDefault) + webAdminTOTPSavePath = path.Join(baseURL, webAdminTOTPSavePathDefault) + webAdminRecoveryCodesPath = path.Join(baseURL, webAdminRecoveryCodesPathDefault) + webTemplateUser = path.Join(baseURL, webTemplateUserDefault) + webTemplateFolder = path.Join(baseURL, webTemplateFolderDefault) + webDefenderHostsPath = path.Join(baseURL, webDefenderHostsPathDefault) + webDefenderPath = path.Join(baseURL, webDefenderPathDefault) + webIPListPath = path.Join(baseURL, webIPListPathDefault) + webIPListsPath = path.Join(baseURL, webIPListsPathDefault) + webEventsPath = path.Join(baseURL, webEventsPathDefault) + webEventsFsSearchPath = path.Join(baseURL, webEventsFsSearchPathDefault) + webEventsProviderSearchPath = path.Join(baseURL, webEventsProviderSearchPathDefault) + webEventsLogSearchPath = path.Join(baseURL, webEventsLogSearchPathDefault) + webConfigsPath = path.Join(baseURL, webConfigsPathDefault) + webStaticFilesPath = path.Join(baseURL, webStaticFilesPathDefault) + webOpenAPIPath = path.Join(baseURL, webOpenAPIPathDefault) +} + +// GetHTTPRouter returns an HTTP handler suitable to use for test cases +func GetHTTPRouter(b Binding) (http.Handler, error) { + server := newHttpdServer(b, filepath.Join("..", "..", "static"), "", CorsConfig{}, filepath.Join("..", "..", "openapi")) + if err := server.initializeRouter(); err != nil { + return nil, err + } + return server.router, nil +} + +// the ticker cannot be started/stopped from multiple goroutines +func startCleanupTicker(duration time.Duration) { + stopCleanupTicker() + cleanupTicker = time.NewTicker(duration) + cleanupDone = make(chan bool) + + go func() { + counter := int64(0) + for { + select { + case <-cleanupDone: + return + case <-cleanupTicker.C: + counter++ + invalidatedJWTTokens.Cleanup() + resetCodesMgr.Cleanup() + webTaskMgr.Cleanup() + if counter%2 == 0 { + oidcMgr.cleanup() + oauth2Mgr.cleanup() + } + } + } + }() +} + +func stopCleanupTicker() { + if cleanupTicker != nil { + cleanupTicker.Stop() + cleanupDone <- true + cleanupTicker = nil + } +} + +func getSigningKey(signingPassphrase string) []byte { + var key []byte + if signingPassphrase != "" { + key = []byte(signingPassphrase) + } else { + key = util.GenerateRandomBytes(32) + } + sk := sha256.Sum256(key) + return sk[:] +} + +// SetInstallationCodeResolver sets a function to call to resolve the installation code +func SetInstallationCodeResolver(fn FnInstallationCodeResolver) { + fnInstallationCodeResolver = fn +} + +func resolveInstallationCode() string { + if fnInstallationCodeResolver != nil { + return fnInstallationCodeResolver(installationCode) + } + return installationCode +} + +type neuteredFileSystem struct { + fs http.FileSystem +} + +func (nfs neuteredFileSystem) Open(name string) (http.File, error) { + f, err := nfs.fs.Open(name) + if err != nil { + return nil, err + } + + s, err := f.Stat() + if err != nil { + return nil, err + } + + if s.IsDir() { + index := path.Join(name, "index.html") + if _, err := nfs.fs.Open(index); err != nil { + defer f.Close() + + return nil, err + } + } + + return f, nil +} diff --git a/internal/httpd/httpd_test.go b/internal/httpd/httpd_test.go new file mode 100644 index 00000000..be4f8d88 --- /dev/null +++ b/internal/httpd/httpd_test.go @@ -0,0 +1,27771 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd_test + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "image" + "image/color" + "image/png" + "io" + "io/fs" + "math" + "mime/multipart" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "runtime" + "slices" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/go-chi/render" + _ "github.com/go-sql-driver/mysql" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/lithammer/shortuuid/v4" + _ "github.com/mattn/go-sqlite3" + "github.com/mhale/smtpd" + "github.com/pkg/sftp" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "github.com/rs/xid" + "github.com/rs/zerolog" + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/ssh" + "golang.org/x/net/html" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + defaultUsername = "test_user" + defaultPassword = "test_password" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + testPubKey1 = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCd60+/j+y8f0tLftihWV1YN9RSahMI9btQMDIMqts/jeNbD8jgoogM3nhF7KxfcaMKURuD47KC4Ey6iAJUJ0sWkSNNxOcIYuvA+5MlspfZDsa8Ag76Fe1vyz72WeHMHMeh/hwFo2TeIeIXg480T1VI6mzfDrVp2GzUx0SS0dMsQBjftXkuVR8YOiOwMCAH2a//M1OrvV7d/NBk6kBN0WnuIBb2jKm15PAA7+jQQG7tzwk2HedNH3jeL5GH31xkSRwlBczRK0xsCQXehAlx6cT/e/s44iJcJTHfpPKoSk6UAhPJYe7Z1QnuoawY9P9jQaxpyeImBZxxUEowhjpj2avBxKdRGBVK8R7EL8tSOeLbhdyWe5Mwc1+foEbq9Zz5j5Kd+hn3Wm1UnsGCrXUUUoZp1jnlNl0NakCto+5KmqnT9cHxaY+ix2RLUWAZyVFlRq71OYux1UHJnEJPiEI1/tr4jFBSL46qhQZv/TfpkfVW8FLz0lErfqu0gQEZnNHr3Fc= nicola@p1" + defaultTokenAuthUser = "admin" + defaultTokenAuthPass = "password" + altAdminUsername = "newTestAdmin" + altAdminPassword = "password1" + csrfFormToken = "_form_token" + tokenPath = "/api/v2/token" + userTokenPath = "/api/v2/user/token" + userLogoutPath = "/api/v2/user/logout" + userPath = "/api/v2/users" + adminPath = "/api/v2/admins" + adminPwdPath = "/api/v2/admin/changepwd" + folderPath = "/api/v2/folders" + groupPath = "/api/v2/groups" + activeConnectionsPath = "/api/v2/connections" + serverStatusPath = "/api/v2/status" + quotasBasePath = "/api/v2/quotas" + quotaScanPath = "/api/v2/quotas/users/scans" + quotaScanVFolderPath = "/api/v2/quotas/folders/scans" + defenderHosts = "/api/v2/defender/hosts" + versionPath = "/api/v2/version" + logoutPath = "/api/v2/logout" + userPwdPath = "/api/v2/user/changepwd" + userDirsPath = "/api/v2/user/dirs" + userFilesPath = "/api/v2/user/files" + userFileActionsPath = "/api/v2/user/file-actions" + userStreamZipPath = "/api/v2/user/streamzip" + userUploadFilePath = "/api/v2/user/files/upload" + userFilesDirsMetadataPath = "/api/v2/user/files/metadata" + apiKeysPath = "/api/v2/apikeys" + adminTOTPConfigsPath = "/api/v2/admin/totp/configs" + adminTOTPGeneratePath = "/api/v2/admin/totp/generate" + adminTOTPValidatePath = "/api/v2/admin/totp/validate" + adminTOTPSavePath = "/api/v2/admin/totp/save" + admin2FARecoveryCodesPath = "/api/v2/admin/2fa/recoverycodes" + adminProfilePath = "/api/v2/admin/profile" + userTOTPConfigsPath = "/api/v2/user/totp/configs" + userTOTPGeneratePath = "/api/v2/user/totp/generate" + userTOTPValidatePath = "/api/v2/user/totp/validate" + userTOTPSavePath = "/api/v2/user/totp/save" + user2FARecoveryCodesPath = "/api/v2/user/2fa/recoverycodes" + userProfilePath = "/api/v2/user/profile" + userSharesPath = "/api/v2/user/shares" + fsEventsPath = "/api/v2/events/fs" + providerEventsPath = "/api/v2/events/provider" + logEventsPath = "/api/v2/events/logs" + sharesPath = "/api/v2/shares" + eventActionsPath = "/api/v2/eventactions" + eventRulesPath = "/api/v2/eventrules" + rolesPath = "/api/v2/roles" + ipListsPath = "/api/v2/iplists" + healthzPath = "/healthz" + webBasePath = "/web" + webBasePathAdmin = "/web/admin" + webAdminSetupPath = "/web/admin/setup" + webLoginPath = "/web/admin/login" + webLogoutPath = "/web/admin/logout" + webUsersPath = "/web/admin/users" + webUserPath = "/web/admin/user" + webGroupsPath = "/web/admin/groups" + webGroupPath = "/web/admin/group" + webFoldersPath = "/web/admin/folders" + webFolderPath = "/web/admin/folder" + webConnectionsPath = "/web/admin/connections" + webStatusPath = "/web/admin/status" + webAdminsPath = "/web/admin/managers" + webAdminPath = "/web/admin/manager" + webMaintenancePath = "/web/admin/maintenance" + webRestorePath = "/web/admin/restore" + webChangeAdminPwdPath = "/web/admin/changepwd" + webAdminProfilePath = "/web/admin/profile" + webTemplateUser = "/web/admin/template/user" + webTemplateFolder = "/web/admin/template/folder" + webDefenderPath = "/web/admin/defender" + webIPListsPath = "/web/admin/ip-lists" + webIPListPath = "/web/admin/ip-list" + webAdminTwoFactorPath = "/web/admin/twofactor" + webAdminTwoFactorRecoveryPath = "/web/admin/twofactor-recovery" + webAdminMFAPath = "/web/admin/mfa" + webAdminTOTPSavePath = "/web/admin/totp/save" + webAdminForgotPwdPath = "/web/admin/forgot-password" + webAdminResetPwdPath = "/web/admin/reset-password" + webAdminEventRulesPath = "/web/admin/eventrules" + webAdminEventRulePath = "/web/admin/eventrule" + webAdminEventActionsPath = "/web/admin/eventactions" + webAdminEventActionPath = "/web/admin/eventaction" + webAdminRolesPath = "/web/admin/roles" + webAdminRolePath = "/web/admin/role" + webEventsPath = "/web/admin/events" + webConfigsPath = "/web/admin/configs" + webOAuth2TokenPath = "/web/admin/oauth2/token" + webBasePathClient = "/web/client" + webClientLoginPath = "/web/client/login" + webClientFilesPath = "/web/client/files" + webClientEditFilePath = "/web/client/editfile" + webClientDirsPath = "/web/client/dirs" + webClientDownloadZipPath = "/web/client/downloadzip" + webChangeClientPwdPath = "/web/client/changepwd" + webClientProfilePath = "/web/client/profile" + webClientPingPath = "/web/client/ping" + webClientTwoFactorPath = "/web/client/twofactor" + webClientTwoFactorRecoveryPath = "/web/client/twofactor-recovery" + webClientLogoutPath = "/web/client/logout" + webClientMFAPath = "/web/client/mfa" + webClientTOTPSavePath = "/web/client/totp/save" + webClientSharesPath = "/web/client/shares" + webClientSharePath = "/web/client/share" + webClientPubSharesPath = "/web/client/pubshares" + webClientForgotPwdPath = "/web/client/forgot-password" + webClientResetPwdPath = "/web/client/reset-password" + webClientViewPDFPath = "/web/client/viewpdf" + webClientGetPDFPath = "/web/client/getpdf" + webClientExistPath = "/web/client/exist" + webClientTasksPath = "/web/client/tasks" + webClientFileMovePath = "/web/client/file-actions/move" + webClientFileCopyPath = "/web/client/file-actions/copy" + jsonAPISuffix = "/json" + httpBaseURL = "http://127.0.0.1:8081" + defaultRemoteAddr = "127.0.0.1:1234" + sftpServerAddr = "127.0.0.1:8022" + smtpServerAddr = "127.0.0.1:3525" + httpsCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + httpsKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` + sftpPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACB+RB4yNTZz9mHOkawwUibNdemijVV3ErMeLxWUBlCN/gAAAJA7DjpfOw46 +XwAAAAtzc2gtZWQyNTUxOQAAACB+RB4yNTZz9mHOkawwUibNdemijVV3ErMeLxWUBlCN/g +AAAEA0E24gi8ab/XRSvJ85TGZJMe6HVmwxSG4ExPfTMwwe2n5EHjI1NnP2Yc6RrDBSJs11 +6aKNVXcSsx4vFZQGUI3+AAAACW5pY29sYUBwMQECAwQ= +-----END OPENSSH PRIVATE KEY-----` + sftpPkeyFingerprint = "SHA256:QVQ06XHZZbYZzqfrsZcf3Yozy2WTnqQPeLOkcJCdbP0" + // password protected private key + testPrivateKeyPwd = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABAvfwQQcs ++PyMsCLTNFcKiQAAAAEAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q ++8w23flfgskjIlKViEwMfjJR4mrbAAAAkHp5xgG8J1XW90M/fT59ZUQht8sZzzP17rEKlX +waYKvLzDxkPK6LFIYs55W1EX1eVt/2Maq+zQ7k2SOUmhPNknsUOlPV2gytX3uIYvXF7u2F +FTBIJuzZ+UQ14wFbraunliE9yye9DajVG1kz2cz2wVgXUbee+gp5NyFVvln+TcTxXwMsWD +qwlk5iw/jQekxThg== +-----END OPENSSH PRIVATE KEY----- +` + testPubKeyPwd = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q+8w23flfgskjIlKViEwMfjJR4mrb" + privateKeyPwd = "password" + rsa1024PrivKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn +NhAAAAAwEAAQAAAIEAxgrZ84gJyU7Qz8JbYuYh0fgTN29h4qVkqDkEE0lWZe7L4QRcQHrB +vycJO5vjfitY5JTojV3nbDNHN6XGVX8QNurwXmxv0EmEbqPoNO/rTf1t7qqwMBBAfSJJ5H +TXsO37vqcWSOt1Ki5yjRm232UfPo3AYXaZdOKDWKpzI12FfqkAAAIAondFqKJ3RagAAAAH +c3NoLXJzYQAAAIEAxgrZ84gJyU7Qz8JbYuYh0fgTN29h4qVkqDkEE0lWZe7L4QRcQHrBvy +cJO5vjfitY5JTojV3nbDNHN6XGVX8QNurwXmxv0EmEbqPoNO/rTf1t7qqwMBBAfSJJ5HTX +sO37vqcWSOt1Ki5yjRm232UfPo3AYXaZdOKDWKpzI12FfqkAAAADAQABAAAAgC7V5COG+a +GFJTbtJQWnnTn17D2A9upN6RcrnL4e6vLiXY8So+qP3YAicDmLrWpqP/SXDsRX/+ID4oTT +jKstiJy5jTvXAozwBbFCvNDk1qifs8p/HKzel3t0172j6gLOa2h9+clJ4BYyCk6ue4f8fV +yKTIc9chdJSpeINNY60CJxAAAAQQDhYpGXljD2Xy/CzqRXyoF+iMtOImLlbgQYswTXegk3 +7JoCNvwqg8xP+JxGpvUGpX23VWh0nBhzcAKHGlssiYQuAAAAQQDwB6s7s1WIRZ2Jsz8f6l +7/ebpPrAMyKmWkXc7KyvR53zuMkMIdvujM5NkOWh1ON8jtNumArey2dWuGVh+pXbdVAAAA +QQDTOAaMcyTfXMH/oSMsp+5obvT/RuewaRLHdBiCy0y1Jw0ykOcOCkswr/btDL26hImaHF +SheorO+2We7dnFuUIFAAAACW5pY29sYUBwMQE= +-----END OPENSSH PRIVATE KEY-----` + rsa1024PubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDGCtnziAnJTtDPwlti5iHR+BM3b2HipWSoOQQTSVZl7svhBFxAesG/Jwk7m+N+K1jklOiNXedsM0c3pcZVfxA26vBebG/QSYRuo+g07+tN/W3uqrAwEEB9IknkdNew7fu+pxZI63UqLnKNGbbfZR8+jcBhdpl04oNYqnMjXYV+qQ==" + redactedSecret = "[**redacted**]" + osWindows = "windows" + oidcMockAddr = "127.0.0.1:11111" +) + +var ( + configDir = filepath.Join(".", "..", "..") + defaultPerms = []string{dataprovider.PermAny} + homeBasePath string + backupsPath string + testServer *httptest.Server + postConnectPath string + preActionPath string + lastResetCode string +) + +type fakeConnection struct { + *common.BaseConnection + command string +} + +func (c *fakeConnection) Disconnect() error { + common.Connections.Remove(c.GetID()) + return nil +} + +func (c *fakeConnection) GetClientVersion() string { + return "" +} + +func (c *fakeConnection) GetCommand() string { + return c.command +} + +func (c *fakeConnection) GetLocalAddress() string { + return "" +} + +func (c *fakeConnection) GetRemoteAddress() string { + return "" +} + +type generateTOTPRequest struct { + ConfigName string `json:"config_name"` +} + +type generateTOTPResponse struct { + ConfigName string `json:"config_name"` + Issuer string `json:"issuer"` + Secret string `json:"secret"` + QRCode []byte `json:"qr_code"` +} + +type validateTOTPRequest struct { + ConfigName string `json:"config_name"` + Passcode string `json:"passcode"` + Secret string `json:"secret"` +} + +type recoveryCode struct { + Code string `json:"code"` + Used bool `json:"used"` +} + +func TestMain(m *testing.M) { //nolint:gocyclo + homeBasePath = os.TempDir() + logfilePath := filepath.Join(configDir, "sftpgo_api_test.log") + logger.InitLogger(logfilePath, 5, 1, 28, false, false, zerolog.DebugLevel) + os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") + os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") + os.Setenv("SFTPGO_DATA_PROVIDER__NAMING_RULES", "0") + os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") + os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") + os.Setenv("SFTPGO_HTTPD__MAX_UPLOAD_FILE_SIZE", "1048576000") + err := config.LoadConfig(configDir, "") + if err != nil { + logger.WarnToConsole("error loading configuration: %v", err) + os.Exit(1) + } + wdPath, err := os.Getwd() + if err != nil { + logger.WarnToConsole("error getting exe path: %v", err) + os.Exit(1) + } + pluginsConfig := []plugin.Config{ + { + Type: "eventsearcher", + Cmd: filepath.Join(wdPath, "..", "..", "tests", "eventsearcher", "eventsearcher"), + AutoMTLS: true, + }, + } + if runtime.GOOS == osWindows { + pluginsConfig[0].Cmd += ".exe" + } + providerConf := config.GetProviderConf() + logger.InfoToConsole("Starting HTTPD tests, provider: %v", providerConf.Driver) + + backupsPath = filepath.Join(os.TempDir(), "test_backups") + providerConf.BackupsPath = backupsPath + err = os.MkdirAll(backupsPath, os.ModePerm) + if err != nil { + logger.ErrorToConsole("error creating backups path: %v", err) + os.Exit(1) + } + + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing kms: %v", err) + os.Exit(1) + } + err = plugin.Initialize(pluginsConfig, "debug") + if err != nil { + logger.ErrorToConsole("error initializing plugin: %v", err) + os.Exit(1) + } + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing MFA: %v", err) + os.Exit(1) + } + + err = dataprovider.Initialize(providerConf, configDir, true) + if err != nil { + logger.WarnToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + + err = common.Initialize(config.GetCommonConfig(), 0) + if err != nil { + logger.WarnToConsole("error initializing common: %v", err) + os.Exit(1) + } + + postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") + preActionPath = filepath.Join(homeBasePath, "preaction.sh") + + httpConfig := config.GetHTTPConfig() + httpConfig.RetryMax = 1 + httpConfig.Timeout = 5 + httpConfig.Initialize(configDir) //nolint:errcheck + + httpdConf := config.GetHTTPDConfig() + + httpdConf.Bindings[0].Port = 8081 + httpdConf.Bindings[0].Security = httpd.SecurityConf{ + Enabled: true, + HTTPSProxyHeaders: []httpd.HTTPSProxyHeader{ + { + Key: "X-Forwarded-Proto", + Value: "https", + }, + }, + CacheControl: "private", + } + httpdtest.SetBaseURL(httpBaseURL) + // required to test sftpfs + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 8022, + }, + } + hostKeyPath := filepath.Join(os.TempDir(), "id_rsa") + sftpdConf.HostKeys = []string{hostKeyPath} + + go func() { + if err := httpdConf.Initialize(configDir, 0); err != nil { + logger.ErrorToConsole("could not start HTTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := sftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server: %v", err) + os.Exit(1) + } + }() + + startSMTPServer() + startOIDCMockServer() + + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + httpd.ReloadCertificateMgr() //nolint:errcheck + // now start an https server + certPath := filepath.Join(os.TempDir(), "test.crt") + keyPath := filepath.Join(os.TempDir(), "test.key") + err = os.WriteFile(certPath, []byte(httpsCert), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing HTTPS certificate: %v", err) + os.Exit(1) + } + err = os.WriteFile(keyPath, []byte(httpsKey), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing HTTPS private key: %v", err) + os.Exit(1) + } + httpdConf.Bindings[0].Port = 8443 + httpdConf.Bindings[0].EnableHTTPS = true + httpdConf.Bindings[0].CertificateFile = certPath + httpdConf.Bindings[0].CertificateKeyFile = keyPath + httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{}) + + go func() { + if err := httpdConf.Initialize(configDir, 0); err != nil { + logger.ErrorToConsole("could not start HTTPS server: %v", err) + os.Exit(1) + } + }() + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + httpd.ReloadCertificateMgr() //nolint:errcheck + + handler, err := httpd.GetHTTPRouter(httpdConf.Bindings[0]) + if err != nil { + logger.ErrorToConsole("unable to get http test handler: %v", err) + os.Exit(1) + } + testServer = httptest.NewServer(handler) + defer testServer.Close() + + exitCode := m.Run() + os.Remove(logfilePath) + os.RemoveAll(backupsPath) + os.Remove(certPath) + os.Remove(keyPath) + os.Remove(hostKeyPath) + os.Remove(hostKeyPath + ".pub") + os.Remove(postConnectPath) + os.Remove(preActionPath) + os.Exit(exitCode) +} + +func TestInitialization(t *testing.T) { + isShared := 0 + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + invalidFile := "invalid file" + passphraseFile := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + err = os.WriteFile(passphraseFile, []byte("my secret"), 0600) + assert.NoError(t, err) + defer os.Remove(passphraseFile) + httpdConf := config.GetHTTPDConfig() + httpdConf.SigningPassphraseFile = invalidFile + err = httpdConf.Initialize(configDir, isShared) + assert.ErrorIs(t, err, fs.ErrNotExist) + httpdConf.SigningPassphraseFile = passphraseFile + defaultTemplatesPath := httpdConf.TemplatesPath + defaultStaticPath := httpdConf.StaticFilesPath + httpdConf.CertificateFile = invalidFile + httpdConf.CertificateKeyFile = invalidFile + err = httpdConf.Initialize(configDir, isShared) + assert.Error(t, err) + httpdConf.CertificateFile = "" + httpdConf.CertificateKeyFile = "" + httpdConf.TemplatesPath = "." + err = httpdConf.Initialize(configDir, isShared) + assert.Error(t, err) + httpdConf = config.GetHTTPDConfig() + httpdConf.TemplatesPath = defaultTemplatesPath + httpdConf.CertificateFile = invalidFile + httpdConf.CertificateKeyFile = invalidFile + httpdConf.StaticFilesPath = "" + httpdConf.TemplatesPath = "" + err = httpdConf.Initialize(configDir, isShared) + assert.Error(t, err) + httpdConf.StaticFilesPath = defaultStaticPath + httpdConf.TemplatesPath = defaultTemplatesPath + httpdConf.CertificateFile = filepath.Join(os.TempDir(), "test.crt") + httpdConf.CertificateKeyFile = filepath.Join(os.TempDir(), "test.key") + httpdConf.CACertificates = append(httpdConf.CACertificates, invalidFile) + err = httpdConf.Initialize(configDir, isShared) + assert.Error(t, err) + httpdConf.CACertificates = nil + httpdConf.CARevocationLists = append(httpdConf.CARevocationLists, invalidFile) + err = httpdConf.Initialize(configDir, isShared) + assert.Error(t, err) + httpdConf.CARevocationLists = nil + httpdConf.SigningPassphraseFile = passphraseFile + httpdConf.Bindings[0].ProxyAllowed = []string{"invalid ip/network"} + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not a valid IP range") + } + assert.Equal(t, "my secret", httpdConf.SigningPassphrase) + httpdConf.Bindings[0].ProxyAllowed = nil + httpdConf.Bindings[0].EnableWebAdmin = false + httpdConf.Bindings[0].EnableWebClient = false + httpdConf.Bindings[0].Port = 8081 + httpdConf.Bindings[0].EnableHTTPS = true + httpdConf.Bindings[0].ClientAuthType = 1 + httpdConf.TokenValidation = 1 + err = httpdConf.Initialize(configDir, 0) + assert.Error(t, err) + httpdConf.TokenValidation = 0 + err = httpdConf.Initialize(configDir, 0) + assert.Error(t, err) + + httpdConf.Bindings[0].OIDC = httpd.OIDC{ + ClientID: "123", + ClientSecret: "secret", + ConfigURL: "http://127.0.0.1:11111", + } + err = httpdConf.Initialize(configDir, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "oidc") + } + httpdConf.Bindings[0].OIDC.UsernameField = "preferred_username" + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "oidc") + } + httpdConf.Bindings[0].OIDC = httpd.OIDC{} + httpdConf.Bindings[0].BaseURL = "ftp://127.0.0.1" + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "URL schema") + } + httpdConf.Bindings[0].BaseURL = "" + httpdConf.Bindings[0].EnableWebClient = true + httpdConf.Bindings[0].EnableWebAdmin = true + httpdConf.Bindings[0].EnableRESTAPI = true + httpdConf.Bindings[0].DisabledLoginMethods = 14 + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for WebAdmin UI") + } + httpdConf.Bindings[0].DisabledLoginMethods = 13 + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for WebAdmin UI") + } + httpdConf.Bindings[0].DisabledLoginMethods = 9 + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for WebClient UI") + } + httpdConf.Bindings[0].DisabledLoginMethods = 11 + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for WebClient UI") + } + httpdConf.Bindings[0].DisabledLoginMethods = 12 + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for WebAdmin UI") + } + httpdConf.Bindings[0].EnableWebAdmin = false + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for WebClient UI") + } + httpdConf.Bindings[0].EnableWebClient = false + httpdConf.Bindings[0].DisabledLoginMethods = 240 + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "no login method available for REST API") + } + err = dataprovider.Close() + assert.NoError(t, err) + err = httpdConf.Initialize(configDir, isShared) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to load config from provider") + } + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestBasicUserHandling(t *testing.T) { + u := getTestUser() + u.Email = "user@user.com" + u.Filters.AdditionalEmails = []string{"email1@user.com", "email2@user.com"} + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + _, resp, err = httpdtest.AddUser(u, http.StatusConflict) + assert.NoError(t, err, string(resp)) + lastPwdChange := user.LastPasswordChange + assert.Greater(t, lastPwdChange, int64(0)) + user.MaxSessions = 10 + user.QuotaSize = 4096 + user.QuotaFiles = 2 + user.UploadBandwidth = 128 + user.DownloadBandwidth = 64 + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) + user.AdditionalInfo = "some free text" + user.Filters.TLSUsername = sdk.TLSUsernameCN + user.Email = "user@example.net" + user.OIDCCustomFields = &map[string]any{ + "field1": "value1", + } + user.Filters.WebClient = append(user.Filters.WebClient, sdk.WebClientPubKeyChangeDisabled, + sdk.WebClientWriteDisabled) + originalUser := user + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, originalUser.ID, user.ID) + assert.Equal(t, lastPwdChange, user.LastPasswordChange) + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Nil(t, user.OIDCCustomFields) + assert.True(t, user.HasPassword) + + user.Email = "invalid@email" + user.FsConfig.OSConfig = sdk.OSFsConfig{ + ReadBufferSize: 1, + WriteBufferSize: 2, + } + _, body, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + assert.Contains(t, string(body), "Validation error: email") + + user.Email = "" + user.Filters.AdditionalEmails = []string{"invalid@email"} + _, body, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + assert.Contains(t, string(body), "Validation error: email") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestBasicRoleHandling(t *testing.T) { + r := getTestRole() + role, resp, err := httpdtest.AddRole(r, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Greater(t, role.CreatedAt, int64(0)) + assert.Greater(t, role.UpdatedAt, int64(0)) + roleGet, _, err := httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role, roleGet) + + roles, _, err := httpdtest.GetRoles(0, 0, http.StatusOK) + assert.NoError(t, err) + if assert.GreaterOrEqual(t, len(roles), 1) { + found := false + for _, ro := range roles { + if ro.Name == r.Name { + assert.Equal(t, role, ro) + found = true + } + } + assert.True(t, found) + } + roles, _, err = httpdtest.GetRoles(0, int64(len(roles)), http.StatusOK) + assert.NoError(t, err) + assert.Len(t, roles, 0) + + role.Description = "updated desc" + _, _, err = httpdtest.UpdateRole(role, http.StatusOK) + assert.NoError(t, err) + roleGet, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Description, roleGet.Description) + + _, _, err = httpdtest.GetRoleByName(role.Name+"_", http.StatusNotFound) + assert.NoError(t, err) + // adding the same role again should fail + _, _, err = httpdtest.AddRole(r, http.StatusConflict) + assert.NoError(t, err) + + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) +} + +func TestRoleRelations(t *testing.T) { + r := getTestRole() + role, resp, err := httpdtest.AddRole(r, http.StatusCreated) + assert.NoError(t, err, string(resp)) + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Role = role.Name + _, resp, err = httpdtest.AddAdmin(a, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "a role admin cannot be a super admin") + + a.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, + dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} + a.Role = "missing admin role" + _, _, err = httpdtest.AddAdmin(a, http.StatusConflict) + assert.NoError(t, err) + a.Role = role.Name + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + admin.Role = "invalid role" + _, resp, err = httpdtest.UpdateAdmin(admin, http.StatusConflict) + assert.NoError(t, err, string(resp)) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, admin.Role) + + resp, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.Error(t, err, "removing a referenced role should fail") + assert.Contains(t, string(resp), "is referenced") + + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, role.Admins, 1) { + assert.Equal(t, admin.Username, role.Admins[0]) + } + + u1 := getTestUser() + u1.Username = defaultUsername + "1" + u1.Role = "missing role" + _, _, err = httpdtest.AddUser(u1, http.StatusConflict) + assert.NoError(t, err) + u1.Role = role.Name + user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, role.Name, user1.Role) + user1.Role = "missing" + _, _, err = httpdtest.UpdateUser(user1, http.StatusConflict, "") + assert.NoError(t, err) + user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, user1.Role) + + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, role.Admins, 1) { + assert.Equal(t, admin.Username, role.Admins[0]) + } + if assert.Len(t, role.Users, 1) { + assert.Equal(t, user1.Username, role.Users[0]) + } + roles, _, err := httpdtest.GetRoles(0, 0, http.StatusOK) + assert.NoError(t, err) + for _, r := range roles { + if r.Name == role.Name { + if assert.Len(t, role.Admins, 1) { + assert.Equal(t, admin.Username, role.Admins[0]) + } + if assert.Len(t, role.Users, 1) { + assert.Equal(t, user1.Username, role.Users[0]) + } + } + } + + u2 := getTestUser() + user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) + assert.NoError(t, err) + + // the global admin can list all users + users, _, err := httpdtest.GetUsers(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(users), 2) + _, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) + assert.NoError(t, err) + // the role admin can only list users with its role + token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + users, _, err = httpdtest.GetUsers(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, users, 1) + _, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusNotFound) + assert.NoError(t, err) + // the role admin can only update/delete users with its role + _, _, err = httpdtest.UpdateUser(user1, http.StatusOK, "") + assert.NoError(t, err) + _, _, err = httpdtest.UpdateUser(user2, http.StatusNotFound, "") + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusNotFound) + assert.NoError(t, err) + // new users created by a role admin have the same role + u3 := getTestUser() + u3.Username = defaultUsername + "3" + _, _, err = httpdtest.AddUser(u3, http.StatusCreated) + if assert.Error(t, err) { + assert.Equal(t, err.Error(), "role mismatch") + } + + user3, _, err := httpdtest.GetUserByUsername(u3.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, user3.Role) + + _, err = httpdtest.RemoveUser(user3, http.StatusOK) + assert.NoError(t, err) + + httpdtest.SetJWTToken("") + + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Admins, []string{altAdminUsername}) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user1.Role) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) +} + +func TestRSAKeyInvalidSize(t *testing.T) { + u := getTestUser() + u.PublicKeys = append(u.PublicKeys, rsa1024PubKey) + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "invalid size") + u = getTestSFTPUser() + u.FsConfig.SFTPConfig.Password = nil + u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(rsa1024PrivKey) + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "rsa key with size 1024 not accepted") +} + +func TestTLSCert(t *testing.T) { + u := getTestUser() + u.Filters.TLSCerts = []string{"not a cert"} + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "invalid TLS certificate") + + u.Filters.TLSCerts = []string{httpsCert} + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + if assert.Len(t, user.Filters.TLSCerts, 1) { + assert.Equal(t, httpsCert, user.Filters.TLSCerts[0]) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSortRelatedFolders(t *testing.T) { + folder1 := util.GenerateUniqueID() + folder2 := util.GenerateUniqueID() + folder3 := util.GenerateUniqueID() + + f1 := vfs.BaseVirtualFolder{ + Name: folder1, + MappedPath: filepath.Clean(os.TempDir()), + } + f2 := vfs.BaseVirtualFolder{ + Name: folder2, + MappedPath: filepath.Clean(os.TempDir()), + } + f3 := vfs.BaseVirtualFolder{ + Name: folder3, + MappedPath: filepath.Clean(os.TempDir()), + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 3) { + assert.Equal(t, folder1, user.VirtualFolders[0].Name) + assert.Equal(t, folder2, user.VirtualFolders[1].Name) + assert.Equal(t, folder3, user.VirtualFolders[2].Name) + } + // Update + user.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 3) { + assert.Equal(t, folder2, user.VirtualFolders[0].Name) + assert.Equal(t, folder1, user.VirtualFolders[1].Name) + assert.Equal(t, folder3, user.VirtualFolders[2].Name) + } + + g := getTestGroup() + g.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + } + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, group.VirtualFolders, 3) { + assert.Equal(t, folder1, group.VirtualFolders[0].Name) + assert.Equal(t, folder2, group.VirtualFolders[1].Name) + assert.Equal(t, folder3, group.VirtualFolders[2].Name) + } + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + group.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: f3, + VirtualPath: "/" + folder3, + }, + { + BaseVirtualFolder: f1, + VirtualPath: "/" + folder1, + }, + { + BaseVirtualFolder: f2, + VirtualPath: "/" + folder2, + }, + } + group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) + assert.NoError(t, err) + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, group.VirtualFolders, 3) { + assert.Equal(t, folder3, group.VirtualFolders[0].Name) + assert.Equal(t, folder1, group.VirtualFolders[1].Name) + assert.Equal(t, folder2, group.VirtualFolders[2].Name) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveFolder(f1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(f2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(f3, http.StatusOK) + assert.NoError(t, err) +} + +func TestSortRelatedGroups(t *testing.T) { + name1 := util.GenerateUniqueID() + name2 := util.GenerateUniqueID() + name3 := util.GenerateUniqueID() + + g1 := getTestGroup() + g1.Name = name1 + g2 := getTestGroup() + g2.Name = name2 + g3 := getTestGroup() + g3.Name = name3 + + group1, _, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err) + group2, _, err := httpdtest.AddGroup(g2, http.StatusCreated) + assert.NoError(t, err) + group3, _, err := httpdtest.AddGroup(g3, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: name1, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Groups = []sdk.GroupMapping{ + { + Name: name1, + Type: sdk.GroupTypePrimary, + }, + { + Name: name2, + Type: sdk.GroupTypeSecondary, + }, + { + Name: name3, + Type: sdk.GroupTypeMembership, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.Groups, 3) { + assert.Equal(t, name1, user.Groups[0].Name) + assert.Equal(t, name2, user.Groups[1].Name) + assert.Equal(t, name3, user.Groups[2].Name) + } + user.Groups = []sdk.GroupMapping{ + { + Name: name2, + Type: sdk.GroupTypeSecondary, + }, + { + Name: name3, + Type: sdk.GroupTypeMembership, + }, + { + Name: name1, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.Groups, 3) { + assert.Equal(t, name2, user.Groups[0].Name) + assert.Equal(t, name3, user.Groups[1].Name) + assert.Equal(t, name1, user.Groups[2].Name) + } + + a := getTestAdmin() + a.Username = altAdminUsername + a.Groups = []dataprovider.AdminGroupMapping{ + { + Name: name3, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, + }, + }, + { + Name: name2, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, + }, + }, + { + Name: name1, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, + }, + }, + } + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, admin.Groups, 3) { + assert.Equal(t, name3, admin.Groups[0].Name) + assert.Equal(t, name2, admin.Groups[1].Name) + assert.Equal(t, name1, admin.Groups[2].Name) + } + admin.Groups = []dataprovider.AdminGroupMapping{ + { + Name: name1, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, + }, + }, + { + Name: name3, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, + }, + }, + { + Name: name2, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, + }, + }, + } + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, admin.Groups, 3) { + assert.Equal(t, name1, admin.Groups[0].Name) + assert.Equal(t, name3, admin.Groups[1].Name) + assert.Equal(t, name2, admin.Groups[2].Name) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group3, http.StatusOK) + assert.NoError(t, err) +} + +func TestBasicGroupHandling(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.TLSCerts = []string{"invalid cert"} // ignored for groups + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + assert.Greater(t, group.CreatedAt, int64(0)) + assert.Greater(t, group.UpdatedAt, int64(0)) + assert.Len(t, group.UserSettings.Filters.TLSCerts, 0) + groupGet, _, err := httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, group, groupGet) + groups, _, err := httpdtest.GetGroups(0, 0, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, groups, 1) { + assert.Equal(t, group, groups[0]) + } + groups, _, err = httpdtest.GetGroups(0, 1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, groups, 0) + _, _, err = httpdtest.GetGroupByName(group.Name+"_", http.StatusNotFound) + assert.NoError(t, err) + // adding the same group again should fail + _, _, err = httpdtest.AddGroup(g, http.StatusConflict) + assert.NoError(t, err) + + group.UserSettings.HomeDir = filepath.Join(os.TempDir(), "%username%") + group.UserSettings.FsConfig.Provider = sdk.SFTPFilesystemProvider + group.UserSettings.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + group.UserSettings.FsConfig.SFTPConfig.Username = defaultUsername + group.UserSettings.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + group.UserSettings.Permissions = map[string][]string{ + "/": {dataprovider.PermAny}, + } + group.UserSettings.Filters.AllowedIP = []string{"10.0.0.0/8"} + group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) + assert.NoError(t, err) + groupGet, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, groupGet.UserSettings.Permissions, 1) + assert.Len(t, groupGet.UserSettings.Filters.AllowedIP, 1) + + // update again and check that the password was preserved + dbGroup, err := dataprovider.GroupExists(group.Name) + assert.NoError(t, err) + group.UserSettings.FsConfig.SFTPConfig.Password = kms.NewSecret( + dbGroup.UserSettings.FsConfig.SFTPConfig.Password.GetStatus(), + dbGroup.UserSettings.FsConfig.SFTPConfig.Password.GetPayload(), "", "") + group.UserSettings.Permissions = nil + group.UserSettings.Filters.AllowedIP = nil + group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group.UserSettings.Permissions, 0) + assert.Len(t, group.UserSettings.Filters.AllowedIP, 0) + dbGroup, err = dataprovider.GroupExists(group.Name) + assert.NoError(t, err) + err = dbGroup.UserSettings.FsConfig.SFTPConfig.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, defaultPassword, dbGroup.UserSettings.FsConfig.SFTPConfig.Password.GetPayload()) + // check the group permissions + groupGet, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, groupGet.UserSettings.Permissions, 0) + + group.UserSettings.HomeDir = "relative path" + _, _, err = httpdtest.UpdateGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateGroup(group, http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusNotFound) + assert.NoError(t, err) +} + +func TestGroupRelations(t *testing.T) { + mappedPath1 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName2 := filepath.Base(mappedPath2) + _, _, err := httpdtest.AddFolder(vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + FsConfig: vfs.Filesystem{ + OSConfig: sdk.OSFsConfig{ + ReadBufferSize: 3, + WriteBufferSize: 5, + }, + }, + }, http.StatusCreated) + assert.NoError(t, err) + g1 := getTestGroup() + g1.Name += "_1" + g1.VirtualFolders = append(g1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir1", + }) + g2 := getTestGroup() + g2.Name += "_2" + g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir2", + }) + g3 := getTestGroup() + g3.Name += "_3" + g3.VirtualFolders = append(g3.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir3", + }) + _, _, err = httpdtest.AddGroup(g1, http.StatusCreated) + assert.Error(t, err, "adding a group with a missing folder must fail") + _, _, err = httpdtest.AddFolder(vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, http.StatusCreated) + assert.NoError(t, err) + + group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, group1.VirtualFolders, 1) + group2, resp, err := httpdtest.AddGroup(g2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, group2.VirtualFolders, 1) + group3, resp, err := httpdtest.AddGroup(g3, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, group3.VirtualFolders, 1) + + folder1, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder1.Groups, 3) + folder2, _, err := httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder2.Groups, 0) + + group1.VirtualFolders = append(group1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: folder2, + VirtualPath: "/vfolder2", + }) + group1, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.VirtualFolders, 2) + + folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder2.Groups, 1) + + group1.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folder1.Name, + MappedPath: folder1.MappedPath, + }, + VirtualPath: "/vpathmod", + }, + } + group1, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.VirtualFolders, 1) + + folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder2.Groups, 0) + + group1.VirtualFolders = append(group1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: folder2, + VirtualPath: "/vfolder2mod", + }) + group1, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.VirtualFolders, 2) + + folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder2.Groups, 1) + + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + { + Name: group2.Name, + Type: sdk.GroupTypeSecondary, + }, + { + Name: group3.Name, + Type: sdk.GroupTypeSecondary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + if assert.Len(t, user.Groups, 3) { + for _, g := range user.Groups { + if g.Name == group1.Name { + assert.Equal(t, sdk.GroupTypePrimary, g.Type) + } else { + assert.Equal(t, sdk.GroupTypeSecondary, g.Type) + } + } + } + group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.Users, 1) + group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group2.Users, 1) + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Users, 1) + + user.Groups = []sdk.GroupMapping{ + { + Name: group3.Name, + Type: sdk.GroupTypeSecondary, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Len(t, user.Groups, 1) + + group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.Users, 0) + group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group2.Users, 0) + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Users, 1) + + user.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + { + Name: group2.Name, + Type: sdk.GroupTypeSecondary, + }, + { + Name: group3.Name, + Type: sdk.GroupTypeSecondary, + }, + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Len(t, user.Groups, 3) + + group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.Users, 1) + group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group2.Users, 1) + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Users, 1) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder1, http.StatusOK) + assert.NoError(t, err) + group1, _, err = httpdtest.GetGroupByName(group1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group1.Users, 0) + assert.Len(t, group1.VirtualFolders, 1) + group2, _, err = httpdtest.GetGroupByName(group2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group2.Users, 0) + assert.Len(t, group2.VirtualFolders, 0) + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Users, 0) + assert.Len(t, group3.VirtualFolders, 0) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group3, http.StatusOK) + assert.NoError(t, err) + folder2, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder2.Groups, 0) + _, err = httpdtest.RemoveFolder(folder2, http.StatusOK) + assert.NoError(t, err) +} + +func TestGroupValidation(t *testing.T) { + group := getTestGroup() + group.VirtualFolders = []vfs.VirtualFolder{ + { + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: util.GenerateUniqueID(), + MappedPath: filepath.Join(os.TempDir(), util.GenerateUniqueID()), + }, + }, + } + _, resp, err := httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "virtual path is mandatory") + group.VirtualFolders = nil + group.UserSettings.FsConfig.Provider = sdk.SFTPFilesystemProvider + _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "endpoint cannot be empty") + group.UserSettings.FsConfig.Provider = sdk.LocalFilesystemProvider + group.UserSettings.Permissions = map[string][]string{ + "a": nil, + } + _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "cannot set permissions for non absolute path") + group.UserSettings.Permissions = map[string][]string{ + "/": nil, + } + _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "no permissions granted") + group.UserSettings.Permissions = map[string][]string{ + "/..": nil, + } + _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "cannot set permissions for invalid subdirectory") + group.UserSettings.Permissions = map[string][]string{ + "/": {dataprovider.PermAny}, + } + group.UserSettings.HomeDir = "relative" + _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "home_dir must be an absolute path") + group.UserSettings.HomeDir = "" + group.UserSettings.Filters.WebClient = []string{"invalid permission"} + _, resp, err = httpdtest.AddGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid web client options") +} + +func TestGroupSettingsOverride(t *testing.T) { + mappedPath1 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName2 := filepath.Base(mappedPath2) + mappedPath3 := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName3 := filepath.Base(mappedPath3) + g1 := getTestGroup() + g1.Name += "_1" + g1.VirtualFolders = append(g1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir1", + }) + g1.UserSettings.Permissions = map[string][]string{ + "/dir1": {dataprovider.PermUpload}, + "/dir2": {dataprovider.PermDownload, dataprovider.PermListItems}, + } + g1.UserSettings.FsConfig.OSConfig = sdk.OSFsConfig{ + ReadBufferSize: 6, + WriteBufferSize: 2, + } + g2 := getTestGroup() + g2.Name += "_2" + g2.UserSettings.Permissions = map[string][]string{ + "/dir1": {dataprovider.PermAny}, + "/dir3": {dataprovider.PermDownload, dataprovider.PermListItems, dataprovider.PermChtimes}, + } + g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir2", + }) + g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: "/vdir3", + }) + g2.VirtualFolders = append(g2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName3, + }, + VirtualPath: "/vdir4", + }) + g2.UserSettings.Filters.AccessTime = []sdk.TimePeriod{ + { + DayOfWeek: int(time.Now().UTC().Weekday()), + From: "00:00", + To: "23:59", + }, + } + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + FsConfig: vfs.Filesystem{ + OSConfig: sdk.OSFsConfig{ + ReadBufferSize: 3, + WriteBufferSize: 5, + }, + }, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderName3, + MappedPath: mappedPath3, + FsConfig: vfs.Filesystem{ + OSConfig: sdk.OSFsConfig{ + ReadBufferSize: 1, + WriteBufferSize: 2, + }, + }, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + + group1, resp, err := httpdtest.AddGroup(g1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + group2, resp, err := httpdtest.AddGroup(g2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + { + Name: group2.Name, + Type: sdk.GroupTypeSecondary, + }, + } + + r := getTestRole() + role, _, err := httpdtest.AddRole(r, http.StatusCreated) + assert.NoError(t, err) + u.Role = role.Name + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, user.VirtualFolders, 0) + assert.Len(t, user.Permissions, 1) + + user, err = dataprovider.CheckUserAndPass(defaultUsername, defaultPassword, "", common.ProtocolHTTP) + assert.NoError(t, err) + + var folderNames []string + if assert.Len(t, user.VirtualFolders, 4) { + for _, f := range user.VirtualFolders { + if !slices.Contains(folderNames, f.Name) { + folderNames = append(folderNames, f.Name) + } + switch f.Name { + case folderName1: + assert.Equal(t, mappedPath1, f.MappedPath) + assert.Equal(t, 3, f.FsConfig.OSConfig.ReadBufferSize) + assert.Equal(t, 5, f.FsConfig.OSConfig.WriteBufferSize) + assert.True(t, slices.Contains([]string{"/vdir1", "/vdir2"}, f.VirtualPath)) + case folderName2: + assert.Equal(t, mappedPath2, f.MappedPath) + assert.Equal(t, "/vdir3", f.VirtualPath) + assert.Equal(t, 0, f.FsConfig.OSConfig.ReadBufferSize) + assert.Equal(t, 0, f.FsConfig.OSConfig.WriteBufferSize) + case folderName3: + assert.Equal(t, mappedPath3, f.MappedPath) + assert.Equal(t, "/vdir4", f.VirtualPath) + assert.Equal(t, 1, f.FsConfig.OSConfig.ReadBufferSize) + assert.Equal(t, 2, f.FsConfig.OSConfig.WriteBufferSize) + } + } + } + assert.Len(t, folderNames, 3) + assert.Contains(t, folderNames, folderName1) + assert.Contains(t, folderNames, folderName2) + assert.Contains(t, folderNames, folderName3) + assert.Len(t, user.Permissions, 4) + assert.Equal(t, g1.UserSettings.Permissions["/dir1"], user.Permissions["/dir1"]) + assert.Equal(t, g1.UserSettings.Permissions["/dir2"], user.Permissions["/dir2"]) + assert.Equal(t, g2.UserSettings.Permissions["/dir3"], user.Permissions["/dir3"]) + assert.Equal(t, g1.UserSettings.FsConfig.OSConfig.ReadBufferSize, user.FsConfig.OSConfig.ReadBufferSize) + assert.Equal(t, g1.UserSettings.FsConfig.OSConfig.WriteBufferSize, user.FsConfig.OSConfig.WriteBufferSize) + assert.Len(t, user.Filters.AccessTime, 1) + + user, err = dataprovider.GetUserAfterIDPAuth(defaultUsername, "", common.ProtocolOIDC, nil) + assert.NoError(t, err) + assert.Len(t, user.VirtualFolders, 4) + assert.Len(t, user.Filters.AccessTime, 1) + + user1, user2, err := dataprovider.GetUserVariants(defaultUsername, "") + assert.NoError(t, err) + assert.Len(t, user1.VirtualFolders, 0) + assert.Len(t, user2.VirtualFolders, 4) + assert.Equal(t, int64(0), user1.ExpirationDate) + assert.Equal(t, int64(0), user2.ExpirationDate) + assert.Len(t, user1.Filters.AccessTime, 0) + assert.Len(t, user2.Filters.AccessTime, 1) + + group2.UserSettings.FsConfig = vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: defaultUsername, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + } + group2.UserSettings.Permissions = map[string][]string{ + "/": {dataprovider.PermListItems, dataprovider.PermDownload}, + "/%username%": {dataprovider.PermListItems}, + } + group2.UserSettings.DownloadBandwidth = 128 + group2.UserSettings.UploadBandwidth = 256 + group2.UserSettings.Filters.PasswordStrength = 70 + group2.UserSettings.Filters.WebClient = []string{sdk.WebClientInfoChangeDisabled, sdk.WebClientMFADisabled} + _, _, err = httpdtest.UpdateGroup(group2, http.StatusOK) + assert.NoError(t, err) + user, err = dataprovider.CheckUserAndPass(defaultUsername, defaultPassword, "", common.ProtocolHTTP) + assert.NoError(t, err) + assert.Len(t, user.VirtualFolders, 4) + assert.Equal(t, sdk.LocalFilesystemProvider, user.FsConfig.Provider) + assert.Equal(t, int64(0), user.DownloadBandwidth) + assert.Equal(t, int64(0), user.UploadBandwidth) + assert.Equal(t, 0, user.Filters.PasswordStrength) + assert.Equal(t, []string{dataprovider.PermAny}, user.GetPermissionsForPath("/")) + assert.Equal(t, []string{dataprovider.PermListItems}, user.GetPermissionsForPath("/"+defaultUsername)) + assert.Len(t, user.Filters.WebClient, 2) + + group1.UserSettings.FsConfig = vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: altAdminUsername, + Prefix: "/dirs/%role%/%username%", + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + } + group1.UserSettings.MaxSessions = 2 + group1.UserSettings.QuotaFiles = 1000 + group1.UserSettings.UploadBandwidth = 512 + group1.UserSettings.DownloadBandwidth = 1024 + group1.UserSettings.TotalDataTransfer = 2048 + group1.UserSettings.ExpiresIn = 15 + group1.UserSettings.Filters.MaxUploadFileSize = 1024 * 1024 + group1.UserSettings.Filters.StartDirectory = "/startdir/%username%" + group1.UserSettings.Filters.PasswordStrength = 70 + group1.UserSettings.Filters.WebClient = []string{sdk.WebClientInfoChangeDisabled} + group1.UserSettings.Permissions = map[string][]string{ + "/": {dataprovider.PermListItems, dataprovider.PermUpload}, + "/sub/%username%": {dataprovider.PermRename}, + "/%role%/%username%": {dataprovider.PermDelete}, + } + group1.UserSettings.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/sub2/%role%/%username%test", + AllowedPatterns: []string{}, + DeniedPatterns: []string{"*.jpg", "*.zip"}, + }, + } + _, _, err = httpdtest.UpdateGroup(group1, http.StatusOK) + assert.NoError(t, err) + user, err = dataprovider.CheckUserAndPass(defaultUsername, defaultPassword, "", common.ProtocolHTTP) + assert.NoError(t, err) + assert.Len(t, user.VirtualFolders, 4) + assert.Equal(t, user.CreatedAt+int64(group1.UserSettings.ExpiresIn)*86400000, user.ExpirationDate) + assert.Equal(t, group1.UserSettings.Filters.PasswordStrength, user.Filters.PasswordStrength) + assert.Equal(t, sdk.SFTPFilesystemProvider, user.FsConfig.Provider) + assert.Equal(t, altAdminUsername, user.FsConfig.SFTPConfig.Username) + assert.Equal(t, "/dirs/"+role.Name+"/"+defaultUsername, user.FsConfig.SFTPConfig.Prefix) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermUpload}, user.GetPermissionsForPath("/")) + assert.Equal(t, []string{dataprovider.PermDelete}, user.GetPermissionsForPath(path.Join("/", role.Name, defaultUsername))) + assert.Equal(t, []string{dataprovider.PermRename}, user.GetPermissionsForPath(path.Join("/sub", defaultUsername))) + assert.Equal(t, group1.UserSettings.MaxSessions, user.MaxSessions) + assert.Equal(t, group1.UserSettings.QuotaFiles, user.QuotaFiles) + assert.Equal(t, group1.UserSettings.UploadBandwidth, user.UploadBandwidth) + assert.Equal(t, group1.UserSettings.TotalDataTransfer, user.TotalDataTransfer) + assert.Equal(t, group1.UserSettings.Filters.MaxUploadFileSize, user.Filters.MaxUploadFileSize) + assert.Equal(t, "/startdir/"+defaultUsername, user.Filters.StartDirectory) + if assert.Len(t, user.Filters.FilePatterns, 1) { + assert.Equal(t, "/sub2/"+role.Name+"/"+defaultUsername+"test", user.Filters.FilePatterns[0].Path) //nolint:goconst + } + if assert.Len(t, user.Filters.WebClient, 2) { + assert.Contains(t, user.Filters.WebClient, sdk.WebClientInfoChangeDisabled) + assert.Contains(t, user.Filters.WebClient, sdk.WebClientMFADisabled) + } + // Attempt to create a user with a weak password and group1 as the primary group: this should fail + u = getTestUser() + u.Username = rand.Text() + u.Password = defaultPassword + u.Groups = []sdk.GroupMapping{ + { + Name: group1.Name, + Type: sdk.GroupTypePrimary, + }, + } + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "insecure password") + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) +} + +func TestConfigs(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + configs, err := dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Equal(t, int64(0), configs.UpdatedAt) + assert.Nil(t, configs.SFTPD) + assert.Nil(t, configs.SMTP) + configs = dataprovider.Configs{ + SFTPD: &dataprovider.SFTPDConfigs{}, + SMTP: &dataprovider.SMTPConfigs{}, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Greater(t, configs.UpdatedAt, int64(0)) + + configs = dataprovider.Configs{ + SFTPD: &dataprovider.SFTPDConfigs{ + Ciphers: []string{"unknown"}, + }, + SMTP: &dataprovider.SMTPConfigs{}, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.ErrorIs(t, err, util.ErrValidation) + configs = dataprovider.Configs{ + SFTPD: &dataprovider.SFTPDConfigs{}, + SMTP: &dataprovider.SMTPConfigs{ + Host: "smtp.example.com", + Port: -1, + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.ErrorIs(t, err, util.ErrValidation) + + configs = dataprovider.Configs{ + SMTP: &dataprovider.SMTPConfigs{ + Host: "mail.example.com", + Port: 587, + User: "test@example.com", + AuthType: 3, + Encryption: 2, + OAuth2: dataprovider.SMTPOAuth2{ + Provider: 1, + Tenant: "", + ClientID: "", + }, + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + if assert.ErrorIs(t, err, util.ErrValidation) { + assert.Contains(t, err.Error(), "smtp oauth2: client id is required") + } + configs.SMTP.OAuth2 = dataprovider.SMTPOAuth2{ + Provider: 1, + ClientID: "client id", + ClientSecret: kms.NewPlainSecret("client secret"), + RefreshToken: kms.NewPlainSecret("refresh token"), + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Equal(t, 3, configs.SMTP.AuthType) + assert.Equal(t, 1, configs.SMTP.OAuth2.Provider) + + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestBasicIPListEntriesHandling(t *testing.T) { + entry := dataprovider.IPListEntry{ + IPOrNet: "::ffff:12.34.56.78", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Description: "test desc", + } + _, _, err := httpdtest.GetIPListEntry(entry.IPOrNet, -1, http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusNotFound) + assert.NoError(t, err) + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) + assert.Error(t, err) + // IPv4 address in IPv6 will be converted to standard IPv4 + entry1, _, err := httpdtest.GetIPListEntry("12.34.56.78/32", dataprovider.IPListTypeAllowList, http.StatusOK) + assert.NoError(t, err) + + entry = dataprovider.IPListEntry{ + IPOrNet: "192.168.0.0/24", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + } + entry2, _, err := httpdtest.AddIPListEntry(entry, http.StatusCreated) + assert.NoError(t, err) + // adding the same entry again should fail + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusConflict) + assert.NoError(t, err) + // adding an entry with an invalid IP should fail + entry.IPOrNet = "not valid" + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + // adding an entry with an incompatible mode should fail + entry.IPOrNet = entry2.IPOrNet + entry.Mode = -1 + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + entry.Type = -1 + _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + entry = dataprovider.IPListEntry{ + IPOrNet: "2001:4860:4860::8888/120", + Type: dataprovider.IPListTypeRateLimiterSafeList, + Mode: dataprovider.ListModeDeny, + } + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + entry.Mode = dataprovider.ListModeAllow + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) + assert.NoError(t, err) + entry.Protocols = 3 + entry3, _, err := httpdtest.UpdateIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) + entry.Mode = dataprovider.ListModeDeny + _, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + + for _, tt := range []dataprovider.IPListType{dataprovider.IPListTypeAllowList, dataprovider.IPListTypeDefender, dataprovider.IPListTypeRateLimiterSafeList} { + entries, _, err := httpdtest.GetIPListEntries(tt, "", "", dataprovider.OrderASC, 0, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, entries, 1) { + switch tt { + case dataprovider.IPListTypeAllowList: + assert.Equal(t, entry1, entries[0]) + case dataprovider.IPListTypeDefender: + assert.Equal(t, entry2, entries[0]) + case dataprovider.IPListTypeRateLimiterSafeList: + assert.Equal(t, entry3, entries[0]) + } + } + } + + _, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "", "invalid order", 0, http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.GetIPListEntries(-1, "", "", dataprovider.OrderASC, 0, http.StatusBadRequest) + assert.NoError(t, err) + + _, err = httpdtest.RemoveIPListEntry(entry1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveIPListEntry(entry1, http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveIPListEntry(entry2, http.StatusOK) + assert.NoError(t, err) + entry2.Type = -1 + _, err = httpdtest.RemoveIPListEntry(entry2, http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.RemoveIPListEntry(entry3, http.StatusOK) + assert.NoError(t, err) +} + +func TestSearchIPListEntries(t *testing.T) { + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "192.168.0.0/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + { + IPOrNet: "192.168.0.1/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + { + IPOrNet: "192.168.0.2/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 5, + }, + { + IPOrNet: "192.168.0.3/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 8, + }, + { + IPOrNet: "10.8.0.0/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 3, + }, + { + IPOrNet: "10.8.1.0/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 8, + }, + { + IPOrNet: "10.8.2.0/24", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 1, + }, + } + + for _, e := range entries { + _, _, err := httpdtest.AddIPListEntry(e, http.StatusCreated) + assert.NoError(t, err) + } + + results, _, err := httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "", dataprovider.OrderASC, 20, http.StatusOK) + assert.NoError(t, err) + if assert.Equal(t, len(entries), len(results)) { + assert.Equal(t, "10.8.0.0/24", results[0].IPOrNet) + } + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "", dataprovider.OrderDESC, 20, http.StatusOK) + assert.NoError(t, err) + if assert.Equal(t, len(entries), len(results)) { + assert.Equal(t, "192.168.0.3/24", results[0].IPOrNet) + } + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "192.168.0.1/24", dataprovider.OrderASC, 1, http.StatusOK) + assert.NoError(t, err) + if assert.Equal(t, 1, len(results), results) { + assert.Equal(t, "192.168.0.2/24", results[0].IPOrNet) + } + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "", "10.8.2.0/24", dataprovider.OrderDESC, 1, http.StatusOK) + assert.NoError(t, err) + if assert.Equal(t, 1, len(results), results) { + assert.Equal(t, "10.8.1.0/24", results[0].IPOrNet) + } + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "10.", "", dataprovider.OrderASC, 20, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, len(results)) + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "192", "", dataprovider.OrderASC, 20, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 4, len(results)) + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "1", "", dataprovider.OrderASC, 20, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 7, len(results)) + results, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeAllowList, "108", "", dataprovider.OrderASC, 20, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, len(results)) + + for _, e := range entries { + _, err := httpdtest.RemoveIPListEntry(e, http.StatusOK) + assert.NoError(t, err) + } +} + +func TestIPListEntriesValidation(t *testing.T) { + entry := dataprovider.IPListEntry{ + IPOrNet: "::ffff:34.56.78.90/120", + Type: -1, + Mode: dataprovider.ListModeDeny, + } + _, resp, err := httpdtest.AddIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid list type") + entry.Type = dataprovider.IPListTypeRateLimiterSafeList + _, resp, err = httpdtest.AddIPListEntry(entry, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid list mode") + entry.Type = dataprovider.IPListTypeDefender + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) + assert.Error(t, err) + entry.IPOrNet = "34.56.78.0/24" + _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) +} + +func TestBasicActionRulesHandling(t *testing.T) { + actionName := "test_action" + a := dataprovider.BaseEventAction{ + Name: actionName, + Description: "test description", + Type: dataprovider.ActionTypeBackup, + Options: dataprovider.BaseEventActionOptions{}, + } + action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) + assert.NoError(t, err) + // adding the same action should fail + _, _, err = httpdtest.AddEventAction(a, http.StatusConflict) + assert.NoError(t, err) + actionGet, _, err := httpdtest.GetEventActionByName(actionName, http.StatusOK) + assert.NoError(t, err) + actions, _, err := httpdtest.GetEventActions(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, len(actions), 0) + found := false + for _, ac := range actions { + if ac.Name == actionName { + assert.Equal(t, actionGet, ac) + found = true + } + } + assert.True(t, found) + a.Description = "new description" + a.Type = dataprovider.ActionTypeDataRetentionCheck + a.Options = dataprovider.BaseEventActionOptions{ + RetentionConfig: dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 144, + }, + { + Path: "/p1", + Retention: 0, + }, + { + Path: "/p2", + Retention: 12, + }, + }, + }, + } + _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) + assert.NoError(t, err) + a.Type = dataprovider.ActionTypeCommand + a.Options = dataprovider.BaseEventActionOptions{ + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: filepath.Join(os.TempDir(), "test_cmd"), + Timeout: 20, + EnvVars: []dataprovider.KeyValue{ + { + Key: "NAME", + Value: "VALUE", + }, + }, + }, + } + dataprovider.EnabledActionCommands = []string{a.Options.CmdConfig.Cmd} + defer func() { + dataprovider.EnabledActionCommands = nil + }() + _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) + assert.NoError(t, err) + // invalid type + a.Type = 1000 + _, _, err = httpdtest.UpdateEventAction(a, http.StatusBadRequest) + assert.NoError(t, err) + + a.Type = dataprovider.ActionTypeEmail + a.Options = dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"email@example.com"}, + Bcc: []string{"bcc@example.com"}, + Subject: "Event: {{.Event}}", + Body: "test mail body", + Attachments: []string{"/{{.VirtualPath}}"}, + }, + } + + _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) + assert.NoError(t, err) + + a.Type = dataprovider.ActionTypeUserInactivityCheck + a.Options = dataprovider.BaseEventActionOptions{ + UserInactivityConfig: dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 20, + }, + } + _, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) + assert.NoError(t, err) + + a.Type = dataprovider.ActionTypeHTTP + a.Options = dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "https://localhost:1234", + Username: defaultUsername, + Password: kms.NewPlainSecret(defaultPassword), + Headers: []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + }, + Timeout: 10, + SkipTLSVerify: true, + Method: http.MethodPost, + QueryParameters: []dataprovider.KeyValue{ + { + Key: "a", + Value: "b", + }, + }, + Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, + }, + } + action, _, err = httpdtest.UpdateEventAction(a, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, action.Options.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, action.Options.HTTPConfig.Password.GetPayload()) + assert.Empty(t, action.Options.HTTPConfig.Password.GetKey()) + assert.Empty(t, action.Options.HTTPConfig.Password.GetAdditionalData()) + // update again and check that the password was preserved + dbAction, err := dataprovider.EventActionExists(actionName) + assert.NoError(t, err) + action.Options.HTTPConfig.Password = kms.NewSecret( + dbAction.Options.HTTPConfig.Password.GetStatus(), + dbAction.Options.HTTPConfig.Password.GetPayload(), "", "") + action, _, err = httpdtest.UpdateEventAction(action, http.StatusOK) + assert.NoError(t, err) + dbAction, err = dataprovider.EventActionExists(actionName) + assert.NoError(t, err) + err = dbAction.Options.HTTPConfig.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, defaultPassword, dbAction.Options.HTTPConfig.Password.GetPayload()) + + r := dataprovider.EventRule{ + Name: "test_rule_name", + Status: 1, + Description: "", + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"upload"}, + Options: dataprovider.ConditionOptions{ + EventStatuses: []int{2, 3}, + MinFileSize: 1024 * 1024, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: actionName, + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + IsFailureAction: false, + StopOnFailure: true, + ExecuteSync: true, + }, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + // adding the same rule should fail + _, _, err = httpdtest.AddEventRule(r, http.StatusConflict) + assert.NoError(t, err) + + rule.Description = "new rule desc" + rule.Trigger = 1000 + _, _, err = httpdtest.UpdateEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + rule.Trigger = dataprovider.EventTriggerFsEvent + rule, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) + assert.NoError(t, err) + + ruleGet, _, err := httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, ruleGet.Actions, 1) { + if assert.NotNil(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password) { + assert.Equal(t, sdkkms.SecretStatusSecretBox, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetPayload()) + assert.Empty(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetKey()) + assert.Empty(t, ruleGet.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetAdditionalData()) + } + } + rules, _, err := httpdtest.GetEventRules(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, len(rules), 0) + found = false + for _, ru := range rules { + if ru.Name == rule.Name { + assert.Equal(t, ruleGet, ru) + found = true + } + } + assert.True(t, found) + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateEventRule(rule, http.StatusNotFound) + assert.NoError(t, err) + _, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule, http.StatusNotFound) + assert.NoError(t, err) + + _, err = httpdtest.RemoveEventAction(action, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateEventAction(action, http.StatusNotFound) + assert.NoError(t, err) + _, _, err = httpdtest.GetEventActionByName(actionName, http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action, http.StatusNotFound) + assert.NoError(t, err) +} + +func TestActionRuleRelations(t *testing.T) { + a1 := dataprovider.BaseEventAction{ + Name: "action1", + Description: "test description", + Type: dataprovider.ActionTypeBackup, + Options: dataprovider.BaseEventActionOptions{}, + } + a2 := dataprovider.BaseEventAction{ + Name: "action2", + Type: dataprovider.ActionTypeTransferQuotaReset, + Options: dataprovider.BaseEventActionOptions{}, + } + a3 := dataprovider.BaseEventAction{ + Name: "action3", + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"test@example.net"}, + ContentType: 1, + Subject: "test subject", + Body: "test body", + }, + }, + } + action1, _, err := httpdtest.AddEventAction(a1, http.StatusCreated) + assert.NoError(t, err) + action2, _, err := httpdtest.AddEventAction(a2, http.StatusCreated) + assert.NoError(t, err) + action3, _, err := httpdtest.AddEventAction(a3, http.StatusCreated) + assert.NoError(t, err) + + r1 := dataprovider.EventRule{ + Name: "rule1", + Description: "", + Trigger: dataprovider.EventTriggerProviderEvent, + Conditions: dataprovider.EventConditions{ + ProviderEvents: []string{"add"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action1.Name, + }, + Order: 1, + }, + }, + } + rule1, _, err := httpdtest.AddEventRule(r1, http.StatusCreated) + assert.NoError(t, err) + if assert.Len(t, rule1.Actions, 2) { + assert.Equal(t, action1.Name, rule1.Actions[0].Name) + assert.Equal(t, 1, rule1.Actions[0].Order) + assert.Equal(t, action3.Name, rule1.Actions[1].Name) + assert.Equal(t, 2, rule1.Actions[1].Order) + assert.True(t, rule1.Actions[1].Options.IsFailureAction) + } + + r2 := dataprovider.EventRule{ + Name: "rule2", + Description: "", + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "1", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + Options: dataprovider.ConditionOptions{ + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: "g*", + InverseMatch: true, + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action3.Name, + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 1, + }, + }, + } + rule2, _, err := httpdtest.AddEventRule(r2, http.StatusCreated) + assert.NoError(t, err) + if assert.Len(t, rule1.Actions, 2) { + assert.Equal(t, action2.Name, rule2.Actions[0].Name) + assert.Equal(t, 1, rule2.Actions[0].Order) + assert.Equal(t, action3.Name, rule2.Actions[1].Name) + assert.Equal(t, 2, rule2.Actions[1].Order) + assert.True(t, rule2.Actions[1].Options.IsFailureAction) + } + // check the references + action1, _, err = httpdtest.GetEventActionByName(action1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action1.Rules, 1) + assert.True(t, slices.Contains(action1.Rules, rule1.Name)) + action2, _, err = httpdtest.GetEventActionByName(action2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action2.Rules, 1) + assert.True(t, slices.Contains(action2.Rules, rule2.Name)) + action3, _, err = httpdtest.GetEventActionByName(action3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action3.Rules, 2) + assert.True(t, slices.Contains(action3.Rules, rule1.Name)) + assert.True(t, slices.Contains(action3.Rules, rule2.Name)) + // referenced actions cannot be removed + _, err = httpdtest.RemoveEventAction(action1, http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusBadRequest) + assert.NoError(t, err) + // remove action3 from rule2 + r2.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action2.Name, + }, + Order: 10, + }, + } + rule2.Status = 1 + rule2, _, err = httpdtest.UpdateEventRule(r2, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, rule2.Actions, 1) { + assert.Equal(t, action2.Name, rule2.Actions[0].Name) + assert.Equal(t, 10, rule2.Actions[0].Order) + } + // check the updated relation + action3, _, err = httpdtest.GetEventActionByName(action3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action3.Rules, 1) + assert.True(t, slices.Contains(action3.Rules, rule1.Name)) + + _, err = httpdtest.RemoveEventRule(rule1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule2, http.StatusOK) + assert.NoError(t, err) + // no relations anymore + action1, _, err = httpdtest.GetEventActionByName(action1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action1.Rules, 0) + action2, _, err = httpdtest.GetEventActionByName(action2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action2.Rules, 0) + action3, _, err = httpdtest.GetEventActionByName(action3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, action3.Rules, 0) + + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action3, http.StatusOK) + assert.NoError(t, err) +} + +func TestOnDemandEventRules(t *testing.T) { + ruleName := "test_on_demand_rule" + a := dataprovider.BaseEventAction{ + Name: "a", + Type: dataprovider.ActionTypeBackup, + Options: dataprovider.BaseEventActionOptions{}, + } + action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) + assert.NoError(t, err) + r := dataprovider.EventRule{ + Name: ruleName, + Status: 1, + Trigger: dataprovider.EventTriggerOnDemand, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: a.Name, + }, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.RunOnDemandRule(ruleName, http.StatusAccepted) + assert.NoError(t, err) + rule.Status = 0 + _, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) + assert.NoError(t, err) + resp, err := httpdtest.RunOnDemandRule(ruleName, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "is inactive") + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RunOnDemandRule(ruleName, http.StatusNotFound) + assert.NoError(t, err) +} + +func TestIDPLoginEventRule(t *testing.T) { + ruleName := "test_IDP_login_rule" + a := dataprovider.BaseEventAction{ + Name: "a", + Type: dataprovider.ActionTypeIDPAccountCheck, + Options: dataprovider.BaseEventActionOptions{ + IDPConfig: dataprovider.EventActionIDPAccountCheck{ + Mode: 1, + TemplateUser: `{"username": "user"}`, + TemplateAdmin: `{"username": "admin"}`, + }, + }, + } + action, resp, err := httpdtest.AddEventAction(a, http.StatusCreated) + assert.NoError(t, err, string(resp)) + r := dataprovider.EventRule{ + Name: ruleName, + Status: 1, + Trigger: dataprovider.EventTriggerIDPLogin, + Conditions: dataprovider.EventConditions{ + IDPLoginEvent: 1, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "username", + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: a.Name, + }, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + rule.Status = 0 + _, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action, http.StatusOK) + assert.NoError(t, err) +} + +func TestEventActionValidation(t *testing.T) { + action := dataprovider.BaseEventAction{ + Name: "", + } + _, resp, err := httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "name is mandatory") + action = dataprovider.BaseEventAction{ + Name: "n", + Type: -1, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid action type") + action.Type = dataprovider.ActionTypeHTTP + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "HTTP endpoint is required") + action.Options.HTTPConfig.Endpoint = "abc" + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid HTTP endpoint schema") + action.Options.HTTPConfig.Endpoint = "http://localhost" + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid HTTP timeout") + action.Options.HTTPConfig.Timeout = 20 + action.Options.HTTPConfig.Headers = []dataprovider.KeyValue{ + { + Key: "", + Value: "", + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid HTTP headers") + action.Options.HTTPConfig.Headers = []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + } + action.Options.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusRedacted, "payload", "", "") + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "cannot save HTTP configuration with a redacted secret") + action.Options.HTTPConfig.Password = nil + action.Options.HTTPConfig.Method = http.MethodTrace + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "unsupported HTTP method") + action.Options.HTTPConfig.Method = http.MethodGet + action.Options.HTTPConfig.QueryParameters = []dataprovider.KeyValue{ + { + Key: "a", + Value: "", + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid HTTP query parameters") + action.Options.HTTPConfig.QueryParameters = nil + action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ + { + Name: "", + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "HTTP part name is required") + action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ + { + Name: "p1", + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "HTTP part body is required if no file path is provided") + action.Options.HTTPConfig.Parts = []dataprovider.HTTPPart{ + { + Name: "p1", + Filepath: "p", + }, + } + action.Options.HTTPConfig.Body = "b" + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "multipart requests require no body") + action.Options.HTTPConfig.Body = "" + action.Options.HTTPConfig.Headers = []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "content type is automatically set for multipart requests") + + action.Type = dataprovider.ActionTypeCommand + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "command is required") + action.Options.CmdConfig.Cmd = "relative" + dataprovider.EnabledActionCommands = []string{action.Options.CmdConfig.Cmd} + defer func() { + dataprovider.EnabledActionCommands = nil + }() + + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid command, it must be an absolute path") + action.Options.CmdConfig.Cmd = filepath.Join(os.TempDir(), "cmd") + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "is not allowed") + + dataprovider.EnabledActionCommands = []string{action.Options.CmdConfig.Cmd} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid command action timeout") + + action.Options.CmdConfig.Timeout = 30 + action.Options.CmdConfig.EnvVars = []dataprovider.KeyValue{ + { + Key: "k", + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid command env vars") + action.Options.CmdConfig.EnvVars = nil + action.Options.CmdConfig.Args = []string{"arg1", ""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid command args") + action.Options.CmdConfig.Args = nil + // restrict commands + if runtime.GOOS == osWindows { + dataprovider.EnabledActionCommands = []string{"C:\\cmd.exe"} + } else { + dataprovider.EnabledActionCommands = []string{"/bin/sh"} + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "is not allowed") + dataprovider.EnabledActionCommands = nil + + action.Type = dataprovider.ActionTypeEmail + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least one email recipient is required") + action.Options.EmailConfig.Recipients = []string{""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid email recipients") + action.Options.EmailConfig.Recipients = []string{"a@a.com"} + action.Options.EmailConfig.Bcc = []string{""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid email bcc") + action.Options.EmailConfig.Bcc = nil + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "email subject is required") + action.Options.EmailConfig.Subject = "subject" + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "email body is required") + + action.Type = dataprovider.ActionTypeDataRetentionCheck + action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ + Folders: nil, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "nothing to delete") + action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 0, + }, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "nothing to delete") + action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: "../path", + Retention: 1, + }, + { + Path: "/path", + Retention: 10, + }, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "duplicated folder path") + action.Options.RetentionConfig = dataprovider.EventActionDataRetentionConfig{ + Folders: []dataprovider.FolderRetention{ + { + Path: "p", + Retention: -1, + }, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid folder retention") + action.Type = dataprovider.ActionTypeFilesystem + action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "no path to rename specified") + action.Options.FsConfig.Renames = []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "", + Value: "/adir", + }, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid paths to rename") + action.Options.FsConfig.Renames = []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "adir", + Value: "/adir", + }, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "rename source and target cannot be equal") + action.Options.FsConfig.Renames = []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/", + Value: "/dir", + }, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "renaming the root directory is not allowed") + action.Options.FsConfig.Type = dataprovider.FilesystemActionMkdirs + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "no directory to create specified") + action.Options.FsConfig.MkDirs = []string{"dir1", ""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid directory to create") + action.Options.FsConfig.Type = dataprovider.FilesystemActionDelete + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "no path to delete specified") + action.Options.FsConfig.Deletes = []string{"item1", ""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid path to delete") + action.Options.FsConfig.Type = dataprovider.FilesystemActionExist + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "no path to check for existence specified") + action.Options.FsConfig.Exist = []string{"item1", ""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid path to check for existence") + action.Options.FsConfig.Type = dataprovider.FilesystemActionCompress + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "archive name is mandatory") + action.Options.FsConfig.Compress.Name = "archive.zip" + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "no path to compress specified") + action.Options.FsConfig.Compress.Paths = []string{"item1", ""} + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid path to compress") + action.Type = dataprovider.ActionTypePasswordExpirationCheck + action.Options.PwdExpirationConfig.Threshold = 0 + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "threshold must be greater than 0") + action.Type = dataprovider.ActionTypeIDPAccountCheck + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least a template must be set") + action.Options.IDPConfig.TemplateAdmin = "{}" + action.Options.IDPConfig.Mode = 100 + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid account check mode") + action.Type = dataprovider.ActionTypeUserInactivityCheck + action.Options = dataprovider.BaseEventActionOptions{ + UserInactivityConfig: dataprovider.EventActionUserInactivity{ + DisableThreshold: 0, + DeleteThreshold: 0, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least a threshold must be defined") + action.Options = dataprovider.BaseEventActionOptions{ + UserInactivityConfig: dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 10, + }, + } + _, resp, err = httpdtest.AddEventAction(action, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "must be greater than deactivation threshold") +} + +func TestEventRuleValidation(t *testing.T) { + rule := dataprovider.EventRule{ + Name: "", + } + _, resp, err := httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "name is mandatory") + rule.Name = "r" + rule.Status = 100 + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid event rule status") + rule.Status = 1 + rule.Trigger = 1000 + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid event rule trigger") + rule.Trigger = dataprovider.EventTriggerFsEvent + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least one filesystem event is required") + rule.Conditions.FsEvents = []string{""} + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "unsupported fs event") + rule.Conditions.FsEvents = []string{"upload"} + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least one action is required") + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action1", + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "", + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "name not specified") + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action", + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action", + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "duplicated action") + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action11", + }, + Order: 1, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action12", + }, + Order: 1, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "duplicated order") + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action111", + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action112", + }, + Order: 2, + Options: dataprovider.EventActionOptions{ + IsFailureAction: true, + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least a non-failure action is required") + rule.Conditions.FsEvents = []string{"upload", "download"} + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action111", + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "sync execution is only supported for upload and pre-* events") + rule.Conditions.FsEvents = []string{"pre-upload", "download"} + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action", + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: false, + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "requires at least a sync action") + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action", + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "sync execution is only supported for upload and pre-* events") + + rule.Conditions.FsEvents = []string{"download"} + rule.Conditions.Options.EventStatuses = []int{3, 2, 8} + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action", + }, + Order: 1, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid event_status") + + rule.Trigger = dataprovider.EventTriggerProviderEvent + rule.Actions = []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "action1234", + }, + Order: 1, + Options: dataprovider.EventActionOptions{ + IsFailureAction: false, + }, + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least one provider event is required") + rule.Conditions.ProviderEvents = []string{""} + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "unsupported provider event") + rule.Conditions.ProviderEvents = []string{"add"} + rule.Conditions.Options.RoleNames = []dataprovider.ConditionPattern{ + { + Pattern: "", + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "empty condition pattern not allowed") + rule.Conditions.Options.RoleNames = nil + rule.Trigger = dataprovider.EventTriggerSchedule + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "at least one schedule is required") + rule.Conditions.Schedules = []dataprovider.Schedule{ + {}, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid schedule") + rule.Conditions.Schedules = []dataprovider.Schedule{ + { + Hours: "3", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + } + _, resp, err = httpdtest.AddEventRule(rule, http.StatusInternalServerError) + assert.NoError(t, err, string(resp)) + rule.Trigger = dataprovider.EventTriggerIDPLogin + rule.Conditions.IDPLoginEvent = 100 + _, resp, err = httpdtest.AddEventRule(rule, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid Identity Provider login event") +} + +func TestUserBandwidthLimits(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 128 + u.DownloadBandwidth = 96 + u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ + { + Sources: []string{"1"}, + }, + } + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "Validation error: could not parse bandwidth limit source") + u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ + { + Sources: nil, + }, + } + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "Validation error: no bandwidth limit source specified") + u.Filters.BandwidthLimits = []sdk.BandwidthLimit{ + { + Sources: []string{"127.0.0.0/8", "::1/128"}, + UploadBandwidth: 256, + }, + { + Sources: []string{"10.0.0.0/8"}, + UploadBandwidth: 512, + DownloadBandwidth: 256, + }, + } + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, user.Filters.BandwidthLimits, 2) + assert.Equal(t, u.Filters.BandwidthLimits, user.Filters.BandwidthLimits) + + connID := xid.New().String() + localAddr := "127.0.0.1" + up, down := user.GetBandwidthForIP("127.0.1.1", connID) + assert.Equal(t, int64(256), up) + assert.Equal(t, int64(0), down) + conn := common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "127.0.1.1", user) + assert.Equal(t, int64(256), conn.User.UploadBandwidth) + assert.Equal(t, int64(0), conn.User.DownloadBandwidth) + up, down = user.GetBandwidthForIP("10.1.2.3", connID) + assert.Equal(t, int64(512), up) + assert.Equal(t, int64(256), down) + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "10.2.1.4:1234", user) + assert.Equal(t, int64(512), conn.User.UploadBandwidth) + assert.Equal(t, int64(256), conn.User.DownloadBandwidth) + up, down = user.GetBandwidthForIP("192.168.1.2", connID) + assert.Equal(t, int64(128), up) + assert.Equal(t, int64(96), down) + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "172.16.0.1", user) + assert.Equal(t, int64(128), conn.User.UploadBandwidth) + assert.Equal(t, int64(96), conn.User.DownloadBandwidth) + up, down = user.GetBandwidthForIP("invalid", connID) + assert.Equal(t, int64(128), up) + assert.Equal(t, int64(96), down) + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "172.16.0", user) + assert.Equal(t, int64(128), conn.User.UploadBandwidth) + assert.Equal(t, int64(96), conn.User.DownloadBandwidth) + + user.Filters.BandwidthLimits = []sdk.BandwidthLimit{ + { + Sources: []string{"10.0.0.0/24"}, + UploadBandwidth: 256, + DownloadBandwidth: 512, + }, + } + user, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + if assert.Len(t, user.Filters.BandwidthLimits, 1) { + bwLimit := user.Filters.BandwidthLimits[0] + assert.Equal(t, []string{"10.0.0.0/24"}, bwLimit.Sources) + assert.Equal(t, int64(256), bwLimit.UploadBandwidth) + assert.Equal(t, int64(512), bwLimit.DownloadBandwidth) + } + up, down = user.GetBandwidthForIP("10.1.2.3", connID) + assert.Equal(t, int64(128), up) + assert.Equal(t, int64(96), down) + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "172.16.0.2", user) + assert.Equal(t, int64(128), conn.User.UploadBandwidth) + assert.Equal(t, int64(96), conn.User.DownloadBandwidth) + up, down = user.GetBandwidthForIP("10.0.0.26", connID) + assert.Equal(t, int64(256), up) + assert.Equal(t, int64(512), down) + conn = common.NewBaseConnection(connID, common.ProtocolHTTP, localAddr, "10.0.0.28", user) + assert.Equal(t, int64(256), conn.User.UploadBandwidth) + assert.Equal(t, int64(512), conn.User.DownloadBandwidth) + + // this works if we remove the omitempty tag from BandwidthLimits + /*user.Filters.BandwidthLimits = nil + user, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + assert.Len(t, user.Filters.BandwidthLimits, 0)*/ + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestAccessTimeValidation(t *testing.T) { + u := getTestUser() + u.Filters.AccessTime = []sdk.TimePeriod{ + { + DayOfWeek: 8, + From: "10:00", + To: "18:00", + }, + } + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "invalid day of week") + u.Filters.AccessTime = []sdk.TimePeriod{ + { + DayOfWeek: 6, + From: "10:00", + To: "18", + }, + } + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "invalid time of day") + u.Filters.AccessTime = []sdk.TimePeriod{ + { + DayOfWeek: 6, + From: "11:00", + To: "10:58", + }, + } + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "The end time cannot be earlier than the start time") +} + +func TestUserTimestamps(t *testing.T) { + user, resp, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + createdAt := user.CreatedAt + updatedAt := user.UpdatedAt + assert.Equal(t, int64(0), user.LastLogin) + assert.Equal(t, int64(0), user.FirstDownload) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Greater(t, createdAt, int64(0)) + assert.Greater(t, updatedAt, int64(0)) + mappedPath := filepath.Join(os.TempDir(), "mapped_dir") + folderName := filepath.Base(mappedPath) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, + VirtualPath: "/vdir", + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + time.Sleep(10 * time.Millisecond) + user, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + assert.Equal(t, int64(0), user.LastLogin) + assert.Equal(t, int64(0), user.FirstDownload) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, createdAt, user.CreatedAt) + assert.Greater(t, user.UpdatedAt, updatedAt) + updatedAt = user.UpdatedAt + // after a folder update or delete the user updated_at field should change + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + time.Sleep(10 * time.Millisecond) + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.LastLogin) + assert.Equal(t, int64(0), user.FirstDownload) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, createdAt, user.CreatedAt) + assert.Greater(t, user.UpdatedAt, updatedAt) + updatedAt = user.UpdatedAt + time.Sleep(10 * time.Millisecond) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.LastLogin) + assert.Equal(t, int64(0), user.FirstDownload) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, createdAt, user.CreatedAt) + assert.Greater(t, user.UpdatedAt, updatedAt) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminTimestamps(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + createdAt := admin.CreatedAt + updatedAt := admin.UpdatedAt + assert.Equal(t, int64(0), admin.LastLogin) + assert.Greater(t, createdAt, int64(0)) + assert.Greater(t, updatedAt, int64(0)) + time.Sleep(10 * time.Millisecond) + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), admin.LastLogin) + assert.Equal(t, createdAt, admin.CreatedAt) + assert.Greater(t, admin.UpdatedAt, updatedAt) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestHTTPUserAuthEmptyPassword(t *testing.T) { + u := getTestUser() + u.Password = "" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, "") + c := httpclient.GetHTTPClient() + resp, err := c.Do(req) + c.CloseIdleConnections() + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unexpected status code 401") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHTTPAnonymousUser(t *testing.T) { + u := getTestUser() + u.Filters.IsAnonymous = true + _, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.Error(t, err) + user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.IsAnonymous) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) + assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) + assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + c := httpclient.GetHTTPClient() + resp, err := c.Do(req) + c.CloseIdleConnections() + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unexpected status code 403") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHTTPUserAuthentication(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + c := httpclient.GetHTTPClient() + resp, err := c.Do(req) + c.CloseIdleConnections() + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + userToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, userToken) + err = resp.Body.Close() + assert.NoError(t, err) + // login with wrong credentials + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, "") + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, "wrong pwd") + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + respBody, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Contains(t, string(respBody), "invalid credentials") + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth("wrong username", defaultPassword) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + respBody, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Contains(t, string(respBody), "invalid credentials") + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultTokenAuthUser, defaultTokenAuthPass) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder = make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + adminToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, adminToken) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, versionPath), nil) + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", adminToken)) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + // using the user token should not work + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, versionPath), nil) + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", userToken)) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userLogoutPath), nil) + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", adminToken)) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userLogoutPath), nil) + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", userToken)) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermMFADisabled(t *testing.T) { + u := getTestUser() + u.Filters.WebClient = []string{sdk.WebClientMFADisabled} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} + _, resp, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + assert.Contains(t, string(resp), "you cannot require two-factor authentication and at the same time disallow it") + user.Filters.TwoFactorAuthProtocols = nil + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) // MFA is disabled for this user + + user.Filters.WebClient = []string{sdk.WebClientWriteDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now we cannot disable MFA for this user + user.Filters.WebClient = []string{sdk.WebClientMFADisabled} + _, resp, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + assert.Contains(t, string(resp), "two-factor authentication cannot be disabled for a user with an active configuration") + + saveReq := make(map[string]bool) + saveReq["enabled"] = false + asJSON, err = json.Marshal(saveReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUpdateUserPassword(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.PasswordStrength = 20 + g.UserSettings.MaxSessions = 10 + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Filters.RequirePasswordChange = true + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + lastPwdChange := user.LastPasswordChange + time.Sleep(100 * time.Millisecond) + newPwd := "uaCooGh3pheiShooghah" + err = dataprovider.UpdateUserPassword(user.Username, newPwd, "", "", "") + assert.NoError(t, err) + _, err = dataprovider.CheckUserAndPass(user.Username, newPwd, "", common.ProtocolHTTP) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, user.Filters.RequirePasswordChange) + assert.NotEqual(t, lastPwdChange, user.LastPasswordChange) + // check that we don't save group overrides + assert.Equal(t, 0, user.MaxSessions) + assert.Equal(t, 0, user.Filters.PasswordStrength) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoginRedirectNext(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + uri := webClientFilesPath + "?path=%2F" //nolint:goconst + req, err := http.NewRequest(http.MethodGet, uri, nil) + assert.NoError(t, err) + req.RequestURI = uri + rr := executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + redirectURI := rr.Header().Get("Location") + assert.Equal(t, webClientLoginPath+"?next="+url.QueryEscape(uri), redirectURI) //nolint:goconst + // render the login page + req, err = http.NewRequest(http.MethodGet, redirectURI, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", redirectURI)) + // now login the user and check the redirect + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form := getLoginForm(defaultUsername, defaultPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, redirectURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = redirectURI + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, uri, rr.Header().Get("Location")) + // unsafe URI + loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + unsafeURI := webClientLoginPath + "?next=" + url.QueryEscape("http://example.net") + req, err = http.NewRequest(http.MethodPost, unsafeURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = unsafeURI + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + unsupportedURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientProfilePath) + req, err = http.NewRequest(http.MethodPost, unsupportedURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = unsupportedURI + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMustChangePasswordRequirement(t *testing.T) { + u := getTestUser() + u.Filters.RequirePasswordChange = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.True(t, user.Filters.RequirePasswordChange) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, userFilesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Password change required. Please set a new password to continue to use your account") + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdRequired) + // change pwd + pwd := make(map[string]string) + pwd["current_password"] = defaultPassword + pwd["new_password"] = altAdminPassword + asJSON, err := json.Marshal(pwd) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check that the change pwd bool is changed + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, user.Filters.RequirePasswordChange) + // get a new token + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, altAdminPassword) + assert.NoError(t, err) + webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, altAdminPassword) + assert.NoError(t, err) + // the new token should work + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check the same as above but changing password from the WebClient UI + user.Filters.RequirePasswordChange = true + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, altAdminPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, webToken) + assert.NoError(t, err) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("current_password", altAdminPassword) + form.Set("new_password1", defaultPassword) + form.Set("new_password2", defaultPassword) + req, err = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestTwoFactorRequirements(t *testing.T) { + u := getTestUser() + u.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolFTP} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols") + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nError2FARequired) + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolHTTP}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following protocols are required") + + userTOTPConfig.Protocols = []string{common.ProtocolHTTP, common.ProtocolFTP} + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now get new tokens and check that the two factor requirements are now met + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", passcode) + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + userToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, userToken) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userDirsPath), nil) + assert.NoError(t, err) + setBearerForReq(req, userToken) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestTwoFactorRequirementsGroupLevel(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolFTP} + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nError2FARequired) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols") + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolHTTP}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following protocols are required") + + userTOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolFTP, common.ProtocolHTTP}, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // now get new tokens and check that the two factor requirements are now met + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", passcode) + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + userToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, userToken) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userDirsPath), nil) + assert.NoError(t, err) + setBearerForReq(req, userToken) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminMustChangePasswordRequirement(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Filters.RequirePasswordChange = true + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + + _, _, err = httpdtest.GetUsers(0, 0, http.StatusForbidden) + assert.NoError(t, err) + _, _, err = httpdtest.GetStatus(http.StatusForbidden) + assert.NoError(t, err) + + _, err = httpdtest.ChangeAdminPassword(altAdminPassword, defaultTokenAuthPass, http.StatusOK) + assert.NoError(t, err) + + httpdtest.SetJWTToken("") + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.RequirePasswordChange) + + // get a new token + token, _, err = httpdtest.GetToken(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + + _, _, err = httpdtest.GetUsers(0, 0, http.StatusOK) + assert.NoError(t, err) + + desc := xid.New().String() + admin.Filters.RequirePasswordChange = true + admin.Filters.RequireTwoFactor = true + admin.Description = desc + _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "require password change mismatch") + } + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.RequirePasswordChange) + assert.False(t, admin.Filters.RequireTwoFactor) + assert.Equal(t, desc, admin.Description) + + httpdtest.SetJWTToken("") + + admin.Filters.RequirePasswordChange = true + _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + // test the same for the WebAdmin + webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + req.RequestURI = webUsersPath + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // The change password page should be accessible, we get the CSRF from it. + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeAdminPwdPath, webToken) + assert.NoError(t, err) + + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("current_password", defaultTokenAuthPass) + form.Set("new_password1", altAdminPassword) + form.Set("new_password2", altAdminPassword) + req, err = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.RequirePasswordChange) + + webToken, err = getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + req.RequestURI = webUsersPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminTwoFactorRequirements(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Filters.RequireTwoFactor = true + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, serverStatusPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met") + + webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webFoldersPath, nil) + assert.NoError(t, err) + req.RequestURI = webFoldersPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nError2FARequiredGeneric) + // add TOTP config + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], altAdminUsername) + assert.NoError(t, err) + adminTOTPConfig := dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + } + asJSON, err := json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", passcode) + req.SetBasicAuth(altAdminUsername, altAdminPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + token = responseHolder["access_token"].(string) + assert.NotEmpty(t, token) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, serverStatusPath), nil) + assert.NoError(t, err) + setBearerForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + // try to disable 2FA + disableReq := map[string]any{ + "enabled": false, + } + asJSON, err = json.Marshal(disableReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, fmt.Sprintf("%v%v", httpBaseURL, adminTOTPSavePath), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + bodyResp, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Contains(t, string(bodyResp), "two-factor authentication must be enabled") + err = resp.Body.Close() + assert.NoError(t, err) + // try to disable 2FA using the dedicated API + req, err = http.NewRequest(http.MethodPut, fmt.Sprintf("%v%v", httpBaseURL, path.Join(adminPath, altAdminUsername, "2fa", "disable")), nil) + assert.NoError(t, err) + setBearerForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + bodyResp, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Contains(t, string(bodyResp), "two-factor authentication must be enabled") + err = resp.Body.Close() + assert.NoError(t, err) + // disabling 2FA using another admin should work + token, err = getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername, "2fa", "disable"), nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.TOTPConfig.Enabled) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoginUserAPITOTP(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolHTTP}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now require HTTP and SSH for TOTP + user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolSSH} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // two factor auth cannot be disabled + config := make(map[string]any) + config["enabled"] = false + asJSON, err = json.Marshal(config) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "two-factor authentication must be enabled") + // all the required protocols must be enabled + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following protocols are required") + // setting all the required protocols should work + userTOTPConfig.Protocols = []string{common.ProtocolHTTP, common.ProtocolSSH} + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", passcode) + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + userToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, userToken) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", passcode) + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginAdminAPITOTP(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], admin.Username) + assert.NoError(t, err) + altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + adminTOTPConfig := dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + } + asJSON, err := json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 12) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(altAdminUsername, altAdminPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", "passcode") + req.SetBasicAuth(altAdminUsername, altAdminPassword) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", passcode) + req.SetBasicAuth(altAdminUsername, altAdminPassword) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + adminToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, adminToken) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, versionPath), nil) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + // get/set recovery codes + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // disable two-factor auth + saveReq := make(map[string]bool) + saveReq["enabled"] = false + asJSON, err = json.Marshal(saveReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 0) + // get/set recovery codes will not work + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestHTTPStreamZipError(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + userToken := responseHolder["access_token"].(string) + assert.NotEmpty(t, userToken) + err = resp.Body.Close() + assert.NoError(t, err) + + filesList := []string{"missing"} + asJSON, err := json.Marshal(filesList) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, fmt.Sprintf("%v%v", httpBaseURL, userStreamZipPath), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", userToken)) + resp, err = httpclient.GetHTTPClient().Do(req) + if !assert.Error(t, err) { // the connection will be closed + err = resp.Body.Close() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestBasicAdminHandling(t *testing.T) { + // we have one admin by default + admins, _, err := httpdtest.GetAdmins(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(admins), 1) + admin := getTestAdmin() + // the default admin already exists + _, _, err = httpdtest.AddAdmin(admin, http.StatusConflict) + assert.NoError(t, err) + + admin.Username = altAdminUsername + admin.Filters.Preferences.HideUserPageSections = 1 + 4 + 8 + admin.Filters.Preferences.DefaultUsersExpiration = 30 + admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.Preferences.HideGroups()) + assert.False(t, admin.Filters.Preferences.HideFilesystem()) + assert.True(t, admin.Filters.Preferences.HideVirtualFolders()) + assert.True(t, admin.Filters.Preferences.HideProfile()) + assert.False(t, admin.Filters.Preferences.HideACLs()) + assert.False(t, admin.Filters.Preferences.HideDiskQuotaAndBandwidthLimits()) + assert.False(t, admin.Filters.Preferences.HideAdvancedSettings()) + + admin.AdditionalInfo = "test info" + admin.Filters.Preferences.HideUserPageSections = 16 + 32 + 64 + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, "test info", admin.AdditionalInfo) + assert.False(t, admin.Filters.Preferences.HideGroups()) + assert.False(t, admin.Filters.Preferences.HideFilesystem()) + assert.False(t, admin.Filters.Preferences.HideVirtualFolders()) + assert.False(t, admin.Filters.Preferences.HideProfile()) + assert.True(t, admin.Filters.Preferences.HideACLs()) + assert.True(t, admin.Filters.Preferences.HideDiskQuotaAndBandwidthLimits()) + assert.True(t, admin.Filters.Preferences.HideAdvancedSettings()) + + admins, _, err = httpdtest.GetAdmins(1, 0, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, admins, 1) + assert.NotEqual(t, admin.Username, admins[0].Username) + + admins, _, err = httpdtest.GetAdmins(1, 1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, admins, 1) + assert.Equal(t, admin.Username, admins[0].Username) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusNotFound) + assert.NoError(t, err) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username+"123", http.StatusNotFound) + assert.NoError(t, err) + + admin.Username = defaultTokenAuthUser + _, err = httpdtest.RemoveAdmin(admin, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAdminGroups(t *testing.T) { + group1 := getTestGroup() + group1.Name += "_1" + group1, _, err := httpdtest.AddGroup(group1, http.StatusCreated) + assert.NoError(t, err) + group2 := getTestGroup() + group2.Name += "_2" + group2, _, err = httpdtest.AddGroup(group2, http.StatusCreated) + assert.NoError(t, err) + group3 := getTestGroup() + group3.Name += "_3" + group3, _, err = httpdtest.AddGroup(group3, http.StatusCreated) + assert.NoError(t, err) + + a := getTestAdmin() + a.Username = altAdminUsername + a.Groups = []dataprovider.AdminGroupMapping{ + { + Name: group1.Name, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, + }, + }, + { + Name: group2.Name, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, + }, + }, + { + Name: group3.Name, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsMembership, + }, + }, + } + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + assert.Len(t, admin.Groups, 3) + + groups, _, err := httpdtest.GetGroups(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, groups, 3) + for _, g := range groups { + if assert.Len(t, g.Admins, 1) { + assert.Equal(t, admin.Username, g.Admins[0]) + } + } + + admin, _, err = httpdtest.UpdateAdmin(a, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, admin.Groups, 1) + + // try to add a missing group + admin.Groups = []dataprovider.AdminGroupMapping{ + { + Name: group1.Name, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsPrimary, + }, + }, + { + Name: group2.Name, + Options: dataprovider.AdminGroupMappingOptions{ + AddToUsersAs: dataprovider.GroupAddToUsersAsSecondary, + }, + }, + } + + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Admins, 1) + + _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.Error(t, err) + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Admins, 1) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + group3, _, err = httpdtest.GetGroupByName(group3.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group3.Admins, 0) + + _, err = httpdtest.RemoveGroup(group3, http.StatusOK) + assert.NoError(t, err) +} + +func TestChangeAdminPassword(t *testing.T) { + _, err := httpdtest.ChangeAdminPassword("wrong", defaultTokenAuthPass, http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.ChangeAdminPassword(defaultTokenAuthPass, defaultTokenAuthPass, http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.ChangeAdminPassword(defaultTokenAuthPass, defaultTokenAuthPass+"1", http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.ChangeAdminPassword(defaultTokenAuthPass+"1", defaultTokenAuthPass, http.StatusUnauthorized) + assert.NoError(t, err) + admin, err := dataprovider.AdminExists(defaultTokenAuthUser) + assert.NoError(t, err) + admin.Password = defaultTokenAuthPass + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) +} + +func TestPasswordValidations(t *testing.T) { + if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { + t.Skip("this test is not supported with the memory provider") + } + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + providerConf := config.GetProviderConf() + assert.NoError(t, err) + providerConf.PasswordValidation.Admins.MinEntropy = 50 + providerConf.PasswordValidation.Users.MinEntropy = 70 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + + _, resp, err := httpdtest.AddAdmin(a, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "insecure password") + + _, resp, err = httpdtest.AddUser(getTestUser(), http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "insecure password") + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestAdminPasswordHashing(t *testing.T) { + if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { + t.Skip("this test is not supported with the memory provider") + } + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + providerConf := config.GetProviderConf() + assert.NoError(t, err) + providerConf.PasswordHashing.Algo = dataprovider.HashingAlgoArgon2ID + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + currentAdmin, err := dataprovider.AdminExists(defaultTokenAuthUser) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(currentAdmin.Password, "$2a$")) + + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + + newAdmin, err := dataprovider.AdminExists(altAdminUsername) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(newAdmin.Password, "$argon2id$")) + + token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + _, _, err = httpdtest.GetStatus(http.StatusOK) + assert.NoError(t, err) + + httpdtest.SetJWTToken("") + _, _, err = httpdtest.GetStatus(http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestDefaultUsersExpiration(t *testing.T) { + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Filters.Preferences.DefaultUsersExpiration = 30 + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + + token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + + _, _, err = httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.Error(t, err) + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.ExpirationDate, int64(0)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + u := getTestUser() + u.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)) + + _, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, u.ExpirationDate, user.ExpirationDate) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + httpdtest.SetJWTToken("") + _, _, err = httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + // render the user template page + webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webTemplateUser, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webTemplateUser+fmt.Sprintf("?from=%s", user.Username), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + httpdtest.SetJWTToken(token) + _, _, err = httpdtest.AddUser(u, http.StatusNotFound) + assert.NoError(t, err) + + httpdtest.SetJWTToken("") +} + +func TestAdminInvalidCredentials(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultTokenAuthUser, defaultTokenAuthPass) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + // wrong password + req.SetBasicAuth(defaultTokenAuthUser, "wrong pwd") + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + err = resp.Body.Close() + assert.NoError(t, err) + assert.Equal(t, dataprovider.ErrInvalidCredentials.Error(), responseHolder["error"].(string)) + // wrong username + req.SetBasicAuth("wrong username", defaultTokenAuthPass) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + responseHolder = make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + assert.NoError(t, err) + err = resp.Body.Close() + assert.NoError(t, err) + assert.Equal(t, dataprovider.ErrInvalidCredentials.Error(), responseHolder["error"].(string)) +} + +func TestAdminLastLogin(t *testing.T) { + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, int64(0), admin.LastLogin) + + _, _, err = httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, admin.LastLogin, int64(0)) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminAllowList(t *testing.T) { + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + + token, _, err := httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + _, _, err = httpdtest.GetStatus(http.StatusOK) + assert.NoError(t, err) + + httpdtest.SetJWTToken("") + + admin.Password = altAdminPassword + admin.Filters.AllowList = []string{"10.6.6.0/32"} + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + _, _, err = httpdtest.GetToken(altAdminUsername, altAdminPassword) + assert.EqualError(t, err, "wrong status code: got 401 want 200") + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserStatus(t *testing.T) { + u := getTestUser() + u.Status = 3 + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Status = 0 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Status = 2 + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + user.Status = 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUidGidLimits(t *testing.T) { + u := getTestUser() + u.UID = math.MaxInt32 + u.GID = math.MaxInt32 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, math.MaxInt32, user.GetUID()) + assert.Equal(t, math.MaxInt32, user.GetGID()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestAddUserNoCredentials(t *testing.T) { + u := getTestUser() + u.Password = "" + u.PublicKeys = []string{} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + // this user cannot login with an empty password but it still can use an SSH cert + _, err = getJWTAPITokenFromTestServer(defaultTokenAuthUser, "") + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestAddUserNoUsername(t *testing.T) { + u := getTestUser() + u.Username = "" + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAddUserNoHomeDir(t *testing.T) { + u := getTestUser() + u.HomeDir = "" + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAddUserInvalidHomeDir(t *testing.T) { + u := getTestUser() + u.HomeDir = "relative_path" //nolint:goconst + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAddUserNoPerms(t *testing.T) { + u := getTestUser() + u.Permissions = make(map[string][]string) + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Permissions["/"] = []string{} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAddUserInvalidEmail(t *testing.T) { + u := getTestUser() + u.Email = "invalid_email" + _, body, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(body), "Validation error: email") +} + +func TestAddUserInvalidPerms(t *testing.T) { + u := getTestUser() + u.Permissions["/"] = []string{"invalidPerm"} + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + // permissions for root dir are mandatory + u.Permissions["/"] = []string{} + u.Permissions["/somedir"] = []string{dataprovider.PermAny} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir/.."] = []string{dataprovider.PermAny} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAddUserInvalidFilters(t *testing.T) { + u := getTestUser() + u.Filters.AllowedIP = []string{"192.168.1.0/24", "192.168.2.0"} + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.AllowedIP = []string{} + u.Filters.DeniedIP = []string{"192.168.3.0/16", "invalid"} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedIP = []string{} + u.Filters.DeniedLoginMethods = []string{"invalid"} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedLoginMethods = dataprovider.ValidLoginMethods + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificateAndPwd} + u.Filters.DeniedProtocols = dataprovider.ValidProtocols + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedProtocols = []string{common.ProtocolFTP} + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "relative", + AllowedPatterns: []string{}, + DeniedPatterns: []string{}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + AllowedPatterns: []string{}, + DeniedPatterns: []string{}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/subdir", + AllowedPatterns: []string{"*.zip"}, + DeniedPatterns: []string{}, + }, + { + Path: "/subdir", + AllowedPatterns: []string{"*.rar"}, + DeniedPatterns: []string{"*.jpg"}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "relative", + AllowedPatterns: []string{}, + DeniedPatterns: []string{}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + AllowedPatterns: []string{}, + DeniedPatterns: []string{}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/subdir", + AllowedPatterns: []string{"*.zip"}, + }, + { + Path: "/subdir", + AllowedPatterns: []string{"*.rar"}, + DeniedPatterns: []string{"*.jpg"}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/subdir", + AllowedPatterns: []string{"a\\"}, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/subdir", + AllowedPatterns: []string{"*.*"}, + DenyPolicy: 100, + }, + } + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedProtocols = []string{"invalid"} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedProtocols = dataprovider.ValidProtocols + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.DeniedProtocols = nil + u.Filters.TLSUsername = "not a supported attribute" + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.Filters.TLSUsername = "" + u.Filters.WebClient = []string{"not a valid web client options"} + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestAddUserInvalidFsConfig(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.S3FilesystemProvider + u.FsConfig.S3Config.Bucket = "" + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.S3Config.Bucket = "testbucket" + u.FsConfig.S3Config.Region = "eu-west-1" //nolint:goconst + u.FsConfig.S3Config.AccessKey = "access-key" //nolint:goconst + u.FsConfig.S3Config.AccessSecret = kms.NewSecret(sdkkms.SecretStatusRedacted, "access-secret", "", "") + u.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?a=b" + u.FsConfig.S3Config.StorageClass = "Standard" //nolint:goconst + u.FsConfig.S3Config.KeyPrefix = "/adir/subdir/" + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.S3Config.AccessSecret.SetStatus(sdkkms.SecretStatusPlain) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.S3Config.KeyPrefix = "" + u.FsConfig.S3Config.UploadPartSize = 3 + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.S3Config.UploadPartSize = 5001 + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.S3Config.UploadPartSize = 0 + u.FsConfig.S3Config.UploadConcurrency = -1 + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.S3Config.UploadConcurrency = 0 + u.FsConfig.S3Config.DownloadPartSize = -1 + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "download_part_size cannot be") + } + u.FsConfig.S3Config.DownloadPartSize = 5001 + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "download_part_size cannot be") + } + u.FsConfig.S3Config.DownloadPartSize = 0 + u.FsConfig.S3Config.DownloadConcurrency = 100 + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid download concurrency") + } + u.FsConfig.S3Config.DownloadConcurrency = -1 + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid download concurrency") + } + u.FsConfig.S3Config.DownloadConcurrency = 0 + u.FsConfig.S3Config.Endpoint = "" + u.FsConfig.S3Config.Region = "" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "region cannot be empty") + } + u = getTestUser() + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "" + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.GCSConfig.Bucket = "abucket" + u.FsConfig.GCSConfig.StorageClass = "Standard" + u.FsConfig.GCSConfig.KeyPrefix = "/somedir/subdir/" + u.FsConfig.GCSConfig.Credentials = kms.NewSecret(sdkkms.SecretStatusRedacted, "test", "", "") //nolint:goconst + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.GCSConfig.Credentials.SetStatus(sdkkms.SecretStatusPlain) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.GCSConfig.KeyPrefix = "somedir/subdir/" //nolint:goconst + u.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() + u.FsConfig.GCSConfig.AutomaticCredentials = 0 + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.GCSConfig.Credentials = kms.NewSecret(sdkkms.SecretStatusSecretBox, "invalid", "", "") + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + + u = getTestUser() + u.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + u.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("http://foo\x7f.com/") + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.AzBlobConfig.SASURL = kms.NewSecret(sdkkms.SecretStatusRedacted, "key", "", "") + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.AzBlobConfig.SASURL = kms.NewEmptySecret() + u.FsConfig.AzBlobConfig.AccountName = "name" + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.AzBlobConfig.AccountKey = kms.NewSecret(sdkkms.SecretStatusRedacted, "key", "", "") + u.FsConfig.AzBlobConfig.KeyPrefix = "/amedir/subdir/" + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.AzBlobConfig.AccountKey.SetStatus(sdkkms.SecretStatusPlain) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.AzBlobConfig.KeyPrefix = "amedir/subdir/" + u.FsConfig.AzBlobConfig.UploadPartSize = -1 + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.AzBlobConfig.UploadPartSize = 101 + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + + u = getTestUser() + u.FsConfig.Provider = sdk.CryptedFilesystemProvider + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.CryptConfig.Passphrase = kms.NewSecret(sdkkms.SecretStatusRedacted, "akey", "", "") + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u = getTestUser() + u.FsConfig.Provider = sdk.SFTPFilesystemProvider + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.SFTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusRedacted, "randompkey", "", "") + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() + u.FsConfig.SFTPConfig.PrivateKey = kms.NewSecret(sdkkms.SecretStatusRedacted, "keyforpkey", "", "") + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret("pk") + u.FsConfig.SFTPConfig.Endpoint = "127.1.1.1:22" + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.BufferSize = -1 + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid buffer_size") + } + u.FsConfig.SFTPConfig.BufferSize = 1000 + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid buffer_size") + } + + u = getTestUser() + u.FsConfig.Provider = sdk.HTTPFilesystemProvider + u.FsConfig.HTTPConfig.Endpoint = "http://foo\x7f.com/" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid endpoint") + } + u.FsConfig.HTTPConfig.Endpoint = "http://127.0.0.1:9999/api/v1" + u.FsConfig.HTTPConfig.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, "", "", "") + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid encrypted password") + } + u.FsConfig.HTTPConfig.Password = nil + u.FsConfig.HTTPConfig.APIKey = kms.NewSecret(sdkkms.SecretStatusRedacted, redactedSecret, "", "") + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "cannot save a user with a redacted secret") + } + u.FsConfig.HTTPConfig.APIKey = nil + u.FsConfig.HTTPConfig.Endpoint = "/api/v1" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid endpoint schema") + } + u.FsConfig.HTTPConfig.Endpoint = "http://unix?api_prefix=v1" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid unix domain socket path") + } + u.FsConfig.HTTPConfig.Endpoint = "http://unix?socket_path=test.sock" + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + if assert.NoError(t, err) { + assert.Contains(t, string(resp), "invalid unix domain socket path") + } +} + +func TestUserRedactedPassword(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.S3FilesystemProvider + u.FsConfig.S3Config.Bucket = "b" + u.FsConfig.S3Config.Region = "eu-west-1" + u.FsConfig.S3Config.AccessKey = "access-key" + u.FsConfig.S3Config.RoleARN = "myRoleARN" + u.FsConfig.S3Config.AccessSecret = kms.NewSecret(sdkkms.SecretStatusRedacted, "access-secret", "", "") + u.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?k=m" + u.FsConfig.S3Config.StorageClass = "Standard" + u.FsConfig.S3Config.ACL = "bucket-owner-full-control" + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "cannot save a user with a redacted secret") + err = dataprovider.AddUser(&u, "", "", "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "cannot save a user with a redacted secret") + } + u.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("secret") + u.FsConfig.S3Config.SSECustomerKey = kms.NewSecret(sdkkms.SecretStatusRedacted, "mysecretkey", "", "") + _, resp, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err, string(resp)) + assert.Contains(t, string(resp), "cannot save a user with a redacted secret") + + u.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("key") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + folderName := "folderName" + vfolder := vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "crypted"), + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewSecret(sdkkms.SecretStatusRedacted, "crypted-secret", "", ""), + }, + }, + }, + VirtualPath: "/avpath", + } + + user.Password = defaultPassword + user.VirtualFolders = append(user.VirtualFolders, vfolder) + err = dataprovider.UpdateUser(&user, "", "", "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "cannot save a user with a redacted secret") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserType(t *testing.T) { + u := getTestUser() + u.Filters.UserType = string(sdk.UserTypeLDAP) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, string(sdk.UserTypeLDAP), user.Filters.UserType) + user.Filters.UserType = string(sdk.UserTypeOS) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, string(sdk.UserTypeOS), user.Filters.UserType) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestRetentionAPI(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + t.Cleanup(func() { + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + }) + + checks, _, err := httpdtest.GetRetentionChecks(http.StatusOK) + assert.NoError(t, err) + assert.Len(t, checks, 0) + + localFilePath := filepath.Join(user.HomeDir, "testdir", "testfile") + err = os.MkdirAll(filepath.Dir(localFilePath), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(localFilePath, []byte("test data"), os.ModePerm) + assert.NoError(t, err) + + folderRetention := []dataprovider.FolderRetention{ + { + Path: "/", + Retention: 24, + DeleteEmptyDirs: true, + }, + } + + check := common.RetentionCheck{ + Folders: folderRetention, + } + c := common.RetentionChecks.Add(check, &user) + require.NotNil(t, c) + + err = c.Start() + require.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + assert.FileExists(t, localFilePath) + + err = os.Chtimes(localFilePath, time.Now().Add(-48*time.Hour), time.Now().Add(-48*time.Hour)) + assert.NoError(t, err) + + err = c.Start() + require.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + assert.NoFileExists(t, localFilePath) + assert.NoDirExists(t, filepath.Dir(localFilePath)) + + c = common.RetentionChecks.Add(check, &user) + assert.NotNil(t, c) + + assert.Nil(t, common.RetentionChecks.Add(check, &user)) // a check for this user is already in progress + + checks, _, err = httpdtest.GetRetentionChecks(http.StatusOK) + assert.NoError(t, err) + assert.Len(t, checks, 1) + + err = c.Start() + assert.NoError(t, err) + + assert.Eventually(t, func() bool { + return len(common.RetentionChecks.Get("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + checks, _, err = httpdtest.GetRetentionChecks(http.StatusOK) + assert.NoError(t, err) + assert.Len(t, checks, 0) +} + +func TestAddUserInvalidVirtualFolders(t *testing.T) { + u := getTestUser() + folderName := "fname" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), + Name: folderName, + }, + VirtualPath: "/vdir", + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), + Name: folderName + "1", + }, + VirtualPath: "/vdir", // invalid, already defined + }) + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), + Name: folderName, + }, + VirtualPath: "/vdir1", + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), + Name: folderName, // invalid, unique constraint (user.id, folder.id) violated + }, + VirtualPath: "/vdir2", + }) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), + Name: folderName + "1", + }, + VirtualPath: "/vdir1/", + QuotaSize: -1, + QuotaFiles: 1, // invvalid, we cannot have -1 and > 0 + }) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), + Name: folderName + "1", + }, + VirtualPath: "/vdir1/", + QuotaSize: 1, + QuotaFiles: -1, + }) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), + Name: folderName + "1", + }, + VirtualPath: "/vdir1/", + QuotaSize: -2, // invalid + QuotaFiles: 0, + }) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir1"), + Name: folderName + "1", + }, + VirtualPath: "/vdir1/", + QuotaSize: 0, + QuotaFiles: -2, // invalid + }) + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped_dir"), + }, + VirtualPath: "/vdir1", + }) + // folder name is mandatory + _, _, err = httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) +} + +func TestUserPublicKey(t *testing.T) { + u := getTestUser() + u.Password = "" + invalidPubKey := "invalid" + u.PublicKeys = []string{invalidPubKey} + _, _, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + u.PublicKeys = []string{testPubKey} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + dbUser, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + assert.Empty(t, dbUser.Password) + assert.False(t, dbUser.IsPasswordHashed()) + + user.PublicKeys = []string{testPubKey, invalidPubKey} + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + user.PublicKeys = []string{testPubKey, testPubKey, testPubKey} + user.Password = defaultPassword + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + dbUser, err = dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + // DSA keys are not accepted + u = getTestUser() + u.Password = "" + u.PublicKeys = []string{"ssh-dss AAAAB3NzaC1kc3MAAACBAK+BKLZs1Vd0cWYOquKfp++0ml9hkzB7UDRozT3nhRcyHcuwASsXiVTqsg96oGjBcUUy076CXlsfJEXE2P0dF6tt1wvABPMwKpOn+kIrfJ0j93X2c2KIZNlD4YuNUJjLHu1DvgQHw8NMps6l5D0M5NFCRdD3NYhI5zFVJJ4CzikrAAAAFQCRBagw7gEbs0gd8So7OLMcSVzs/wAAAIBjuo7U9q8npchQ3otgCvj0xIwsQ+Fi9bH0SBceqbCcVzFYY6JXSQ0XmwHs+0AuvRCPIGaBdfcm+w+9YOxREtdEVjcmkYlfJpTaVljjWcWFWTQddbiamZhQ/xLU9CNLK4oYLwIGLZjCcG7nRDdLtLQdBFuzP/faEi3TD2BK114QmAAAAIEAj1n34pH2WKwbSZhzmz/OG0VzqJICFWboiM44LZl2AqcRBvEEycdHlGe2IKaj5lEtLgBKJt9NSFhBIzWh7gcEzSMlkiDecdYSFlDc4snmTiXaoiIehV59nTY6gc8GLWCzuem+WdHxvJ4yOSWF9k+a+Y+/v/35shNLkfokViOlN7k="} + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "DSA key format is insecure and it is not allowed") +} + +func TestUpdateUserEmptyPassword(t *testing.T) { + u := getTestUser() + u.PublicKeys = []string{testPubKey} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + // the password is not empty + dbUser, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + // now update the user and set an empty password + data, err := json.Marshal(dbUser) + assert.NoError(t, err) + var customUser map[string]any + err = json.Unmarshal(data, &customUser) + assert.NoError(t, err) + customUser["password"] = "" + asJSON, err := json.Marshal(customUser) + assert.NoError(t, err) + userNoPwd, _, err := httpdtest.UpdateUserWithJSON(user, http.StatusOK, "", asJSON) + assert.NoError(t, err) + assert.Equal(t, user.Password, userNoPwd.Password) // the password is hidden + // check the password within the data provider + dbUser, err = dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + assert.Empty(t, dbUser.Password) + assert.False(t, dbUser.IsPasswordHashed()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateUserNoPassword(t *testing.T) { + u := getTestUser() + u.PublicKeys = []string{testPubKey} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + // the password is not empty + dbUser, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + // now update the user and remove the password field, old password should be preserved + user.Password = "" // password has the omitempty tag + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // the password is preserved + dbUser, err = dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateUser(t *testing.T) { + u := getTestUser() + u.UsedQuotaFiles = 1 + u.UsedQuotaSize = 2 + u.Filters.TLSUsername = sdk.TLSUsernameCN + u.Filters.Hooks.CheckPasswordDisabled = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + user.HomeDir = filepath.Join(homeBasePath, "testmod") + user.UID = 33 + user.GID = 101 + user.MaxSessions = 10 + user.QuotaSize = 4096 + user.QuotaFiles = 2 + user.Permissions["/"] = []string{dataprovider.PermCreateDirs, dataprovider.PermDelete, dataprovider.PermDownload} + user.Permissions["/subdir"] = []string{dataprovider.PermListItems, dataprovider.PermUpload} + user.Filters.AllowedIP = []string{"192.168.1.0/24", "192.168.2.0/24"} + user.Filters.DeniedIP = []string{"192.168.3.0/24", "192.168.4.0/24"} + user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + user.Filters.DeniedProtocols = []string{common.ProtocolWebDAV} + user.Filters.TLSUsername = sdk.TLSUsernameNone + user.Filters.Hooks.ExternalAuthDisabled = true + user.Filters.Hooks.PreLoginDisabled = true + user.Filters.Hooks.CheckPasswordDisabled = false + user.Filters.DisableFsChecks = true + user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ + Path: "/subdir", + AllowedPatterns: []string{"*.zip", "*.rar"}, + DeniedPatterns: []string{"*.jpg", "*.png"}, + DenyPolicy: sdk.DenyPolicyHide, + }) + user.Filters.MaxUploadFileSize = 4096 + user.UploadBandwidth = 1024 + user.DownloadBandwidth = 512 + user.VirtualFolders = nil + mappedPath1 := filepath.Join(os.TempDir(), "mapped_dir1") + mappedPath2 := filepath.Join(os.TempDir(), "mapped_dir2") + folderName1 := filepath.Base(mappedPath1) + folderName2 := filepath.Base(mappedPath2) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir1", + }) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: "/vdir12/subdir", + QuotaSize: 123, + QuotaFiles: 2, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "invalid") + assert.NoError(t, err) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0") + assert.NoError(t, err) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "1") + assert.NoError(t, err) + user.Permissions["/subdir"] = []string{} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Len(t, user.Permissions["/subdir"], 0) + assert.Len(t, user.VirtualFolders, 2) + for _, folder := range user.VirtualFolders { + assert.Greater(t, folder.ID, int64(0)) + if folder.VirtualPath == "/vdir12/subdir" { + assert.Equal(t, int64(123), folder.QuotaSize) + assert.Equal(t, 2, folder.QuotaFiles) + } + } + folder, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user.Username) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + // removing the user must remove folder mapping + folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 0) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 0) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateUserTransferQuotaUsage(t *testing.T) { + u := getTestUser() + usedDownloadDataTransfer := int64(2 * 1024 * 1024) + usedUploadDataTransfer := int64(1024 * 1024) + u.UsedDownloadDataTransfer = usedDownloadDataTransfer + u.UsedUploadDataTransfer = usedUploadDataTransfer + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.UsedUploadDataTransfer) + assert.Equal(t, int64(0), user.UsedDownloadDataTransfer) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "invalid_mode", http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusBadRequest) + assert.NoError(t, err, "user has no transfer quota restrictions add mode should fail") + user.TotalDataTransfer = 100 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2*usedUploadDataTransfer, user.UsedUploadDataTransfer) + assert.Equal(t, 2*usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + u.UsedDownloadDataTransfer = -1 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusBadRequest) + assert.NoError(t, err) + u.UsedDownloadDataTransfer = usedDownloadDataTransfer + u.Username += "1" + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusNotFound) + assert.NoError(t, err) + u.Username = defaultUsername + _, err = httpdtest.UpdateTransferQuotaUsage(u, "", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + u.UsedDownloadDataTransfer = 0 + u.UsedUploadDataTransfer = 1 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer+1, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer, user.UsedDownloadDataTransfer) + u.UsedDownloadDataTransfer = 1 + u.UsedUploadDataTransfer = 0 + _, err = httpdtest.UpdateTransferQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedUploadDataTransfer+1, user.UsedUploadDataTransfer) + assert.Equal(t, usedDownloadDataTransfer+1, user.UsedDownloadDataTransfer) + + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "transfer-usage"), + bytes.NewBuffer([]byte(`not a json`))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateUserQuotaUsage(t *testing.T) { + u := getTestUser() + usedQuotaFiles := 1 + usedQuotaSize := int64(65535) + u.UsedQuotaFiles = usedQuotaFiles + u.UsedQuotaSize = usedQuotaSize + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + _, err = httpdtest.UpdateQuotaUsage(u, "invalid_mode", http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize, user.UsedQuotaSize) + _, err = httpdtest.UpdateQuotaUsage(u, "add", http.StatusBadRequest) + assert.NoError(t, err, "user has no quota restrictions add mode should fail") + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize, user.UsedQuotaSize) + user.QuotaFiles = 100 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = httpdtest.UpdateQuotaUsage(u, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2*usedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, 2*usedQuotaSize, user.UsedQuotaSize) + u.UsedQuotaFiles = -1 + _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusBadRequest) + assert.NoError(t, err) + u.UsedQuotaFiles = usedQuotaFiles + u.Username = u.Username + "1" + _, err = httpdtest.UpdateQuotaUsage(u, "", http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserFolderMapping(t *testing.T) { + mappedPath1 := filepath.Join(os.TempDir(), "mapped_dir1") + mappedPath2 := filepath.Join(os.TempDir(), "mapped_dir2") + folderName1 := filepath.Base(mappedPath1) + folderName2 := filepath.Base(mappedPath2) + u1 := getTestUser() + u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: "/vdir", + QuotaSize: -1, + QuotaFiles: -1, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + UsedQuotaFiles: 2, + UsedQuotaSize: 123, + LastQuotaUpdate: 456, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) + assert.NoError(t, err) + // virtual folder must be auto created + folder, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user1.Username) + assert.Equal(t, 2, folder.UsedQuotaFiles) + assert.Equal(t, int64(123), folder.UsedQuotaSize) + assert.Equal(t, int64(456), folder.LastQuotaUpdate) + assert.Equal(t, 2, user1.VirtualFolders[0].UsedQuotaFiles) + assert.Equal(t, int64(123), user1.VirtualFolders[0].UsedQuotaSize) + assert.Equal(t, int64(456), user1.VirtualFolders[0].LastQuotaUpdate) + + u2 := getTestUser() + u2.Username = defaultUsername + "2" + u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: "/vdir1", + QuotaSize: 0, + QuotaFiles: 0, + }) + u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: "/vdir2", + QuotaSize: -1, + QuotaFiles: -1, + }) + user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user2.Username) + folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 2) + assert.Contains(t, folder.Users, user1.Username) + assert.Contains(t, folder.Users, user2.Username) + // now update user2 removing mappedPath1 + user2.VirtualFolders = nil + user2.VirtualFolders = append(user2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + UsedQuotaFiles: 2, + UsedQuotaSize: 123, + }, + VirtualPath: "/vdir", + QuotaSize: 0, + QuotaFiles: 0, + }) + user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user2.Username) + assert.Equal(t, 0, folder.UsedQuotaFiles) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user1.Username) + // add mappedPath1 again to user2 + user2.VirtualFolders = append(user2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + }, + VirtualPath: "/vdir1", + }) + user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user2.Username) + // removing virtual folders should clear relations on both side + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user2.VirtualFolders, 1) { + folder := user2.VirtualFolders[0] + assert.Equal(t, mappedPath1, folder.MappedPath) + assert.Equal(t, folderName1, folder.Name) + } + user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user2.VirtualFolders, 1) { + folder := user2.VirtualFolders[0] + assert.Equal(t, mappedPath1, folder.MappedPath) + } + + folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 2) + // removing a user should clear virtual folder mapping + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Contains(t, folder.Users, user2.Username) + // removing a folder should clear mapping on the user side too + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user2.VirtualFolders, 0) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserS3Config(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.S3FilesystemProvider + user.FsConfig.S3Config.Bucket = "test" //nolint:goconst + user.FsConfig.S3Config.AccessKey = "Server-Access-Key" + user.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("Server-Access-Secret") + user.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("SSE-encryption-key") + user.FsConfig.S3Config.RoleARN = "myRoleARN" + user.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000" + user.FsConfig.S3Config.UploadPartSize = 8 + user.FsConfig.S3Config.DownloadPartMaxTime = 60 + user.FsConfig.S3Config.UploadPartMaxTime = 40 + user.FsConfig.S3Config.ForcePathStyle = true + user.FsConfig.S3Config.SkipTLSVerify = true + user.FsConfig.S3Config.DownloadPartSize = 6 + folderName := "vfolderName" + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/folderPath", + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "folderName"), + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("Crypted-Secret"), + }, + }, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user, body, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(body)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, user.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.SSECustomerKey.GetStatus()) + assert.NotEmpty(t, user.FsConfig.S3Config.SSECustomerKey.GetPayload()) + assert.Empty(t, user.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.S3Config.SSECustomerKey.GetKey()) + assert.Equal(t, 60, user.FsConfig.S3Config.DownloadPartMaxTime) + assert.Equal(t, 40, user.FsConfig.S3Config.UploadPartMaxTime) + assert.True(t, user.FsConfig.S3Config.SkipTLSVerify) + if assert.Len(t, user.VirtualFolders, 1) { + folder := user.VirtualFolders[0] + assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetKey()) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, folder.FsConfig.CryptConfig.Passphrase.GetKey()) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + user.VirtualFolders = nil + user.FsConfig.S3Config.SSECustomerKey = kms.NewEmptySecret() + secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "Server-Access-Secret", "", "") + user.FsConfig.S3Config.AccessSecret = secret + _, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.Error(t, err) + user.FsConfig.S3Config.AccessSecret.SetStatus(sdkkms.SecretStatusPlain) + user, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + initialSecretPayload := user.FsConfig.S3Config.AccessSecret.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, initialSecretPayload) + assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) + user.FsConfig.Provider = sdk.S3FilesystemProvider + user.FsConfig.S3Config.Bucket = "test-bucket" + user.FsConfig.S3Config.Region = "us-east-1" //nolint:goconst + user.FsConfig.S3Config.AccessKey = "Server-Access-Key1" + user.FsConfig.S3Config.Endpoint = "http://localhost:9000" + user.FsConfig.S3Config.KeyPrefix = "somedir/subdir" //nolint:goconst + user.FsConfig.S3Config.UploadConcurrency = 5 + user.FsConfig.S3Config.DownloadConcurrency = 4 + user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(bb)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.Equal(t, initialSecretPayload, user.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.Empty(t, user.FsConfig.S3Config.AccessSecret.GetKey()) + // test user without access key and access secret (shared config state) + user.FsConfig.Provider = sdk.S3FilesystemProvider + user.FsConfig.S3Config.Bucket = "testbucket" + user.FsConfig.S3Config.Region = "us-east-1" + user.FsConfig.S3Config.AccessKey = "" + user.FsConfig.S3Config.AccessSecret = kms.NewEmptySecret() + user.FsConfig.S3Config.Endpoint = "" + user.FsConfig.S3Config.KeyPrefix = "somedir/subdir" + user.FsConfig.S3Config.UploadPartSize = 6 + user.FsConfig.S3Config.UploadConcurrency = 4 + user, body, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(body)) + assert.Nil(t, user.FsConfig.S3Config.AccessSecret) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + // shared credential test for add instead of update + user, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err) + assert.Nil(t, user.FsConfig.S3Config.AccessSecret) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestHTTPFsConfig(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.HTTPFilesystemProvider + user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "http://127.0.0.1/httpfs", + Username: defaultUsername, + }, + Password: kms.NewPlainSecret(defaultPassword), + APIKey: kms.NewPlainSecret(defaultTokenAuthUser), + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + initialPwdPayload := user.FsConfig.HTTPConfig.Password.GetPayload() + initialAPIKeyPayload := user.FsConfig.HTTPConfig.APIKey.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, initialPwdPayload) + assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.NotEmpty(t, initialAPIKeyPayload) + assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetKey()) + user.FsConfig.HTTPConfig.Password.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.HTTPConfig.Password.SetAdditionalData(util.GenerateUniqueID()) + user.FsConfig.HTTPConfig.Password.SetKey(util.GenerateUniqueID()) + user.FsConfig.HTTPConfig.APIKey.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.HTTPConfig.APIKey.SetAdditionalData(util.GenerateUniqueID()) + user.FsConfig.HTTPConfig.APIKey.SetKey(util.GenerateUniqueID()) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.Password.GetStatus()) + assert.Equal(t, initialPwdPayload, user.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.Equal(t, initialAPIKeyPayload, user.FsConfig.HTTPConfig.APIKey.GetPayload()) + assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetKey()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + // also test AddUser + u := getTestUser() + u.FsConfig.Provider = sdk.HTTPFilesystemProvider + u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "http://127.0.0.1/httpfs", + Username: defaultUsername, + }, + Password: kms.NewPlainSecret(defaultPassword), + APIKey: kms.NewPlainSecret(defaultTokenAuthUser), + } + user, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, user.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Empty(t, user.FsConfig.HTTPConfig.Password.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.NotEmpty(t, user.FsConfig.HTTPConfig.APIKey.GetPayload()) + assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.HTTPConfig.APIKey.GetKey()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserAzureBlobConfig(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + user.FsConfig.AzBlobConfig.Container = "test" + user.FsConfig.AzBlobConfig.AccountName = "Server-Account-Name" + user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key") + user.FsConfig.AzBlobConfig.Endpoint = "http://127.0.0.1:9000" + user.FsConfig.AzBlobConfig.UploadPartSize = 8 + user.FsConfig.AzBlobConfig.DownloadPartSize = 6 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + initialPayload := user.FsConfig.AzBlobConfig.AccountKey.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) + user.FsConfig.AzBlobConfig.AccountKey.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.AzBlobConfig.AccountKey.SetAdditionalData("data") + user.FsConfig.AzBlobConfig.AccountKey.SetKey("fake key") + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.Equal(t, initialPayload, user.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "Server-Account-Key", "", "") + user.FsConfig.AzBlobConfig.AccountKey = secret + _, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.Error(t, err) + user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key-Test") + user, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err) + initialPayload = user.FsConfig.AzBlobConfig.AccountKey.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) + user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + user.FsConfig.AzBlobConfig.Container = "test-container" + user.FsConfig.AzBlobConfig.Endpoint = "http://localhost:9001" + user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir" + user.FsConfig.AzBlobConfig.UploadConcurrency = 5 + user.FsConfig.AzBlobConfig.DownloadConcurrency = 4 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Equal(t, initialPayload, user.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.AzBlobConfig.AccountKey.GetKey()) + // test user without access key and access secret (SAS) + user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + user.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("https://myaccount.blob.core.windows.net/pictures/profile.jpg?sv=2012-02-12&st=2009-02-09&se=2009-02-10&sr=c&sp=r&si=YWJjZGVmZw%3d%3d&sig=dD80ihBh5jfNpymO5Hg1IdiJIEvHcJpCMiCMnN%2fRnbI%3d") + user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir" + user.FsConfig.AzBlobConfig.AccountName = "" + user.FsConfig.AzBlobConfig.AccountKey = kms.NewEmptySecret() + user.FsConfig.AzBlobConfig.UploadPartSize = 6 + user.FsConfig.AzBlobConfig.UploadConcurrency = 4 + user.FsConfig.AzBlobConfig.DownloadPartSize = 3 + user.FsConfig.AzBlobConfig.DownloadConcurrency = 5 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Nil(t, user.FsConfig.AzBlobConfig.AccountKey) + assert.NotNil(t, user.FsConfig.AzBlobConfig.SASURL) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + // sas test for add instead of update + user.FsConfig.AzBlobConfig = vfs.AzBlobFsConfig{ + BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ + Container: user.FsConfig.AzBlobConfig.Container, + }, + SASURL: kms.NewPlainSecret("http://127.0.0.1/fake/sass/url"), + } + user, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err) + assert.Nil(t, user.FsConfig.AzBlobConfig.AccountKey) + initialPayload = user.FsConfig.AzBlobConfig.SASURL.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.SASURL.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) + assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetKey()) + user.FsConfig.AzBlobConfig.SASURL.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.AzBlobConfig.SASURL.SetAdditionalData("data") + user.FsConfig.AzBlobConfig.SASURL.SetKey("fake key") + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.AzBlobConfig.SASURL.GetStatus()) + assert.Equal(t, initialPayload, user.FsConfig.AzBlobConfig.SASURL.GetPayload()) + assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) + assert.Empty(t, user.FsConfig.AzBlobConfig.SASURL.GetKey()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserCryptFs(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("crypt passphrase") + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + initialPayload := user.FsConfig.CryptConfig.Passphrase.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) + user.FsConfig.CryptConfig.Passphrase.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.CryptConfig.Passphrase.SetAdditionalData("data") + user.FsConfig.CryptConfig.Passphrase.SetKey("fake pass key") + user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(bb)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.Equal(t, initialPayload, user.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "invalid encrypted payload", "", "") + user.FsConfig.CryptConfig.Passphrase = secret + _, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.Error(t, err) + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("passphrase test") + user, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err) + initialPayload = user.FsConfig.CryptConfig.Passphrase.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase.SetKey("pass") + user, bb, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(bb)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, initialPayload) + assert.Equal(t, initialPayload, user.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, user.FsConfig.CryptConfig.Passphrase.GetKey()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserSFTPFs(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.SFTPFilesystemProvider + user.FsConfig.SFTPConfig.Endpoint = "[::1]:22:22" // invalid endpoint + user.FsConfig.SFTPConfig.Username = "sftp_user" + user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp_pwd") + user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) + user.FsConfig.SFTPConfig.Fingerprints = []string{sftpPkeyFingerprint} + user.FsConfig.SFTPConfig.BufferSize = 2 + user.FsConfig.SFTPConfig.EqualityCheckMode = 1 + _, resp, err := httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + assert.Contains(t, string(resp), "invalid endpoint") + + user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1" + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1:22", user.FsConfig.SFTPConfig.Endpoint) + + user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:2022" + user.FsConfig.SFTPConfig.DisableCouncurrentReads = true + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, "/", user.FsConfig.SFTPConfig.Prefix) + assert.True(t, user.FsConfig.SFTPConfig.DisableCouncurrentReads) + assert.Equal(t, int64(2), user.FsConfig.SFTPConfig.BufferSize) + initialPwdPayload := user.FsConfig.SFTPConfig.Password.GetPayload() + initialPkeyPayload := user.FsConfig.SFTPConfig.PrivateKey.GetPayload() + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.Password.GetStatus()) + assert.NotEmpty(t, initialPwdPayload) + assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.NotEmpty(t, initialPkeyPayload) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) + user.FsConfig.SFTPConfig.Password.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.SFTPConfig.Password.SetAdditionalData("adata") + user.FsConfig.SFTPConfig.Password.SetKey("fake pwd key") + user.FsConfig.SFTPConfig.PrivateKey.SetStatus(sdkkms.SecretStatusSecretBox) + user.FsConfig.SFTPConfig.PrivateKey.SetAdditionalData("adata") + user.FsConfig.SFTPConfig.PrivateKey.SetKey("fake key") + user.FsConfig.SFTPConfig.DisableCouncurrentReads = false + user, bb, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(bb)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.Password.GetStatus()) + assert.Equal(t, initialPwdPayload, user.FsConfig.SFTPConfig.Password.GetPayload()) + assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.Empty(t, user.FsConfig.SFTPConfig.Password.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.Equal(t, initialPkeyPayload, user.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) + assert.False(t, user.FsConfig.SFTPConfig.DisableCouncurrentReads) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + secret := kms.NewSecret(sdkkms.SecretStatusSecretBox, "invalid encrypted payload", "", "") + user.FsConfig.SFTPConfig.Password = secret + _, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.Error(t, err) + user.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() + user.FsConfig.SFTPConfig.PrivateKey = secret + _, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.Error(t, err) + + user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) + user, _, err = httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err) + initialPkeyPayload = user.FsConfig.SFTPConfig.PrivateKey.GetPayload() + assert.Nil(t, user.FsConfig.SFTPConfig.Password) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.NotEmpty(t, initialPkeyPayload) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) + user.FsConfig.Provider = sdk.SFTPFilesystemProvider + user.FsConfig.SFTPConfig.PrivateKey.SetKey("k") + user, bb, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(bb)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.NotEmpty(t, initialPkeyPayload) + assert.Equal(t, initialPkeyPayload, user.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.Empty(t, user.FsConfig.SFTPConfig.PrivateKey.GetKey()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserHiddenFields(t *testing.T) { + // sensitive data must be hidden but not deleted from the dataprovider + usernames := []string{"user1", "user2", "user3", "user4", "user5", "user6"} + u1 := getTestUser() + u1.Username = usernames[0] + u1.FsConfig.Provider = sdk.S3FilesystemProvider + u1.FsConfig.S3Config.Bucket = "test" + u1.FsConfig.S3Config.Region = "us-east-1" + u1.FsConfig.S3Config.AccessKey = "S3-Access-Key" + u1.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("S3-Access-Secret") + u1.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("SSE-secret-key") + user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) + assert.NoError(t, err) + + u2 := getTestUser() + u2.Username = usernames[1] + u2.FsConfig.Provider = sdk.GCSFilesystemProvider + u2.FsConfig.GCSConfig.Bucket = "test" + u2.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("fake credentials") + u2.FsConfig.GCSConfig.ACL = "bucketOwnerRead" + u2.FsConfig.GCSConfig.UploadPartSize = 5 + u2.FsConfig.GCSConfig.UploadPartMaxTime = 20 + user2, _, err := httpdtest.AddUser(u2, http.StatusCreated) + assert.NoError(t, err) + + u3 := getTestUser() + u3.Username = usernames[2] + u3.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + u3.FsConfig.AzBlobConfig.Container = "test" + u3.FsConfig.AzBlobConfig.AccountName = "Server-Account-Name" + u3.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("Server-Account-Key") + user3, _, err := httpdtest.AddUser(u3, http.StatusCreated) + assert.NoError(t, err) + + u4 := getTestUser() + u4.Username = usernames[3] + u4.FsConfig.Provider = sdk.CryptedFilesystemProvider + u4.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("test passphrase") + user4, _, err := httpdtest.AddUser(u4, http.StatusCreated) + assert.NoError(t, err) + + u5 := getTestUser() + u5.Username = usernames[4] + u5.FsConfig.Provider = sdk.SFTPFilesystemProvider + u5.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:2022" + u5.FsConfig.SFTPConfig.Username = "sftp_user" + u5.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("apassword") + u5.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(sftpPrivateKey) + u5.FsConfig.SFTPConfig.Fingerprints = []string{sftpPkeyFingerprint} + u5.FsConfig.SFTPConfig.Prefix = "/prefix" + user5, _, err := httpdtest.AddUser(u5, http.StatusCreated) + assert.NoError(t, err) + + u6 := getTestUser() + u6.Username = usernames[5] + u6.FsConfig.Provider = sdk.HTTPFilesystemProvider + u6.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "http://127.0.0.1/api/v1", + Username: defaultUsername, + }, + Password: kms.NewPlainSecret(defaultPassword), + APIKey: kms.NewPlainSecret(defaultTokenAuthUser), + } + user6, _, err := httpdtest.AddUser(u6, http.StatusCreated) + assert.NoError(t, err) + + users, _, err := httpdtest.GetUsers(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(users), 6) + for _, username := range usernames { + user, _, err := httpdtest.GetUserByUsername(username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user.Password) + assert.True(t, user.HasPassword) + } + user1, _, err = httpdtest.GetUserByUsername(user1.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user1.Password) + assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetKey()) + assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) + assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetStatus()) + assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetPayload()) + + user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user2.Password) + assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) + assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetPayload()) + + user3, _, err = httpdtest.GetUserByUsername(user3.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user3.Password) + assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetKey()) + assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + + user4, _, err = httpdtest.GetUserByUsername(user4.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user4.Password) + assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetPayload()) + + user5, _, err = httpdtest.GetUserByUsername(user5.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user5.Password) + assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetKey()) + assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetStatus()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetPayload()) + assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetKey()) + assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + assert.Equal(t, "/prefix", user5.FsConfig.SFTPConfig.Prefix) + + user6, _, err = httpdtest.GetUserByUsername(user6.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user6.Password) + assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetKey()) + assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.NotEmpty(t, user6.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.NotEmpty(t, user6.FsConfig.HTTPConfig.APIKey.GetPayload()) + + // finally check that we have all the data inside the data provider + user1, err = dataprovider.UserExists(user1.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, user1.Password) + assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetKey()) + assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, user1.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetKey()) + assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) + assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetStatus()) + assert.NotEmpty(t, user1.FsConfig.S3Config.SSECustomerKey.GetPayload()) + err = user1.FsConfig.S3Config.AccessSecret.Decrypt() + assert.NoError(t, err) + err = user1.FsConfig.S3Config.SSECustomerKey.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user1.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.Equal(t, u1.FsConfig.S3Config.AccessSecret.GetPayload(), user1.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, user1.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusPlain, user1.FsConfig.S3Config.SSECustomerKey.GetStatus()) + assert.Equal(t, u1.FsConfig.S3Config.SSECustomerKey.GetPayload(), user1.FsConfig.S3Config.SSECustomerKey.GetPayload()) + assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetKey()) + assert.Empty(t, user1.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) + + user2, err = dataprovider.UserExists(user2.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, user2.Password) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetPayload()) + err = user2.FsConfig.GCSConfig.Credentials.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user2.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.Equal(t, u2.FsConfig.GCSConfig.Credentials.GetPayload(), user2.FsConfig.GCSConfig.Credentials.GetPayload()) + assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) + assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + + user3, err = dataprovider.UserExists(user3.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, user3.Password) + assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetKey()) + assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.NotEmpty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + err = user3.FsConfig.AzBlobConfig.AccountKey.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user3.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.Equal(t, u3.FsConfig.AzBlobConfig.AccountKey.GetPayload(), user3.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetKey()) + assert.Empty(t, user3.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + + user4, err = dataprovider.UserExists(user4.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, user4.Password) + assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, user4.FsConfig.CryptConfig.Passphrase.GetPayload()) + err = user4.FsConfig.CryptConfig.Passphrase.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user4.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.Equal(t, u4.FsConfig.CryptConfig.Passphrase.GetPayload(), user4.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Empty(t, user4.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + + user5, err = dataprovider.UserExists(user5.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, user5.Password) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetKey()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetStatus()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.Password.GetPayload()) + err = user5.FsConfig.SFTPConfig.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user5.FsConfig.SFTPConfig.Password.GetStatus()) + assert.Equal(t, u5.FsConfig.SFTPConfig.Password.GetPayload(), user5.FsConfig.SFTPConfig.Password.GetPayload()) + assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetKey()) + assert.Empty(t, user5.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetKey()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.NotEmpty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + err = user5.FsConfig.SFTPConfig.PrivateKey.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user5.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.Equal(t, u5.FsConfig.SFTPConfig.PrivateKey.GetPayload(), user5.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetKey()) + assert.Empty(t, user5.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + + user6, err = dataprovider.UserExists(user6.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, user6.Password) + assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetKey()) + assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, user6.FsConfig.HTTPConfig.Password.GetPayload()) + err = user6.FsConfig.HTTPConfig.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, user6.FsConfig.HTTPConfig.Password.GetStatus()) + assert.Equal(t, u6.FsConfig.HTTPConfig.Password.GetPayload(), user6.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetKey()) + assert.Empty(t, user6.FsConfig.HTTPConfig.Password.GetAdditionalData()) + + // update the GCS user and check that the credentials are preserved + user2.FsConfig.GCSConfig.Credentials = kms.NewEmptySecret() + user2.FsConfig.GCSConfig.ACL = "private" + _, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + + user2, _, err = httpdtest.GetUserByUsername(user2.Username, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, user2.Password) + assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetKey()) + assert.Empty(t, user2.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.NotEmpty(t, user2.FsConfig.GCSConfig.Credentials.GetPayload()) + + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user3, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user4, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user5, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user6, http.StatusOK) + assert.NoError(t, err) +} + +func TestSecretObject(t *testing.T) { + s := kms.NewPlainSecret("test data") + s.SetAdditionalData("username") + require.True(t, s.IsValid()) + err := s.Encrypt() + require.NoError(t, err) + require.Equal(t, sdkkms.SecretStatusSecretBox, s.GetStatus()) + require.NotEmpty(t, s.GetPayload()) + require.NotEmpty(t, s.GetKey()) + require.True(t, s.IsValid()) + err = s.Decrypt() + require.NoError(t, err) + require.Equal(t, sdkkms.SecretStatusPlain, s.GetStatus()) + require.Equal(t, "test data", s.GetPayload()) + require.Empty(t, s.GetKey()) +} + +func TestSecretObjectCompatibility(t *testing.T) { + // this is manually tested against vault too + testPayload := "test payload" + s := kms.NewPlainSecret(testPayload) + require.True(t, s.IsValid()) + err := s.Encrypt() + require.NoError(t, err) + localAsJSON, err := json.Marshal(s) + assert.NoError(t, err) + + for _, secretStatus := range []string{sdkkms.SecretStatusSecretBox} { + kmsConfig := config.GetKMSConfig() + assert.Empty(t, kmsConfig.Secrets.MasterKeyPath) + if secretStatus == sdkkms.SecretStatusVaultTransit { + os.Setenv("VAULT_SERVER_URL", "http://127.0.0.1:8200") + os.Setenv("VAULT_SERVER_TOKEN", "s.9lYGq83MbgG5KR5kfebXVyhJ") + kmsConfig.Secrets.URL = "hashivault://mykey" + } + err := kmsConfig.Initialize() + assert.NoError(t, err) + // encrypt without a master key + secret := kms.NewPlainSecret(testPayload) + secret.SetAdditionalData("add data") + err = secret.Encrypt() + assert.NoError(t, err) + assert.Equal(t, 0, secret.GetMode()) + secretClone := secret.Clone() + err = secretClone.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secretClone.GetPayload()) + if secretStatus == sdkkms.SecretStatusVaultTransit { + // decrypt the local secret now that the provider is vault + secretLocal := kms.NewEmptySecret() + err = json.Unmarshal(localAsJSON, secretLocal) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 0, secretLocal.GetMode()) + err = secretLocal.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secretLocal.GetPayload()) + assert.Equal(t, sdkkms.SecretStatusPlain, secretLocal.GetStatus()) + err = secretLocal.Encrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 0, secretLocal.GetMode()) + } + + asJSON, err := json.Marshal(secret) + assert.NoError(t, err) + + masterKeyPath := filepath.Join(os.TempDir(), "mkey") + err = os.WriteFile(masterKeyPath, []byte("test key"), os.ModePerm) + assert.NoError(t, err) + config := kms.Configuration{ + Secrets: kms.Secrets{ + MasterKeyPath: masterKeyPath, + }, + } + if secretStatus == sdkkms.SecretStatusVaultTransit { + config.Secrets.URL = "hashivault://mykey" + } + err = config.Initialize() + assert.NoError(t, err) + + // now build the secret from JSON + secret = kms.NewEmptySecret() + err = json.Unmarshal(asJSON, secret) + assert.NoError(t, err) + assert.Equal(t, 0, secret.GetMode()) + err = secret.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secret.GetPayload()) + err = secret.Encrypt() + assert.NoError(t, err) + assert.Equal(t, 1, secret.GetMode()) + err = secret.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secret.GetPayload()) + if secretStatus == sdkkms.SecretStatusVaultTransit { + // decrypt the local secret encryped without a master key now that + // the provider is vault and a master key is set. + // The provider will not change, the master key will be used + secretLocal := kms.NewEmptySecret() + err = json.Unmarshal(localAsJSON, secretLocal) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 0, secretLocal.GetMode()) + err = secretLocal.Decrypt() + assert.NoError(t, err) + assert.Equal(t, testPayload, secretLocal.GetPayload()) + assert.Equal(t, sdkkms.SecretStatusPlain, secretLocal.GetStatus()) + err = secretLocal.Encrypt() + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, secretLocal.GetStatus()) + assert.Equal(t, 1, secretLocal.GetMode()) + } + + err = kmsConfig.Initialize() + assert.NoError(t, err) + err = os.Remove(masterKeyPath) + assert.NoError(t, err) + if secretStatus == sdkkms.SecretStatusVaultTransit { + os.Unsetenv("VAULT_SERVER_URL") + os.Unsetenv("VAULT_SERVER_TOKEN") + } + } +} + +func TestUpdateUserNoCredentials(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.Password = "" + user.PublicKeys = []string{} + // password and public key will be omitted from json serialization if empty and so they will remain unchanged + // and no validation error will be raised + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateUserEmptyHomeDir(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.HomeDir = "" + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateUserInvalidHomeDir(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + user.HomeDir = "relative_path" + _, _, err = httpdtest.UpdateUser(user, http.StatusBadRequest, "") + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateNonExistentUser(t *testing.T) { + _, _, err := httpdtest.UpdateUser(getTestUser(), http.StatusNotFound, "") + assert.NoError(t, err) +} + +func TestGetNonExistentUser(t *testing.T) { + _, _, err := httpdtest.GetUserByUsername("na", http.StatusNotFound) + assert.NoError(t, err) +} + +func TestDeleteNonExistentUser(t *testing.T) { + _, err := httpdtest.RemoveUser(getTestUser(), http.StatusNotFound) + assert.NoError(t, err) +} + +func TestAddDuplicateUser(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + _, _, err = httpdtest.AddUser(getTestUser(), http.StatusConflict) + assert.NoError(t, err) + _, _, err = httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.Error(t, err, "adding a duplicate user must fail") + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestGetUsers(t *testing.T) { + user1, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Username = defaultUsername + "1" + user2, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + users, _, err := httpdtest.GetUsers(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(users), 2) + for _, user := range users { + if u.Username == user.Username { + assert.True(t, user.HasPassword) + } + } + users, _, err = httpdtest.GetUsers(1, 0, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, len(users)) + users, _, err = httpdtest.GetUsers(1, 1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, len(users)) + _, _, err = httpdtest.GetUsers(1, 1, http.StatusInternalServerError) + assert.Error(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) +} + +func TestGetQuotaScans(t *testing.T) { + _, _, err := httpdtest.GetQuotaScans(http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetQuotaScans(http.StatusInternalServerError) + assert.Error(t, err) + _, _, err = httpdtest.GetFoldersQuotaScans(http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetFoldersQuotaScans(http.StatusInternalServerError) + assert.Error(t, err) +} + +func TestStartQuotaScan(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + folder := vfs.BaseVirtualFolder{ + Name: "vfolder", + MappedPath: filepath.Join(os.TempDir(), "folder"), + Description: "virtual folder", + } + _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.StartFolderQuotaScan(folder, http.StatusAccepted) + assert.NoError(t, err) + for { + quotaScan, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) + if !assert.NoError(t, err, "Error getting active scans") { + break + } + if len(quotaScan) == 0 { + break + } + time.Sleep(100 * time.Millisecond) + } + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateFolderQuotaUsage(t *testing.T) { + f := vfs.BaseVirtualFolder{ + Name: "vdir", + MappedPath: filepath.Join(os.TempDir(), "folder"), + } + usedQuotaFiles := 1 + usedQuotaSize := int64(65535) + f.UsedQuotaFiles = usedQuotaFiles + f.UsedQuotaSize = usedQuotaSize + folder, _, err := httpdtest.AddFolder(f, http.StatusCreated) + if assert.NoError(t, err) { + assert.Equal(t, usedQuotaFiles, folder.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize, folder.UsedQuotaSize) + } + _, err = httpdtest.UpdateFolderQuotaUsage(folder, "invalid mode", http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.UpdateFolderQuotaUsage(f, "reset", http.StatusOK) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(f.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, folder.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize, folder.UsedQuotaSize) + _, err = httpdtest.UpdateFolderQuotaUsage(f, "add", http.StatusOK) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(f.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2*usedQuotaFiles, folder.UsedQuotaFiles) + assert.Equal(t, 2*usedQuotaSize, folder.UsedQuotaSize) + f.UsedQuotaSize = -1 + _, err = httpdtest.UpdateFolderQuotaUsage(f, "", http.StatusBadRequest) + assert.NoError(t, err) + f.UsedQuotaSize = usedQuotaSize + f.Name = f.Name + "1" + _, err = httpdtest.UpdateFolderQuotaUsage(f, "", http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) +} + +func TestGetVersion(t *testing.T) { + _, _, err := httpdtest.GetVersion(http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetVersion(http.StatusInternalServerError) + assert.Error(t, err, "get version request must succeed, we requested to check a wrong status code") +} + +func TestGetStatus(t *testing.T) { + _, _, err := httpdtest.GetStatus(http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetStatus(http.StatusBadRequest) + assert.Error(t, err, "get provider status request must succeed, we requested to check a wrong status code") +} + +func TestGetConnections(t *testing.T) { + _, _, err := httpdtest.GetConnections(http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetConnections(http.StatusInternalServerError) + assert.Error(t, err, "get sftp connections request must succeed, we requested to check a wrong status code") +} + +func TestCloseActiveConnection(t *testing.T) { + _, err := httpdtest.CloseConnection("non_existent_id", http.StatusNotFound) + assert.NoError(t, err) + user := getTestUser() + c := common.NewBaseConnection("connID", common.ProtocolSFTP, "", "", user) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) + _, err = httpdtest.CloseConnection(c.GetID(), http.StatusOK) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestCloseConnectionAfterUserUpdateDelete(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + c := common.NewBaseConnection("connID", common.ProtocolFTP, "", "", user) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) + c1 := common.NewBaseConnection("connID1", common.ProtocolSFTP, "", "", user) + fakeConn1 := &fakeConnection{ + BaseConnection: c1, + } + err = common.Connections.Add(fakeConn1) + assert.NoError(t, err) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "0") + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 2) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "1") + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) + err = common.Connections.Add(fakeConn1) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 2) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestAdminGenerateRecoveryCodesSaveError(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.NamingRules = 7 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + a := getTestAdmin() + a.Username = "adMiN@example.com " + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], admin.Username) + assert.NoError(t, err) + admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + } + admin.Password = defaultTokenAuthPass + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(a.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + adminAPIToken, err := getJWTAPITokenFromTestServerWithPasscode(a.Username, defaultTokenAuthPass, passcode) + assert.NoError(t, err) + assert.NotEmpty(t, adminAPIToken) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { + return + } + req, err := http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, adminAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following characters are allowed") + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminCredentialsWithSpaces(t *testing.T) { + a := getTestAdmin() + a.Username = xid.New().String() + a.Password = " " + xid.New().String() + " " + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + // For admins the password is always trimmed. + _, err = getJWTAPITokenFromTestServer(a.Username, a.Password) + assert.Error(t, err) + _, err = getJWTAPITokenFromTestServer(a.Username, strings.TrimSpace(a.Password)) + assert.NoError(t, err) + // The password sent from the WebAdmin UI is automatically trimmed + _, err = getJWTWebToken(a.Username, a.Password) + assert.NoError(t, err) + _, err = getJWTWebToken(a.Username, strings.TrimSpace(a.Password)) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserCredentialsWithSpaces(t *testing.T) { + u := getTestUser() + u.Password = " " + xid.New().String() + " " + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + // For users the password is not trimmed + _, err = getJWTAPIUserTokenFromTestServer(u.Username, u.Password) + assert.NoError(t, err) + _, err = getJWTAPIUserTokenFromTestServer(u.Username, strings.TrimSpace(u.Password)) + assert.Error(t, err) + + _, err = getJWTWebClientTokenFromTestServer(u.Username, u.Password) + assert.NoError(t, err) + _, err = getJWTWebClientTokenFromTestServer(u.Username, strings.TrimSpace(u.Password)) + assert.Error(t, err) + + user.Password = u.Password + conn, sftpClient, err := getSftpClient(user) + if assert.NoError(t, err) { + conn.Close() + sftpClient.Close() + } + user.Password = strings.TrimSpace(u.Password) + _, _, err = getSftpClient(user) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestNamingRules(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.NamingRules = 7 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + u := getTestUser() + u.Username = " uSeR@user.me " + u.Email = dataprovider.ConvertName(u.Username) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, "user@user.me", user.Username) + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + user.Password = u.Password + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + user.Username = u.Username + user.AdditionalInfo = "info" + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(u.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.TOTPConfig.Enabled) + + r := getTestRole() + r.Name = "role@mycompany" + role, _, err := httpdtest.AddRole(r, http.StatusCreated) + assert.NoError(t, err) + + a := getTestAdmin() + a.Username = "admiN@example.com " + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, "admin@example.com", admin.Username) + admin.Email = dataprovider.ConvertName(a.Username) + admin.Username = a.Username + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(a.Username, http.StatusOK) + assert.NoError(t, err) + + f := vfs.BaseVirtualFolder{ + Name: "文件夹AB", + MappedPath: filepath.Clean(os.TempDir()), + } + folder, resp, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Equal(t, "文件夹ab", folder.Name) + folder.Name = f.Name + folder.Description = folder.Name + _, resp, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err, string(resp)) + folder, resp, err = httpdtest.GetFolderByName(f.Name, http.StatusOK) + assert.NoError(t, err, string(resp)) + assert.Equal(t, "文件夹AB", folder.Description) + _, err = httpdtest.RemoveFolder(f, http.StatusOK) + assert.NoError(t, err) + token, err := getJWTWebClientTokenFromTestServer(u.Username, defaultPassword) + assert.NoError(t, err) + assert.NotEmpty(t, token) + adminAPIToken, err := getJWTAPITokenFromTestServer(a.Username, defaultTokenAuthPass) + assert.NoError(t, err) + assert.NotEmpty(t, adminAPIToken) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { + return + } + + token, err = getJWTWebClientTokenFromTestServer(user.Username, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + req, err := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) + // test user reset password. Setting the new password will fail because the username is not valid + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set("username", user.Username) + form.Set(csrfFormToken, csrfToken) + lastResetCode = "" + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("code", lastResetCode) + form.Set("password", defaultPassword) + form.Set("confirm_password", defaultPassword) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) + + adminAPIToken, err = getJWTAPITokenFromTestServer(admin.Username, defaultTokenAuthPass) + assert.NoError(t, err) + userAPIToken, err := getJWTAPIUserTokenFromTestServer(user.Username, defaultPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userPath+"/"+user.Username+"/2fa/disable", nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, adminAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following characters are allowed") + + req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following characters are allowed") + + apiKeyAuthReq := make(map[string]bool) + apiKeyAuthReq["allow_api_key_auth"] = true + asJSON, err := json.Marshal(apiKeyAuthReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following characters are allowed") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + token, err = getJWTWebTokenFromTestServer(admin.Username, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminProfilePath, token) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidUser) + + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidName) + + apiKeyAuthReq = make(map[string]bool) + apiKeyAuthReq["allow_api_key_auth"] = true + asJSON, err = json.Marshal(apiKeyAuthReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, adminAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the following characters are allowed") + // test admin reset password + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set("username", admin.Username) + form.Set(csrfFormToken, csrfToken) + lastResetCode = "" + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("code", lastResetCode) + form.Set("password", defaultPassword) + form.Set("confirm_password", defaultPassword) + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestUserPassword(t *testing.T) { + u := getTestUser() + u.Password = "" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.False(t, user.HasPassword) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.False(t, user.HasPassword) + + user.Password = defaultPassword + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.True(t, user.HasPassword) + + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + rawUser := map[string]any{ + "username": user.Username, + "home_dir": filepath.Join(homeBasePath, defaultUsername), + "permissions": map[string][]string{ + "/": {"*"}, + }, + } + userAsJSON, err := json.Marshal(rawUser) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // the previous password must be preserved + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.HasPassword) + // update the user with an empty password field, the password will be unset + rawUser["password"] = "" + userAsJSON, err = json.Marshal(rawUser) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, user.HasPassword) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSaveErrors(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.NamingRules = 1 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + recCode := "recovery code" + recoveryCodes := []dataprovider.RecoveryCode{ + { + Secret: kms.NewPlainSecret(recCode), + Used: false, + }, + } + + u := getTestUser() + u.Username = "user@example.com" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = u.Password + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH, common.ProtocolHTTP}, + } + user.Filters.RecoveryCodes = recoveryCodes + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Len(t, user.Filters.RecoveryCodes, 1) + + a := getTestAdmin() + a.Username = "admin@example.com" + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + admin.Email = admin.Username + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + admin.Password = a.Password + admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + } + admin.Filters.RecoveryCodes = recoveryCodes + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 1) + + r := getTestRole() + r.Name = "role@mycompany" + role, _, err := httpdtest.AddRole(r, http.StatusCreated) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + if config.GetProviderConf().Driver == dataprovider.MemoryDataProviderName { + return + } + + _, resp, err := httpdtest.UpdateRole(role, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "the following characters are allowed") + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := getLoginForm(a.Username, a.Password, csrfToken) + req, err := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + cookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set("recovery_code", recCode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(u.Username, u.Password, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set("recovery_code", recCode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserBaseDir(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.UsersBaseDir = homeBasePath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + u := getTestUser() + u.HomeDir = "" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + if assert.Error(t, err) { + assert.EqualError(t, err, "home dir mismatch") + } + assert.Equal(t, filepath.Join(providerConf.UsersBaseDir, u.Username), user.HomeDir) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestQuotaTrackingDisabled(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.TrackQuota = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + // user quota scan must fail + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.StartQuotaScan(user, http.StatusForbidden) + assert.NoError(t, err) + _, err = httpdtest.UpdateQuotaUsage(user, "", http.StatusForbidden) + assert.NoError(t, err) + _, err = httpdtest.UpdateTransferQuotaUsage(user, "", http.StatusForbidden) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + // folder quota scan must fail + folder := vfs.BaseVirtualFolder{ + Name: "folder_quota_test", + MappedPath: filepath.Clean(os.TempDir()), + } + folder, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err, string(resp)) + _, err = httpdtest.StartFolderQuotaScan(folder, http.StatusForbidden) + assert.NoError(t, err) + _, err = httpdtest.UpdateFolderQuotaUsage(folder, "", http.StatusForbidden) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestProviderErrors(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + userAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + userWebToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + token, _, err := httpdtest.GetToken(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + testServerToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + httpdtest.SetJWTToken(token) + err = dataprovider.Close() + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername("na", http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetUsers(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetGroups(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetAdmins(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetAPIKeys(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetEventActions(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetEventRules(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetIPListEntries(dataprovider.IPListTypeDefender, "", "", dataprovider.OrderASC, 10, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetRoles(1, 0, http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateRole(getTestRole(), http.StatusInternalServerError) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodPost, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + // password reset errors + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := make(url.Values) + form.Set("username", "username") + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) + + getJSONShares := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, err := http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, userWebToken) + executeRequest(req) + } + getJSONShares() + + req, err = http.NewRequest(http.MethodGet, webClientSharePath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, userWebToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodGet, webClientSharePath+"/shareID", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, userWebToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/shareID", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, userWebToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + _, _, err = httpdtest.UpdateUser(dataprovider.User{BaseUser: sdk.BaseUser{Username: "auser"}}, http.StatusInternalServerError, "") + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(dataprovider.User{BaseUser: sdk.BaseUser{Username: "auser"}}, http.StatusInternalServerError) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: "aname"}, http.StatusInternalServerError) + assert.NoError(t, err) + status, _, err := httpdtest.GetStatus(http.StatusOK) + if assert.NoError(t, err) { + assert.False(t, status.DataProvider.IsActive) + } + _, _, err = httpdtest.Dumpdata("backup.json", "", "", http.StatusInternalServerError) + assert.NoError(t, err) + _, _, err = httpdtest.GetFolders(0, 0, http.StatusInternalServerError) + assert.NoError(t, err) + user = getTestUser() + user.ID = 1 + backupData := dataprovider.BackupData{ + Version: dataprovider.DumpVersion, + } + backupData.Configs = &dataprovider.Configs{} + backupData.Users = append(backupData.Users, user) + backupContent, err := json.Marshal(backupData) + assert.NoError(t, err) + backupFilePath := filepath.Join(backupsPath, "backup.json") + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData.Configs = nil + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData.Folders = append(backupData.Folders, vfs.BaseVirtualFolder{Name: "testFolder", MappedPath: filepath.Clean(os.TempDir())}) + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData.Users = nil + backupData.Folders = nil + backupData.Groups = append(backupData.Groups, getTestGroup()) + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData.Groups = nil + backupData.Admins = append(backupData.Admins, getTestAdmin()) + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData.Users = nil + backupData.Folders = nil + backupData.Admins = nil + backupData.APIKeys = append(backupData.APIKeys, dataprovider.APIKey{ + Name: "name", + KeyID: util.GenerateUniqueID(), + Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), + Scope: dataprovider.APIKeyScopeUser, + }) + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData.APIKeys = nil + backupData.Shares = append(backupData.Shares, dataprovider.Share{ + Name: util.GenerateUniqueID(), + ShareID: util.GenerateUniqueID(), + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Username: defaultUsername, + }) + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, resp, err := httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err, string(resp)) + backupData = dataprovider.BackupData{ + EventActions: []dataprovider.BaseEventAction{ + { + Name: "quota_reset", + Type: dataprovider.ActionTypeFolderQuotaReset, + }, + }, + Version: dataprovider.DumpVersion, + } + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData = dataprovider.BackupData{ + EventRules: []dataprovider.EventRule{ + { + Name: "quota_reset", + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "2", + DayOfWeek: "1", + DayOfMonth: "2", + Month: "3", + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: "unknown action", + }, + Order: 1, + }, + }, + }, + }, + Version: dataprovider.DumpVersion, + } + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData = dataprovider.BackupData{ + Roles: []dataprovider.Role{ + { + Name: "role1", + }, + }, + Version: dataprovider.DumpVersion, + } + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + backupData = dataprovider.BackupData{ + IPLists: []dataprovider.IPListEntry{ + { + IPOrNet: "192.168.1.1/24", + Type: dataprovider.IPListTypeRateLimiterSafeList, + Mode: dataprovider.ListModeAllow, + }, + }, + Version: dataprovider.DumpVersion, + } + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "", http.StatusInternalServerError) + assert.NoError(t, err) + + err = os.Remove(backupFilePath) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webUserPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, webTemplateUser+"?from=auser", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, webGroupPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, webAdminPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, webTemplateUser, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, path.Join(webGroupPath, "groupname"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, "grpname"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, webTemplateFolder+"?from=afolder", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventActionPath, "actionname"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, "actionname"), bytes.NewBuffer(nil)) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + getJSONActions := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, err := http.NewRequest(http.MethodGet, webAdminEventActionsPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + executeRequest(req) + } + getJSONActions() + + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventRulePath, "rulename"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, "rulename"), bytes.NewBuffer(nil)) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + getJSONRules := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, err := http.NewRequest(http.MethodGet, webAdminEventRulesPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + executeRequest(req) + } + getJSONRules() + + req, err = http.NewRequest(http.MethodGet, webAdminEventRulePath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, testServerToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + httpdtest.SetJWTToken("") +} + +func TestFolders(t *testing.T) { + folder := vfs.BaseVirtualFolder{ + Name: "name", + MappedPath: "relative path", + Users: []string{"1", "2", "3"}, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("asecret"), + }, + }, + } + _, _, err := httpdtest.AddFolder(folder, http.StatusBadRequest) + assert.NoError(t, err) + folder.MappedPath = filepath.Clean(os.TempDir()) + folder1, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Equal(t, folder.Name, folder1.Name) + assert.Equal(t, folder.MappedPath, folder1.MappedPath) + assert.Equal(t, 0, folder1.UsedQuotaFiles) + assert.Equal(t, int64(0), folder1.UsedQuotaSize) + assert.Equal(t, int64(0), folder1.LastQuotaUpdate) + assert.Equal(t, sdkkms.SecretStatusSecretBox, folder1.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, folder1.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, folder1.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, folder1.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Len(t, folder1.Users, 0) + // adding a duplicate folder must fail + _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + assert.Error(t, err) + folder.MappedPath = filepath.Join(os.TempDir(), "vfolder") + folder.Name = filepath.Base(folder.MappedPath) + folder.UsedQuotaFiles = 1 + folder.UsedQuotaSize = 345 + folder.LastQuotaUpdate = 10 + folder2, _, err := httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Equal(t, 1, folder2.UsedQuotaFiles) + assert.Equal(t, int64(345), folder2.UsedQuotaSize) + assert.Equal(t, int64(10), folder2.LastQuotaUpdate) + assert.Len(t, folder2.Users, 0) + folders, _, err := httpdtest.GetFolders(0, 0, http.StatusOK) + assert.NoError(t, err) + numResults := len(folders) + assert.GreaterOrEqual(t, numResults, 2) + found := false + for _, f := range folders { + if f.Name == folder1.Name { + found = true + assert.Equal(t, folder1.MappedPath, f.MappedPath) + assert.Equal(t, sdkkms.SecretStatusSecretBox, f.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, f.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Len(t, f.Users, 0) + } + } + assert.True(t, found) + folders, _, err = httpdtest.GetFolders(0, 1, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folders, numResults-1) + folders, _, err = httpdtest.GetFolders(1, 0, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folders, 1) + f, _, err := httpdtest.GetFolderByName(folder1.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, folder1.Name, f.Name) + assert.Equal(t, folder1.MappedPath, f.MappedPath) + assert.Equal(t, sdkkms.SecretStatusSecretBox, f.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, f.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Empty(t, f.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Len(t, f.Users, 0) + f, _, err = httpdtest.GetFolderByName(folder2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, folder2.Name, f.Name) + assert.Equal(t, folder2.MappedPath, f.MappedPath) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{ + Name: "invalid", + }, http.StatusNotFound) + assert.NoError(t, err) + _, _, err = httpdtest.UpdateFolder(vfs.BaseVirtualFolder{Name: "notfound"}, http.StatusNotFound) + assert.NoError(t, err) + folder1.MappedPath = "a/relative/path" + _, _, err = httpdtest.UpdateFolder(folder1, http.StatusBadRequest) + assert.NoError(t, err) + folder1.MappedPath = filepath.Join(os.TempDir(), "updated") + folder1.Description = "updated folder description" + f, resp, err = httpdtest.UpdateFolder(folder1, http.StatusOK) + assert.NoError(t, err, string(resp)) + assert.Equal(t, folder1.MappedPath, f.MappedPath) + assert.Equal(t, folder1.Description, f.Description) + + _, err = httpdtest.RemoveFolder(folder1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder2, http.StatusOK) + assert.NoError(t, err) +} + +func TestFolderRelations(t *testing.T) { + mappedPath := filepath.Join(os.TempDir(), "mapped_path") + name := filepath.Base(mappedPath) + u := getTestUser() + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: name, + }, + VirtualPath: "/mountu", + }) + _, resp, err := httpdtest.AddUser(u, http.StatusInternalServerError) + assert.NoError(t, err, string(resp)) + g := getTestGroup() + g.VirtualFolders = append(g.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: name, + }, + VirtualPath: "/mountg", + }) + _, resp, err = httpdtest.AddGroup(g, http.StatusInternalServerError) + assert.NoError(t, err, string(resp)) + f := vfs.BaseVirtualFolder{ + Name: name, + MappedPath: mappedPath, + } + folder, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + assert.Len(t, folder.Users, 0) + assert.Len(t, folder.Groups, 0) + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + group, resp, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + folder, _, err = httpdtest.GetFolderByName(folder.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Len(t, folder.Groups, 1) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 1) { + assert.Equal(t, mappedPath, user.VirtualFolders[0].MappedPath) + } + + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, group.VirtualFolders, 1) { + assert.Equal(t, mappedPath, group.VirtualFolders[0].MappedPath) + } + // update the folder and check the modified field on user and group + mappedPath = filepath.Join(os.TempDir(), "mapped_path") + folder.MappedPath = mappedPath + _, _, err = httpdtest.UpdateFolder(folder, http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 1) { + assert.Equal(t, mappedPath, user.VirtualFolders[0].MappedPath) + } + + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, group.VirtualFolders, 1) { + assert.Equal(t, mappedPath, group.VirtualFolders[0].MappedPath) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folder.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 0) + assert.Len(t, folder.Groups, 0) + + user, resp, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, user.VirtualFolders, 1) + group, resp, err = httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Len(t, group.VirtualFolders, 1) + + folder, _, err = httpdtest.GetFolderByName(folder.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, folder.Users, 1) + assert.Len(t, folder.Groups, 1) + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.VirtualFolders, 0) + + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group.VirtualFolders, 0) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestDumpdata(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + _, rawResp, err := httpdtest.Dumpdata("", "", "", http.StatusBadRequest) + assert.NoError(t, err, string(rawResp)) + _, _, err = httpdtest.Dumpdata(filepath.Join(backupsPath, "backup.json"), "", "", http.StatusBadRequest) + assert.NoError(t, err) + _, rawResp, err = httpdtest.Dumpdata("../backup.json", "", "", http.StatusBadRequest) + assert.NoError(t, err, string(rawResp)) + _, rawResp, err = httpdtest.Dumpdata("backup.json", "", "0", http.StatusOK) + assert.NoError(t, err, string(rawResp)) + response, _, err := httpdtest.Dumpdata("", "1", "0", http.StatusOK) + assert.NoError(t, err) + _, ok := response["admins"] + assert.True(t, ok) + _, ok = response["users"] + assert.True(t, ok) + _, ok = response["groups"] + assert.True(t, ok) + _, ok = response["folders"] + assert.True(t, ok) + _, ok = response["api_keys"] + assert.True(t, ok) + _, ok = response["shares"] + assert.True(t, ok) + _, ok = response["version"] + assert.True(t, ok) + _, rawResp, err = httpdtest.Dumpdata("backup.json", "", "1", http.StatusOK) + assert.NoError(t, err, string(rawResp)) + err = os.Remove(filepath.Join(backupsPath, "backup.json")) + assert.NoError(t, err) + if runtime.GOOS != osWindows { + err = os.Chmod(backupsPath, 0001) + assert.NoError(t, err) + _, _, err = httpdtest.Dumpdata("bck.json", "", "", http.StatusForbidden) + assert.NoError(t, err) + // subdir cannot be created + _, _, err = httpdtest.Dumpdata(filepath.Join("subdir", "bck.json"), "", "", http.StatusForbidden) + assert.NoError(t, err) + err = os.Chmod(backupsPath, 0755) + assert.NoError(t, err) + } + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestDefenderAPI(t *testing.T) { + oldConfig := config.GetCommonConfig() + + drivers := []string{common.DefenderDriverMemory} + if isDbDefenderSupported() { + drivers = append(drivers, common.DefenderDriverProvider) + } + + for _, driver := range drivers { + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Driver = driver + cfg.DefenderConfig.Threshold = 3 + cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreNoAuth = 0 + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + ip := "::1" + + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + assert.Len(t, hosts, 0) + + _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound) + assert.NoError(t, err) + + common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventNoLoginTried) + hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + assert.Len(t, hosts, 0) + common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) + hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 2, host.Score) + assert.Equal(t, ip, host.IP) + } + host, _, err := httpdtest.GetDefenderHostByIP(ip, http.StatusOK) + assert.NoError(t, err) + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 2, host.Score) + + common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) + hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.NotEmpty(t, host.GetBanTime()) + assert.Equal(t, 0, host.Score) + assert.Equal(t, ip, host.IP) + } + host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusOK) + assert.NoError(t, err) + assert.NotEmpty(t, host.GetBanTime()) + assert.Equal(t, 0, host.Score) + + _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusOK) + assert.NoError(t, err) + + _, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound) + assert.NoError(t, err) + + common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) + common.AddDefenderEvent(ip, common.ProtocolHTTP, common.HostEventUserNotFound) + hosts, _, err = httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + assert.Len(t, hosts, 1) + + _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusOK) + assert.NoError(t, err) + + host, _, err = httpdtest.GetDefenderHostByIP(ip, http.StatusNotFound) + assert.NoError(t, err) + _, err = httpdtest.RemoveDefenderHostByIP(ip, http.StatusNotFound) + assert.NoError(t, err) + + host, _, err = httpdtest.GetDefenderHostByIP("invalid_ip", http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.RemoveDefenderHostByIP("invalid_ip", http.StatusBadRequest) + assert.NoError(t, err) + if driver == common.DefenderDriverProvider { + err = dataprovider.CleanupDefender(util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Hour))) + assert.NoError(t, err) + } + } + + err := common.Initialize(oldConfig, 0) + require.NoError(t, err) +} + +func TestDefenderAPIErrors(t *testing.T) { + if isDbDefenderSupported() { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Driver = common.DefenderDriverProvider + err := common.Initialize(cfg, 0) + require.NoError(t, err) + + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, defenderHosts, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + require.NoError(t, err) + } +} + +func TestRestoreShares(t *testing.T) { + // shares should be restored preserving the UsedTokens, CreatedAt, LastUseAt, UpdatedAt, + // and ExpiresAt, so an expired share can be restored while we cannot create an already + // expired share + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + share := dataprovider.Share{ + ShareID: shortuuid.New(), + Name: "share name", + Description: "share description", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Username: user.Username, + CreatedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-144 * time.Hour)), + UpdatedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-96 * time.Hour)), + LastUseAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-64 * time.Hour)), + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-48 * time.Hour)), + MaxTokens: 10, + UsedTokens: 8, + AllowFrom: []string{"127.0.0.0/8"}, + } + backupData := dataprovider.BackupData{ + Version: dataprovider.DumpVersion, + } + backupData.Shares = append(backupData.Shares, share) + backupContent, err := json.Marshal(backupData) + assert.NoError(t, err) + _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusOK) + assert.NoError(t, err) + shareGet, err := dataprovider.ShareExists(share.ShareID, user.Username) + assert.NoError(t, err) + assert.Equal(t, share, shareGet) + + share.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-142 * time.Hour)) + share.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-92 * time.Hour)) + share.LastUseAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-62 * time.Hour)) + share.UsedTokens = 6 + backupData.Shares = []dataprovider.Share{share} + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusOK) + assert.NoError(t, err) + shareGet, err = dataprovider.ShareExists(share.ShareID, user.Username) + assert.NoError(t, err) + assert.Equal(t, share, shareGet) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoaddataFromPostBody(t *testing.T) { + mappedPath := filepath.Join(os.TempDir(), "restored_folder") + folderName := filepath.Base(mappedPath) + role := getTestRole() + role.ID = 1 + role.Name = "test_restored_role" + group := getTestGroup() + group.ID = 1 + group.Name = "test_group_restored" + user := getTestUser() + user.ID = 1 + user.Username = "test_user_restored" + user.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user.Role = role.Name + admin := getTestAdmin() + admin.ID = 1 + admin.Username = "test_admin_restored" + admin.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, + dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} + admin.Role = role.Name + backupData := dataprovider.BackupData{ + Version: dataprovider.DumpVersion, + } + backupData.Users = append(backupData.Users, user) + backupData.Groups = append(backupData.Groups, group) + backupData.Admins = append(backupData.Admins, admin) + backupData.Roles = append(backupData.Roles, role) + backupData.Folders = []vfs.BaseVirtualFolder{ + { + Name: folderName, + MappedPath: mappedPath, + UsedQuotaSize: 123, + UsedQuotaFiles: 456, + LastQuotaUpdate: 789, + Users: []string{"user"}, + }, + { + Name: folderName, + MappedPath: mappedPath + "1", + }, + } + backupData.APIKeys = append(backupData.APIKeys, dataprovider.APIKey{}) + backupData.Shares = append(backupData.Shares, dataprovider.Share{}) + backupContent, err := json.Marshal(backupData) + assert.NoError(t, err) + _, _, err = httpdtest.LoaddataFromPostBody(nil, "0", "0", http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "a", "0", http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.LoaddataFromPostBody([]byte("invalid content"), "0", "0", http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusInternalServerError) + assert.NoError(t, err) + + keyID := util.GenerateUniqueID() + backupData.APIKeys = []dataprovider.APIKey{ + { + Name: "test key", + Scope: dataprovider.APIKeyScopeAdmin, + KeyID: keyID, + Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), + }, + } + backupData.Shares = []dataprovider.Share{ + { + ShareID: keyID, + Name: keyID, + Scope: dataprovider.ShareScopeWrite, + Paths: []string{"/"}, + Username: user.Username, + }, + } + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + _, resp, err := httpdtest.LoaddataFromPostBody(backupContent, "0", "0", http.StatusOK) + assert.NoError(t, err, string(resp)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, user.Role) + if assert.Len(t, user.Groups, 1) { + assert.Equal(t, sdk.GroupTypePrimary, user.Groups[0].Type) + assert.Equal(t, group.Name, user.Groups[0].Name) + } + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, role.Admins, 1) + assert.Len(t, role.Users, 1) + _, err = dataprovider.ShareExists(keyID, user.Username) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, admin.Role) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, role.Admins, 0) + assert.Len(t, role.Users, 0) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + apiKey, _, err := httpdtest.GetAPIKeyByID(keyID, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) + + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, mappedPath+"1", folder.MappedPath) + assert.Equal(t, int64(123), folder.UsedQuotaSize) + assert.Equal(t, 456, folder.UsedQuotaFiles) + assert.Equal(t, int64(789), folder.LastQuotaUpdate) + assert.Len(t, folder.Users, 0) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoaddata(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + mappedPath := filepath.Join(os.TempDir(), "restored_folder") + folderName := filepath.Base(mappedPath) + folderDesc := "restored folder desc" + user := getTestUser() + user.ID = 1 + user.Username = "test_user_restore" + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vuserpath", + }) + group := getTestGroup() + group.ID = 1 + group.Name = "test_group_restore" + group.VirtualFolders = append(group.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vgrouppath", + }) + role := getTestRole() + role.ID = 1 + role.Name = "test_role_restore" + user.Groups = append(user.Groups, sdk.GroupMapping{ + Name: group.Name, + Type: sdk.GroupTypePrimary, + }) + admin := getTestAdmin() + admin.ID = 1 + admin.Username = "test_admin_restore" + admin.Groups = []dataprovider.AdminGroupMapping{ + { + Name: group.Name, + }, + } + ipListEntry := dataprovider.IPListEntry{ + IPOrNet: "172.16.2.4/32", + Description: "entry desc", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + Protocols: 3, + } + apiKey := dataprovider.APIKey{ + Name: util.GenerateUniqueID(), + Scope: dataprovider.APIKeyScopeAdmin, + KeyID: util.GenerateUniqueID(), + Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), + } + share := dataprovider.Share{ + ShareID: util.GenerateUniqueID(), + Name: util.GenerateUniqueID(), + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Username: user.Username, + } + action := dataprovider.BaseEventAction{ + ID: 81, + Name: "test_restore_action", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "https://localhost:4567/action", + Username: defaultUsername, + Password: kms.NewPlainSecret(defaultPassword), + Timeout: 10, + SkipTLSVerify: true, + Method: http.MethodPost, + Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, + }, + }, + } + rule := dataprovider.EventRule{ + ID: 100, + Name: "test_rule_restore", + Description: "", + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"download"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + configs := dataprovider.Configs{ + SFTPD: &dataprovider.SFTPDConfigs{ + HostKeyAlgos: []string{ssh.KeyAlgoRSA, ssh.CertAlgoRSAv01}, + PublicKeyAlgos: []string{ssh.InsecureKeyAlgoDSA}, //nolint:staticcheck + }, + SMTP: &dataprovider.SMTPConfigs{ + Host: "mail.example.com", + Port: 587, + From: "from@example.net", + }, + } + backupData := dataprovider.BackupData{ + Version: 14, + } + backupData.Configs = &configs + backupData.Users = append(backupData.Users, user) + backupData.Roles = append(backupData.Roles, role) + backupData.Groups = append(backupData.Groups, group) + backupData.Admins = append(backupData.Admins, admin) + backupData.Folders = []vfs.BaseVirtualFolder{ + { + Name: folderName, + MappedPath: mappedPath + "1", + UsedQuotaSize: 123, + UsedQuotaFiles: 456, + LastQuotaUpdate: 789, + Users: []string{"user"}, + }, + { + MappedPath: mappedPath, + Name: folderName, + Description: folderDesc, + }, + } + backupData.APIKeys = append(backupData.APIKeys, apiKey) + backupData.Shares = append(backupData.Shares, share) + backupData.EventActions = append(backupData.EventActions, action) + backupData.EventRules = append(backupData.EventRules, rule) + backupData.IPLists = append(backupData.IPLists, ipListEntry) + backupContent, err := json.Marshal(backupData) + assert.NoError(t, err) + backupFilePath := filepath.Join(backupsPath, "backup.json") + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "a", "", http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "", "a", http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata("backup.json", "1", "", http.StatusBadRequest) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath+"a", "1", "", http.StatusBadRequest) + assert.NoError(t, err) + if runtime.GOOS != osWindows { + err = os.Chmod(backupFilePath, 0111) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "1", "", http.StatusBadRequest) + assert.NoError(t, err) + err = os.Chmod(backupFilePath, 0644) + assert.NoError(t, err) + } + // add objects from backup + _, resp, err := httpdtest.Loaddata(backupFilePath, "1", "", http.StatusOK) + assert.NoError(t, err, string(resp)) + // update from backup + _, _, err = httpdtest.Loaddata(backupFilePath, "2", "", http.StatusOK) + assert.NoError(t, err) + configsGet, err := dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Equal(t, configs.SMTP, configsGet.SMTP) + assert.Equal(t, []string{ssh.KeyAlgoRSA}, configsGet.SFTPD.HostKeyAlgos) + assert.Equal(t, []string{ssh.InsecureKeyAlgoDSA}, configsGet.SFTPD.PublicKeyAlgos) //nolint:staticcheck + assert.Len(t, configsGet.SFTPD.KexAlgorithms, 0) + assert.Len(t, configsGet.SFTPD.Ciphers, 0) + assert.Len(t, configsGet.SFTPD.MACs, 0) + assert.Greater(t, configsGet.UpdatedAt, int64(0)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.VirtualFolders, 1) + assert.Len(t, user.Groups, 1) + _, err = dataprovider.ShareExists(share.ShareID, user.Username) + assert.NoError(t, err) + + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group.VirtualFolders, 1) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, admin.Groups, 1) + + apiKey, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) + assert.NoError(t, err) + + action, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + + entry, _, err := httpdtest.GetIPListEntry(ipListEntry.IPOrNet, ipListEntry.Type, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, entry.CreatedAt, int64(0)) + assert.Greater(t, entry.UpdatedAt, int64(0)) + assert.Equal(t, ipListEntry.Description, entry.Description) + assert.Equal(t, ipListEntry.Protocols, entry.Protocols) + assert.Equal(t, ipListEntry.Mode, entry.Mode) + + rule, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, rule.Status) + if assert.Len(t, rule.Actions, 1) { + if assert.NotNil(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password) { + assert.Equal(t, sdkkms.SecretStatusSecretBox, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetPayload()) + assert.Empty(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetKey()) + assert.Empty(t, rule.Actions[0].BaseEventAction.Options.HTTPConfig.Password.GetAdditionalData()) + } + } + + response, _, err := httpdtest.Dumpdata("", "1", "0", http.StatusOK) + assert.NoError(t, err) + var dumpedData dataprovider.BackupData + data, err := json.Marshal(response) + assert.NoError(t, err) + err = json.Unmarshal(data, &dumpedData) + assert.NoError(t, err) + found := false + if assert.GreaterOrEqual(t, len(dumpedData.Users), 1) { + for _, u := range dumpedData.Users { + if u.Username == user.Username { + found = true + assert.Equal(t, len(user.VirtualFolders), len(u.VirtualFolders)) + assert.Equal(t, len(user.Groups), len(u.Groups)) + } + } + } + assert.True(t, found) + found = false + if assert.GreaterOrEqual(t, len(dumpedData.Admins), 1) { + for _, a := range dumpedData.Admins { + if a.Username == admin.Username { + found = true + assert.Equal(t, len(admin.Groups), len(a.Groups)) + } + } + } + assert.True(t, found) + if assert.Len(t, dumpedData.Groups, 1) { + assert.Equal(t, len(group.VirtualFolders), len(dumpedData.Groups[0].VirtualFolders)) + } + if assert.Len(t, dumpedData.EventActions, 1) { + assert.Equal(t, action.Name, dumpedData.EventActions[0].Name) + } + if assert.Len(t, dumpedData.EventRules, 1) { + assert.Equal(t, rule.Name, dumpedData.EventRules[0].Name) + assert.Len(t, dumpedData.EventRules[0].Actions, 1) + } + found = false + for _, r := range dumpedData.Roles { + if r.Name == role.Name { + found = true + } + } + assert.True(t, found) + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, int64(123), folder.UsedQuotaSize) + assert.Equal(t, 456, folder.UsedQuotaFiles) + assert.Equal(t, int64(789), folder.LastQuotaUpdate) + assert.Equal(t, folderDesc, folder.Description) + assert.Len(t, folder.Users, 1) + response, _, err = httpdtest.Dumpdata("", "1", "0", http.StatusOK, dataprovider.DumpScopeUsers) + assert.NoError(t, err) + dumpedData = dataprovider.BackupData{} + data, err = json.Marshal(response) + assert.NoError(t, err) + err = json.Unmarshal(data, &dumpedData) + assert.NoError(t, err) + assert.Greater(t, len(dumpedData.Users), 0) + assert.Len(t, dumpedData.Admins, 0) + assert.Len(t, dumpedData.Folders, 0) + assert.Len(t, dumpedData.Groups, 0) + assert.Len(t, dumpedData.Roles, 0) + assert.Len(t, dumpedData.EventRules, 0) + assert.Len(t, dumpedData.EventActions, 0) + assert.Len(t, dumpedData.IPLists, 0) + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) + + err = os.Remove(backupFilePath) + assert.NoError(t, err) + err = createTestFile(backupFilePath, 20*1048576+1) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "1", "0", http.StatusBadRequest) + assert.NoError(t, err) + err = os.Remove(backupFilePath) + assert.NoError(t, err) + err = createTestFile(backupFilePath, 65535) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "1", "0", http.StatusBadRequest) + assert.NoError(t, err) + err = os.Remove(backupFilePath) + assert.NoError(t, err) + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestLoaddataConvertActions(t *testing.T) { + a1 := dataprovider.BaseEventAction{ + Name: xid.New().String(), + Type: dataprovider.ActionTypeEmail, + Options: dataprovider.BaseEventActionOptions{ + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: []string{"failure@example.com"}, + Subject: `Failed "{{Event}}" from "{{Name}}"`, + Body: "Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}", + }, + }, + } + a2 := dataprovider.BaseEventAction{ + Name: xid.New().String(), + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/{{VirtualDirPath}}/{{ObjectName}}", + Value: "/{{ObjectName}}_renamed", + }, + }, + }, + }, + }, + } + backupData := dataprovider.BackupData{ + EventActions: []dataprovider.BaseEventAction{a1, a2}, + Version: 16, + } + backupContent, err := json.Marshal(backupData) + assert.NoError(t, err) + backupFilePath := filepath.Join(backupsPath, "backup.json") + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, resp, err := httpdtest.Loaddata(backupFilePath, "1", "2", http.StatusOK) + assert.NoError(t, err, string(resp)) + // Check that actions are migrated. + action1, _, err := httpdtest.GetEventActionByName(a1.Name, http.StatusOK) + assert.NoError(t, err) + action2, _, err := httpdtest.GetEventActionByName(a2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, `Failed "{{.Event}}" from "{{.Name}}"`, action1.Options.EmailConfig.Subject) + assert.Equal(t, `Object name: {{.ObjectName}} object type: {{.ObjectType}}, IP: {{.IP}}`, action1.Options.EmailConfig.Body) + assert.Equal(t, `/{{.VirtualDirPath}}/{{.ObjectName}}`, action2.Options.FsConfig.Renames[0].Key) + assert.Equal(t, `/{{.ObjectName}}_renamed`, action2.Options.FsConfig.Renames[0].Value) + // If we restore a backup from the current version actions are not migrated. + backupData = dataprovider.BackupData{ + EventActions: []dataprovider.BaseEventAction{a1, a2}, + Version: dataprovider.DumpVersion, + } + backupContent, err = json.Marshal(backupData) + assert.NoError(t, err) + backupFilePath = filepath.Join(backupsPath, "backup.json") + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, resp, err = httpdtest.Loaddata(backupFilePath, "1", "2", http.StatusOK) + assert.NoError(t, err, string(resp)) + action1, _, err = httpdtest.GetEventActionByName(a1.Name, http.StatusOK) + assert.NoError(t, err) + action2, _, err = httpdtest.GetEventActionByName(a2.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, `Failed "{{Event}}" from "{{Name}}"`, action1.Options.EmailConfig.Subject) + assert.Equal(t, `Object name: {{ObjectName}} object type: {{ObjectType}}, IP: {{IP}}`, action1.Options.EmailConfig.Body) + assert.Equal(t, `/{{VirtualDirPath}}/{{ObjectName}}`, action2.Options.FsConfig.Renames[0].Key) + assert.Equal(t, `/{{ObjectName}}_renamed`, action2.Options.FsConfig.Renames[0].Value) + // Cleanup. + _, err = httpdtest.RemoveEventAction(action1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action2, http.StatusOK) + assert.NoError(t, err) + actions, _, err := httpdtest.GetEventActions(0, 0, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, actions, 0) +} + +func TestLoaddataMode(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + mappedPath := filepath.Join(os.TempDir(), "restored_fold") + folderName := filepath.Base(mappedPath) + configs := dataprovider.Configs{ + SFTPD: &dataprovider.SFTPDConfigs{ + PublicKeyAlgos: []string{ssh.KeyAlgoRSA}, + }, + } + role := getTestRole() + role.ID = 1 + role.Name = "test_role_load" + role.Description = "" + user := getTestUser() + user.ID = 1 + user.Username = "test_user_restore" + user.Role = role.Name + group := getTestGroup() + group.ID = 1 + group.Name = "test_group_restore" + user.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + admin := getTestAdmin() + admin.ID = 1 + admin.Username = "test_admin_restore" + apiKey := dataprovider.APIKey{ + Name: util.GenerateUniqueID(), + Scope: dataprovider.APIKeyScopeAdmin, + KeyID: util.GenerateUniqueID(), + Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), + Description: "desc", + } + share := dataprovider.Share{ + ShareID: util.GenerateUniqueID(), + Name: util.GenerateUniqueID(), + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Username: user.Username, + } + action := dataprovider.BaseEventAction{ + ID: 81, + Name: "test_restore_action_data_mode", + Description: "action desc", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "https://localhost:4567/mode", + Username: defaultUsername, + Password: kms.NewPlainSecret(defaultPassword), + Timeout: 10, + SkipTLSVerify: true, + Method: http.MethodPost, + Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, + }, + }, + } + rule := dataprovider.EventRule{ + ID: 100, + Name: "test_rule_restore_data_mode", + Description: "rule desc", + Trigger: dataprovider.EventTriggerFsEvent, + Conditions: dataprovider.EventConditions{ + FsEvents: []string{"mkdir"}, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + ipListEntry := dataprovider.IPListEntry{ + IPOrNet: "10.8.3.9/32", + Description: "note", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + Protocols: 7, + } + backupData := dataprovider.BackupData{ + Version: dataprovider.DumpVersion, + } + backupData.Configs = &configs + backupData.Users = append(backupData.Users, user) + backupData.Groups = append(backupData.Groups, group) + backupData.Admins = append(backupData.Admins, admin) + backupData.EventActions = append(backupData.EventActions, action) + backupData.EventRules = append(backupData.EventRules, rule) + backupData.Roles = append(backupData.Roles, role) + backupData.Folders = []vfs.BaseVirtualFolder{ + { + Name: folderName, + MappedPath: mappedPath, + UsedQuotaSize: 123, + UsedQuotaFiles: 456, + LastQuotaUpdate: 789, + Users: []string{"user"}, + }, + { + MappedPath: mappedPath + "1", + Name: folderName, + }, + } + backupData.APIKeys = append(backupData.APIKeys, apiKey) + backupData.Shares = append(backupData.Shares, share) + backupData.IPLists = append(backupData.IPLists, ipListEntry) + backupContent, _ := json.Marshal(backupData) + backupFilePath := filepath.Join(backupsPath, "backup.json") + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + _, _, err = httpdtest.Loaddata(backupFilePath, "0", "0", http.StatusOK) + assert.NoError(t, err) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, mappedPath+"1", folder.MappedPath) + assert.Equal(t, int64(123), folder.UsedQuotaSize) + assert.Equal(t, 456, folder.UsedQuotaFiles) + assert.Equal(t, int64(789), folder.LastQuotaUpdate) + assert.Len(t, folder.Users, 0) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, user.Role) + oldUploadBandwidth := user.UploadBandwidth + user.UploadBandwidth = oldUploadBandwidth + 128 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, role.Users, 1) + assert.Len(t, role.Admins, 0) + assert.Empty(t, role.Description) + role.Description = "role desc" + _, _, err = httpdtest.UpdateRole(role, http.StatusOK) + assert.NoError(t, err) + role.Description = "" + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, group.Users, 1) + oldGroupDesc := group.Description + group.Description = "new group description" + group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) + assert.NoError(t, err) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + oldInfo := admin.AdditionalInfo + oldDesc := admin.Description + admin.AdditionalInfo = "newInfo" + admin.Description = "newDesc" + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + apiKey, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) + assert.NoError(t, err) + oldAPIKeyDesc := apiKey.Description + apiKey.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now()) + apiKey.Description = "new desc" + apiKey, _, err = httpdtest.UpdateAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) + share.Description = "test desc" + err = dataprovider.UpdateShare(&share, "", "", "") + assert.NoError(t, err) + + action, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + oldActionDesc := action.Description + action.Description = "new action description" + action, _, err = httpdtest.UpdateEventAction(action, http.StatusOK) + assert.NoError(t, err) + + rule, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, rule.Status) + oldRuleDesc := rule.Description + rule.Description = "new rule description" + rule, _, err = httpdtest.UpdateEventRule(rule, http.StatusOK) + assert.NoError(t, err) + + entry, _, err := httpdtest.GetIPListEntry(ipListEntry.IPOrNet, ipListEntry.Type, http.StatusOK) + assert.NoError(t, err) + oldEntryDesc := entry.Description + entry.Description = "new note" + entry, _, err = httpdtest.UpdateIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) + + configs.SFTPD.PublicKeyAlgos = append(configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + backupData.Configs = &configs + backupData.Folders = []vfs.BaseVirtualFolder{ + { + MappedPath: mappedPath, + Name: folderName, + }, + } + _, _, err = httpdtest.Loaddata(backupFilePath, "0", "1", http.StatusOK) + assert.NoError(t, err) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 2) + group, _, err = httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, oldGroupDesc, group.Description) + folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, mappedPath+"1", folder.MappedPath) + assert.Equal(t, int64(123), folder.UsedQuotaSize) + assert.Equal(t, 456, folder.UsedQuotaFiles) + assert.Equal(t, int64(789), folder.LastQuotaUpdate) + assert.Len(t, folder.Users, 0) + action, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, oldActionDesc, action.Description) + rule, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, oldRuleDesc, rule.Description) + entry, _, err = httpdtest.GetIPListEntry(ipListEntry.IPOrNet, ipListEntry.Type, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, oldEntryDesc, entry.Description) + + c := common.NewBaseConnection("connID", common.ProtocolFTP, "", "", user) + fakeConn := &fakeConnection{ + BaseConnection: c, + } + err = common.Connections.Add(fakeConn) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 1) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, oldUploadBandwidth, user.UploadBandwidth) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, oldInfo, admin.AdditionalInfo) + assert.NotEqual(t, oldDesc, admin.Description) + + apiKey, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) + assert.NoError(t, err) + assert.NotEqual(t, int64(0), apiKey.ExpiresAt) + assert.NotEqual(t, oldAPIKeyDesc, apiKey.Description) + + share, err = dataprovider.ShareExists(share.ShareID, user.Username) + assert.NoError(t, err) + assert.NotEmpty(t, share.Description) + + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, role.Users, 1) + assert.Len(t, role.Admins, 0) + assert.NotEmpty(t, role.Description) + + _, _, err = httpdtest.Loaddata(backupFilePath, "0", "2", http.StatusOK) + assert.NoError(t, err) + // mode 2 will update the user and close the previous connection + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, oldUploadBandwidth, user.UploadBandwidth) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) + // the group is referenced + _, err = httpdtest.RemoveGroup(group, http.StatusBadRequest) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(backupFilePath) + assert.NoError(t, err) + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestRateLimiter(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.RateLimitersConfig = []common.RateLimiterConfig{ + { + Average: 1, + Period: 1000, + Burst: 1, + Type: 1, + Protocols: []string{common.ProtocolHTTP}, + }, + } + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Get(httpBaseURL + healthzPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + resp, err = client.Get(httpBaseURL + healthzPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, "1", resp.Header.Get("Retry-After")) + assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) + err = resp.Body.Close() + assert.NoError(t, err) + + resp, err = client.Get(httpBaseURL + webLoginPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, "1", resp.Header.Get("Retry-After")) + assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) + err = resp.Body.Close() + assert.NoError(t, err) + + resp, err = client.Get(httpBaseURL + webClientLoginPath) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, "1", resp.Header.Get("Retry-After")) + assert.NotEmpty(t, resp.Header.Get("X-Retry-In")) + err = resp.Body.Close() + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestHTTPSConnection(t *testing.T) { + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Get("https://localhost:8443" + healthzPath) + if assert.Error(t, err) { + if !strings.Contains(err.Error(), "certificate is not valid") && + !strings.Contains(err.Error(), "certificate signed by unknown authority") && + !strings.Contains(err.Error(), "certificate is not standards compliant") { + assert.Fail(t, err.Error()) + } + } else { + resp.Body.Close() + } +} + +// test using mock http server + +func TestBasicUserHandlingMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + user.MaxSessions = 10 + user.UploadBandwidth = 128 + user.Permissions["/"] = []string{dataprovider.PermAny, dataprovider.PermDelete, dataprovider.PermDownload} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, userPath+"/"+user.Username, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, userPath+"/"+user.Username, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + var updatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updatedUser) + assert.NoError(t, err) + assert.Equal(t, user.MaxSessions, updatedUser.MaxSessions) + assert.Equal(t, user.UploadBandwidth, updatedUser.UploadBandwidth) + assert.Equal(t, 1, len(updatedUser.Permissions["/"])) + assert.True(t, slices.Contains(updatedUser.Permissions["/"], dataprovider.PermAny)) + req, _ = http.NewRequest(http.MethodDelete, userPath+"/"+user.Username, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestAddUserNoUsernameMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + user.Username = "" + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddUserInvalidHomeDirMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + user.HomeDir = "relative_path" + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddUserInvalidPermsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + user.Permissions["/"] = []string{} + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddFolderInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer([]byte("invalid json"))) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddEventRuleInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, eventActionsPath, bytes.NewBuffer([]byte("invalid json"))) + require.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, err = http.NewRequest(http.MethodPost, eventRulesPath, bytes.NewBuffer([]byte("invalid json"))) + require.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddRoleInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, rolesPath, bytes.NewBuffer([]byte("{"))) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestIPListEntriesErrorsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, ipListsPath+"/a/b", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "invalid list type") + req, err = http.NewRequest(http.MethodGet, ipListsPath+"/invalid", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "invalid list type") + + reqBody := bytes.NewBuffer([]byte("{")) + req, err = http.NewRequest(http.MethodPost, ipListsPath+"/2", reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + entry := dataprovider.IPListEntry{ + IPOrNet: "172.120.1.1/32", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + } + _, _, err = httpdtest.AddIPListEntry(entry, http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPut, path.Join(ipListsPath, "1", url.PathEscape(entry.IPOrNet)), reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveIPListEntry(entry, http.StatusOK) + assert.NoError(t, err) +} + +func TestRoleErrorsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + reqBody := bytes.NewBuffer([]byte("{")) + req, err := http.NewRequest(http.MethodGet, rolesPath+"?limit=a", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + role, _, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPut, path.Join(rolesPath, role.Name), reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPut, path.Join(rolesPath, "missing_role"), reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusNotFound) + assert.NoError(t, err) +} + +func TestEventRuleErrorsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + reqBody := bytes.NewBuffer([]byte("invalid json body")) + + req, err := http.NewRequest(http.MethodGet, eventActionsPath+"?limit=b", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, eventRulesPath+"?limit=c", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + a := dataprovider.BaseEventAction{ + Name: "action_name", + Description: "test description", + Type: dataprovider.ActionTypeBackup, + Options: dataprovider.BaseEventActionOptions{}, + } + action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPut, path.Join(eventActionsPath, action.Name), reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + r := dataprovider.EventRule{ + Name: "test_event_rule", + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "2", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + rule, _, err := httpdtest.AddEventRule(r, http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPut, path.Join(eventRulesPath, rule.Name), reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + rule.Actions[0].Name = "misssing action name" + asJSON, err := json.Marshal(rule) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, path.Join(eventRulesPath, rule.Name), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + _, err = httpdtest.RemoveEventRule(rule, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveEventAction(action, http.StatusOK) + assert.NoError(t, err) +} + +func TestGroupErrorsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + reqBody := bytes.NewBuffer([]byte("not a json string")) + + req, err := http.NewRequest(http.MethodPost, groupPath, reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, groupPath+"?limit=d", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + group, _, err := httpdtest.AddGroup(getTestGroup(), http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPut, path.Join(groupPath, group.Name), reqBody) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestUpdateFolderInvalidJsonMock(t *testing.T) { + folder := vfs.BaseVirtualFolder{ + Name: "name", + MappedPath: filepath.Clean(os.TempDir()), + } + folder, resp, err := httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPut, path.Join(folderPath, folder.Name), bytes.NewBuffer([]byte("not a json"))) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) +} + +func TestAddUserInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer([]byte("invalid json"))) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddAdminInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer([]byte("..."))) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestAddAdminNoPasswordMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Password = "" + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "please set a password") +} + +func TestAdminTwoFactorLogin(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + admin1 := getTestAdmin() + admin1.Username = altAdminUsername + "1" + admin1.Password = altAdminPassword + var permissions []string + for _, p := range admin1.GetValidPerms() { + if p != dataprovider.PermAdminAny && p != dataprovider.PermAdminDisableMFA { + permissions = append(permissions, p) + } + } + admin1.Permissions = permissions + admin1, _, err = httpdtest.AddAdmin(admin1, http.StatusCreated) + assert.NoError(t, err) + // enable two factor authentication + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], admin.Username) + assert.NoError(t, err) + altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + adminTOTPConfig := dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + } + asJSON, err := json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var recCodes []recoveryCode + err = json.Unmarshal(rr.Body.Bytes(), &recCodes) + assert.NoError(t, err) + assert.Len(t, recCodes, 12) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, admin.Filters.RecoveryCodes, 12) + for _, c := range admin.Filters.RecoveryCodes { + assert.Empty(t, c.Secret.GetAdditionalData()) + assert.Empty(t, c.Secret.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, c.Secret.GetStatus()) + assert.NotEmpty(t, c.Secret.GetPayload()) + } + + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + cookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + + // without a cookie + req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webAdminTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // any other page will be redirected to the two factor auth page + req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + // a partial token cannot be used for user pages + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + form = make(url.Values) + form.Set("passcode", passcode) + // no csrf + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("passcode", "invalid_passcode") + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("passcode", "") + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("passcode", passcode) + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webUsersPath, rr.Header().Get("Location")) + // the same cookie cannot be reused + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusNotFound, rr.Code) + // get a new cookie and login using a recovery code + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + form = make(url.Values) + recoveryCode := recCodes[0].Code + form.Set("recovery_code", recoveryCode) + // no csrf + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("recovery_code", "") + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("recovery_code", recoveryCode) + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webUsersPath, rr.Header().Get("Location")) + authenticatedCookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + //render MFA page + req, err = http.NewRequest(http.MethodGet, webAdminMFAPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, authenticatedCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check that the recovery code was marked as used + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + recCodes = nil + err = json.Unmarshal(rr.Body.Bytes(), &recCodes) + assert.NoError(t, err) + assert.Len(t, recCodes, 12) + found := false + for _, rc := range recCodes { + if rc.Code == recoveryCode { + found = true + assert.True(t, rc.Used) + } else { + assert.False(t, rc.Used) + } + } + assert.True(t, found) + // the same recovery code cannot be reused + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set("recovery_code", recoveryCode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("recovery_code", "invalid_recovery_code") + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + // disable TOTP + altToken1, err := getJWTAPITokenFromTestServer(admin1.Username, altAdminPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, altToken1) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "two-factor authentication is not enabled") + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set("recovery_code", recoveryCode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) + + form.Set("passcode", passcode) + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin1, http.StatusOK) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + req, err = http.NewRequest(http.MethodGet, webAdminMFAPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, authenticatedCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) +} + +func TestAdminTOTP(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + // TOTPConfig will be ignored on add + admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: "config", + Secret: kms.NewEmptySecret(), + } + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 0) + + altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, adminTOTPConfigsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var configs []mfa.TOTPConfig + err = json.Unmarshal(rr.Body.Bytes(), &configs) + assert.NoError(t, err, rr.Body.String()) + assert.Len(t, configs, len(mfa.GetAvailableTOTPConfigs())) + totpConfig := configs[0] + totpReq := generateTOTPRequest{ + ConfigName: totpConfig.Name, + } + asJSON, err = json.Marshal(totpReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPGeneratePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var totpGenResp generateTOTPResponse + err = json.Unmarshal(rr.Body.Bytes(), &totpGenResp) + assert.NoError(t, err) + assert.NotEmpty(t, totpGenResp.Secret) + assert.NotEmpty(t, totpGenResp.QRCode) + + passcode, err := generateTOTPPasscode(totpGenResp.Secret) + assert.NoError(t, err) + validateReq := validateTOTPRequest{ + ConfigName: totpGenResp.ConfigName, + Passcode: passcode, + Secret: totpGenResp.Secret, + } + asJSON, err = json.Marshal(validateReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPValidatePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // the same passcode cannot be reused + req, err = http.NewRequest(http.MethodPost, adminTOTPValidatePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "this passcode was already used") + + adminTOTPConfig := dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: totpGenResp.ConfigName, + Secret: kms.NewPlainSecret(totpGenResp.Secret), + } + asJSON, err = json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Equal(t, totpGenResp.ConfigName, admin.Filters.TOTPConfig.ConfigName) + assert.Empty(t, admin.Filters.TOTPConfig.Secret.GetKey()) + assert.Empty(t, admin.Filters.TOTPConfig.Secret.GetAdditionalData()) + assert.NotEmpty(t, admin.Filters.TOTPConfig.Secret.GetPayload()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, admin.Filters.TOTPConfig.Secret.GetStatus()) + admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: false, + ConfigName: util.GenerateUniqueID(), + Secret: kms.NewEmptySecret(), + } + admin.Filters.RecoveryCodes = []dataprovider.RecoveryCode{ + { + Secret: kms.NewEmptySecret(), + }, + } + admin, resp, err := httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err, string(resp)) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 12) + // if we use token we cannot get or generate recovery codes + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // now the same but with altToken + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var recCodes []recoveryCode + err = json.Unmarshal(rr.Body.Bytes(), &recCodes) + assert.NoError(t, err) + assert.Len(t, recCodes, 12) + // regenerate recovery codes + req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check that recovery codes are different + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var newRecCodes []recoveryCode + err = json.Unmarshal(rr.Body.Bytes(), &newRecCodes) + assert.NoError(t, err) + assert.Len(t, newRecCodes, 12) + assert.NotEqual(t, recCodes, newRecCodes) + // disable 2FA, the update admin API should not work + admin.Filters.TOTPConfig.Enabled = false + admin.Filters.RecoveryCodes = nil + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, altAdminUsername, admin.Username) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 12) + // use the dedicated API + req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.TOTPConfig.Enabled) + assert.Len(t, admin.Filters.RecoveryCodes, 0) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPut, adminPath+"/"+altAdminUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, admin2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestChangeAdminPwdInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer([]byte("{"))) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) +} + +func TestSMTPConfig(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + smtpTestURL := path.Join(webConfigsPath, "smtp", "test") + tokenHeader := "X-CSRF-TOKEN" + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer([]byte("{"))) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + testReq := make(map[string]any) + testReq["host"] = smtpCfg.Host + testReq["port"] = 3525 + testReq["from"] = "from@example.com" + asJSON, err := json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + testReq["recipient"] = "example@example.com" + asJSON, err = json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + configs := dataprovider.Configs{ + SMTP: &dataprovider.SMTPConfigs{ + Host: "127.0.0.1", + Port: 3535, + User: "user@example.com", + Password: kms.NewPlainSecret(defaultPassword), + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + + testReq["password"] = redactedSecret + asJSON, err = json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), "server does not support SMTP AUTH") + + testReq["password"] = "" + testReq["auth_type"] = 3 + testReq["oauth2"] = smtp.OAuth2Config{ + ClientSecret: redactedSecret, + RefreshToken: redactedSecret, + } + asJSON, err = json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, smtpTestURL, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "smtp oauth2: client id is required") + + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) +} + +func TestOAuth2TokenRequest(t *testing.T) { + tokenHeader := "X-CSRF-TOKEN" + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer([]byte("{"))) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + testReq := make(map[string]any) + testReq["client_secret"] = redactedSecret + asJSON, err := json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "base redirect url is required") + + testReq["base_redirect_url"] = "http://localhost:8081" + asJSON, err = json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set(tokenHeader, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestMFAPermission(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientMFAPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientMFAPath + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user.Filters.WebClient = []string{sdk.WebClientMFADisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientMFAPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientMFAPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebUserTwoFactorLogin(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + // enable two factor authentication + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + adminToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolHTTP}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var recCodes []recoveryCode + err = json.Unmarshal(rr.Body.Bytes(), &recCodes) + assert.NoError(t, err) + assert.Len(t, recCodes, 12) + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 12) + for _, c := range user.Filters.RecoveryCodes { + assert.Empty(t, c.Secret.GetAdditionalData()) + assert.Empty(t, c.Secret.GetKey()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, c.Secret.GetStatus()) + assert.NotEmpty(t, c.Secret.GetPayload()) + } + + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := getLoginForm(defaultUsername, defaultPassword, csrfToken) + // CSRF verification fails if there is no cookie + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + cookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // without a cookie + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // invalid IP address + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = "6.7.8.9:4567" + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webClientTwoFactorRecoveryPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // any other page will be redirected to the two factor auth page + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + // a partial token cannot be used for admin pages + req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + form = make(url.Values) + form.Set("passcode", passcode) + + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorPath, cookie) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("passcode", "invalid_user_passcode") + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("passcode", "") + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("passcode", passcode) + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + // the same cookie cannot be reused + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusNotFound, rr.Code) + // get a new cookie and login using a recovery code + loginCookie, csrfToken, err = getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + form = make(url.Values) + recoveryCode := recCodes[0].Code + form.Set("recovery_code", recoveryCode) + // no csrf + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("recovery_code", "") + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("recovery_code", recoveryCode) + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + authenticatedCookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + //render MFA page + req, err = http.NewRequest(http.MethodGet, webClientMFAPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, authenticatedCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // get MFA qrcode + req, err = http.NewRequest(http.MethodGet, path.Join(webClientMFAPath, "qrcode?url="+url.QueryEscape(key.URL())), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, authenticatedCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "image/png", rr.Header().Get("Content-Type")) + // invalid MFA url + req, err = http.NewRequest(http.MethodGet, path.Join(webClientMFAPath, "qrcode?url="+url.QueryEscape("http://foo\x7f.eu")), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, authenticatedCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + // check that the recovery code was marked as used + req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + recCodes = nil + err = json.Unmarshal(rr.Body.Bytes(), &recCodes) + assert.NoError(t, err) + assert.Len(t, recCodes, 12) + found := false + for _, rc := range recCodes { + if rc.Code == recoveryCode { + found = true + assert.True(t, rc.Used) + } else { + assert.False(t, rc.Used) + } + } + assert.True(t, found) + // the same recovery code cannot be reused + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set("recovery_code", recoveryCode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form.Set("recovery_code", "invalid_user_recovery_code") + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + // disable TOTP + req, err = http.NewRequest(http.MethodPut, userPath+"/"+user.Username+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPut, userPath+"/"+user.Username+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "two-factor authentication is not enabled") + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set("recovery_code", recoveryCode) + form.Set("passcode", passcode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) + + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18n2FADisabled) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + req, err = http.NewRequest(http.MethodGet, webClientMFAPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, authenticatedCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) +} + +func TestWebUserTwoFactoryLoginRedirect(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolHTTP}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := getLoginForm(defaultUsername, defaultPassword, csrfToken) + uri := webClientFilesPath + "?path=%2F" + loginURI := webClientLoginPath + "?next=" + url.QueryEscape(uri) + expectedURI := webClientTwoFactorPath + "?next=" + url.QueryEscape(uri) + req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURI + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, expectedURI, rr.Header().Get("Location")) + cookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + // test unsafe redirects + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + externalURI := webClientLoginPath + "?next=" + url.QueryEscape("https://example.com") + req, err = http.NewRequest(http.MethodPost, externalURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = externalURI + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken) + internalURI := webClientLoginPath + "?next=" + url.QueryEscape(webClientMFAPath) + req, err = http.NewRequest(http.MethodPost, internalURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = internalURI + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + // render two factor page + req, err = http.NewRequest(http.MethodGet, expectedURI, nil) + assert.NoError(t, err) + req.RequestURI = expectedURI + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), fmt.Sprintf("action=%q", expectedURI)) + // login with the passcode + csrfToken, err = getCSRFTokenFromInternalPageMock(expectedURI, cookie) + assert.NoError(t, err) + passcode, err := generateTOTPPasscode(key.Secret()) + assert.NoError(t, err) + form = make(url.Values) + form.Set("passcode", passcode) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, expectedURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = expectedURI + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, uri, rr.Header().Get("Location")) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSearchEvents(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, fsEventsPath+"?limit=10&order=ASC&fs_provider=0", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + events := make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &events) + assert.NoError(t, err) + if assert.Len(t, events, 1) { + ev := events[0] + for _, field := range []string{"id", "timestamp", "action", "username", "fs_path", "status", "protocol", + "ip", "session_id", "fs_provider", "bucket", "endpoint", "open_flags", "role", "instance_id"} { + _, ok := ev[field] + assert.True(t, ok, field) + } + } + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?limit=10&order=ASC&role=role1", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + events = nil + err = json.Unmarshal(rr.Body.Bytes(), &events) + assert.NoError(t, err) + assert.Len(t, events, 1) + // CSV export + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?limit=10&order=ASC&csv_export=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) + // the test eventsearcher plugin returns error if start_timestamp < 0 + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?start_timestamp=-1&end_timestamp=123456&statuses=1,2", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + // CSV export with error + exportFunc := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?start_timestamp=-2&csv_export=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + } + exportFunc() + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?limit=e", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, providerEventsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + events = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &events) + assert.NoError(t, err) + if assert.Len(t, events, 1) { + ev := events[0] + for _, field := range []string{"id", "timestamp", "action", "username", "object_type", "object_name", + "object_data", "role", "instance_id"} { + _, ok := ev[field] + assert.True(t, ok, field) + } + } + req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?omit_object_data=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + events = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &events) + assert.NoError(t, err) + if assert.Len(t, events, 1) { + ev := events[0] + field := "object_data" + _, ok := ev[field] + assert.False(t, ok, field) + } + // CSV export + req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?csv_export=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) + + // the test eventsearcher plugin returns error if start_timestamp < 0 + req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?start_timestamp=-1", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + // CSV export with error + exportFunc = func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?start_timestamp=-1&csv_export=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + } + exportFunc() + + req, err = http.NewRequest(http.MethodGet, logEventsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + events = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &events) + assert.NoError(t, err) + if assert.Len(t, events, 1) { + ev := events[0] + for _, field := range []string{"id", "timestamp", "event", "ip", "message", "role", "instance_id"} { + _, ok := ev[field] + assert.True(t, ok, field) + } + } + req, err = http.NewRequest(http.MethodGet, logEventsPath+"?events=a,1", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // CSV export + req, err = http.NewRequest(http.MethodGet, logEventsPath+"?csv_export=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "text/csv", rr.Header().Get("Content-Type")) + // the test eventsearcher plugin returns error if start_timestamp < 0 + req, err = http.NewRequest(http.MethodGet, logEventsPath+"?start_timestamp=-1", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + // CSV export with error + exportFunc = func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + req, err = http.NewRequest(http.MethodGet, logEventsPath+"?start_timestamp=-1&csv_export=true", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + } + exportFunc() + + req, err = http.NewRequest(http.MethodGet, providerEventsPath+"?limit=2000", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, logEventsPath+"?limit=2000", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?start_timestamp=a", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?end_timestamp=a", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?order=ASSC", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?statuses=a,b", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, fsEventsPath+"?fs_provider=a", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, webEventsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestMFAErrors(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + assert.False(t, user.Filters.TOTPConfig.Enabled) + userToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + adminToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + // invalid config name + totpReq := generateTOTPRequest{ + ConfigName: "invalid config name", + } + asJSON, err := json.Marshal(totpReq) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userTOTPGeneratePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // invalid JSON + invalidJSON := []byte("not a JSON") + req, err = http.NewRequest(http.MethodPost, userTOTPGeneratePath, bytes.NewBuffer(invalidJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(invalidJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(invalidJSON)) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPost, adminTOTPValidatePath, bytes.NewBuffer(invalidJSON)) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // invalid TOTP config name + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: "missing name", + Secret: kms.NewPlainSecret(xid.New().String()), + Protocols: []string{common.ProtocolSSH}, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: config name") + // invalid TOTP secret + userTOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: nil, + Protocols: []string{common.ProtocolSSH}, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") + // no protocol + userTOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: kms.NewPlainSecret(xid.New().String()), + Protocols: nil, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: specify at least one protocol") + // invalid protocol + userTOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: kms.NewPlainSecret(xid.New().String()), + Protocols: []string{common.ProtocolWebDAV}, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: invalid protocol") + + adminTOTPConfig := dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: "", + Secret: kms.NewPlainSecret("secret"), + } + asJSON, err = json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: config name is mandatory") + + adminTOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: nil, + } + asJSON, err = json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") + + // invalid TOTP secret status + userTOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: kms.NewSecret(sdkkms.SecretStatusRedacted, "", "", ""), + Protocols: []string{common.ProtocolSSH}, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // previous secret will be preserved and we have no secret saved + assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") + + req, err = http.NewRequest(http.MethodPost, adminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "totp: secret is mandatory") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMFAInvalidSecret(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + userToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), + Protocols: []string{common.ProtocolSSH, common.ProtocolHTTP}, + } + user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ + Used: false, + Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), + }) + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), "Unable to decrypt recovery codes") + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := getLoginForm(defaultUsername, defaultPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientTwoFactorPath, rr.Header().Get("Location")) + cookie, err := getCookieFromResponse(rr) + assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("passcode", "123456") + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("recovery_code", "RC-123456") + req, err = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, userTokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", "authcode") + req.SetBasicAuth(defaultUsername, defaultPassword) + resp, err := httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + admin.Password = altAdminPassword + admin.Filters.TOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: mfa.GetAvailableTOTPConfigNames()[0], + Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), + } + admin.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ + Used: false, + Secret: kms.NewSecret(sdkkms.SecretStatusSecretBox, "payload", "key", user.Username), + }) + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, err = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminTwoFactorPath, rr.Header().Get("Location")) + cookie, err = getCookieFromResponse(rr) + assert.NoError(t, err) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("passcode", "123456") + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminTwoFactorRecoveryPath, cookie) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("recovery_code", "RC-123456") + req, err = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, cookie) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v%v", httpBaseURL, tokenPath), nil) + assert.NoError(t, err) + req.Header.Set("X-SFTPGO-OTP", "auth-code") + req.SetBasicAuth(altAdminUsername, altAdminPassword) + resp, err = httpclient.GetHTTPClient().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebUserTOTP(t *testing.T) { + u := getTestUser() + // TOTPConfig will be ignored on add + u.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: "", + Secret: kms.NewEmptySecret(), + Protocols: []string{common.ProtocolSSH}, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.False(t, user.Filters.TOTPConfig.Enabled) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, userTOTPConfigsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var configs []mfa.TOTPConfig + err = json.Unmarshal(rr.Body.Bytes(), &configs) + assert.NoError(t, err, rr.Body.String()) + assert.Len(t, configs, len(mfa.GetAvailableTOTPConfigs())) + totpConfig := configs[0] + totpReq := generateTOTPRequest{ + ConfigName: totpConfig.Name, + } + asJSON, err := json.Marshal(totpReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPGeneratePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var totpGenResp generateTOTPResponse + err = json.Unmarshal(rr.Body.Bytes(), &totpGenResp) + assert.NoError(t, err) + assert.NotEmpty(t, totpGenResp.Secret) + assert.NotEmpty(t, totpGenResp.QRCode) + + passcode, err := generateTOTPPasscode(totpGenResp.Secret) + assert.NoError(t, err) + validateReq := validateTOTPRequest{ + ConfigName: totpGenResp.ConfigName, + Passcode: passcode, + Secret: totpGenResp.Secret, + } + asJSON, err = json.Marshal(validateReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPValidatePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // the same passcode cannot be reused + req, err = http.NewRequest(http.MethodPost, userTOTPValidatePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "this passcode was already used") + + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: totpGenResp.ConfigName, + Secret: kms.NewPlainSecret(totpGenResp.Secret), + Protocols: []string{common.ProtocolSSH}, + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + totpCfg := user.Filters.TOTPConfig + assert.True(t, totpCfg.Enabled) + secretPayload := totpCfg.Secret.GetPayload() + assert.Equal(t, totpGenResp.ConfigName, totpCfg.ConfigName) + assert.Empty(t, totpCfg.Secret.GetKey()) + assert.Empty(t, totpCfg.Secret.GetAdditionalData()) + assert.NotEmpty(t, secretPayload) + assert.Equal(t, sdkkms.SecretStatusSecretBox, totpCfg.Secret.GetStatus()) + assert.Len(t, totpCfg.Protocols, 1) + assert.Contains(t, totpCfg.Protocols, common.ProtocolSSH) + // update protocols only + userTOTPConfig = dataprovider.UserTOTPConfig{ + Protocols: []string{common.ProtocolSSH, common.ProtocolFTP}, + Secret: kms.NewEmptySecret(), + } + asJSON, err = json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // update the user, TOTP should not be affected + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: false, + Secret: kms.NewEmptySecret(), + } + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, totpCfg.ConfigName, user.Filters.TOTPConfig.ConfigName) + assert.Empty(t, user.Filters.TOTPConfig.Secret.GetKey()) + assert.Empty(t, user.Filters.TOTPConfig.Secret.GetAdditionalData()) + assert.Equal(t, secretPayload, user.Filters.TOTPConfig.Secret.GetPayload()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) + assert.Len(t, user.Filters.TOTPConfig.Protocols, 2) + assert.Contains(t, user.Filters.TOTPConfig.Protocols, common.ProtocolSSH) + assert.Contains(t, user.Filters.TOTPConfig.Protocols, common.ProtocolFTP) + + req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var recCodes []recoveryCode + err = json.Unmarshal(rr.Body.Bytes(), &recCodes) + assert.NoError(t, err) + assert.Len(t, recCodes, 12) + // regenerate recovery codes + req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check that recovery codes are different + req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var newRecCodes []recoveryCode + err = json.Unmarshal(rr.Body.Bytes(), &newRecCodes) + assert.NoError(t, err) + assert.Len(t, newRecCodes, 12) + assert.NotEqual(t, recCodes, newRecCodes) + // disable 2FA, the update user API should not work + adminToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user.Filters.TOTPConfig.Enabled = false + user.Filters.RecoveryCodes = nil + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, defaultUsername, user.Username) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Len(t, user.Filters.RecoveryCodes, 12) + // use the dedicated API + req, err = http.NewRequest(http.MethodPut, userPath+"/"+defaultUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.False(t, user.Filters.TOTPConfig.Enabled) + assert.Len(t, user.Filters.RecoveryCodes, 0) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPut, userPath+"/"+defaultUsername+"/2fa/disable", nil) + assert.NoError(t, err) + setBearerForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, user2FARecoveryCodesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, userTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestWebAPIChangeUserProfileMock(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + assert.False(t, user.Filters.AllowAPIKeyAuth) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + // invalid json + req, err := http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer([]byte("{"))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + email := "userapi@example.com" + additionalEmails := []string{"userapi1@example.com"} + description := "user API description" + profileReq := make(map[string]any) + profileReq["allow_api_key_auth"] = true + profileReq["email"] = email + profileReq["description"] = description + profileReq["public_keys"] = []string{testPubKey, testPubKey1} + profileReq["tls_certs"] = []string{httpsCert} + profileReq["additional_emails"] = additionalEmails + asJSON, err := json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + profileReq = make(map[string]any) + req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = json.Unmarshal(rr.Body.Bytes(), &profileReq) + assert.NoError(t, err) + assert.Equal(t, email, profileReq["email"].(string)) + assert.Len(t, profileReq["additional_emails"].([]interface{}), 1) + assert.Equal(t, description, profileReq["description"].(string)) + assert.True(t, profileReq["allow_api_key_auth"].(bool)) + val, ok := profileReq["public_keys"].([]any) + if assert.True(t, ok, profileReq) { + assert.Len(t, val, 2) + } + val, ok = profileReq["tls_certs"].([]any) + if assert.True(t, ok, profileReq) { + assert.Len(t, val, 1) + } + // set an invalid email + profileReq = make(map[string]any) + profileReq["email"] = "notavalidemail" + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Validation error: email") + // set an invalid additional email + profileReq = make(map[string]any) + profileReq["additional_emails"] = []string{"not an email"} + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Validation error: email") + // set an invalid public key + profileReq = make(map[string]any) + profileReq["public_keys"] = []string{"not a public key"} + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Validation error: error parsing public key") + // set an invalid TLS certificate + profileReq = make(map[string]any) + profileReq["tls_certs"] = []string{"not a TLS cert"} + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Validation error: invalid TLS certificate") + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientPubKeyChangeDisabled, + sdk.WebClientTLSCertChangeDisabled} + user.Email = email + user.Description = description + user.Filters.AllowAPIKeyAuth = true + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + profileReq = make(map[string]any) + profileReq["allow_api_key_auth"] = false + profileReq["email"] = email + profileReq["description"] = description + "_mod" //nolint:goconst + profileReq["public_keys"] = []string{testPubKey} + profileReq["tls_certs"] = []string{} + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), "Profile updated") + // check that api key auth and public keys were not changed + profileReq = make(map[string]any) + req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = json.Unmarshal(rr.Body.Bytes(), &profileReq) + assert.NoError(t, err) + assert.Equal(t, email, profileReq["email"].(string)) + assert.Equal(t, description+"_mod", profileReq["description"].(string)) + assert.True(t, profileReq["allow_api_key_auth"].(bool)) + val, ok = profileReq["public_keys"].([]any) + if assert.True(t, ok, profileReq) { + assert.Len(t, val, 2) + } + val, ok = profileReq["tls_certs"].([]any) + if assert.True(t, ok, profileReq) { + assert.Len(t, val, 1) + } + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled} + user.Description = description + "_mod" + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + profileReq = make(map[string]any) + profileReq["allow_api_key_auth"] = false + profileReq["email"] = "newemail@apiuser.com" + profileReq["description"] = description + profileReq["public_keys"] = []string{testPubKey} + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + profileReq = make(map[string]any) + req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = json.Unmarshal(rr.Body.Bytes(), &profileReq) + assert.NoError(t, err) + assert.Equal(t, email, profileReq["email"].(string)) + assert.Equal(t, description+"_mod", profileReq["description"].(string)) + assert.True(t, profileReq["allow_api_key_auth"].(bool)) + assert.Len(t, profileReq["public_keys"].([]any), 1) + // finally disable all profile permissions + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled, + sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "You are not allowed to change anything") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPut, userProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestPermGroupOverride(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.WebClient = []string{sdk.WebClientPasswordChangeDisabled} + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + pwd := make(map[string]string) + pwd["current_password"] = defaultPassword + pwd["new_password"] = altAdminPassword + asJSON, err := json.Marshal(pwd) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + group.UserSettings.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP} + group, _, err = httpdtest.UpdateGroup(group, http.StatusOK) + assert.NoError(t, err) + + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols") + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nError2FARequired) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebAPIChangeUserPwdMock(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, userProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // invalid json + req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer([]byte("{"))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + pwd := make(map[string]string) + pwd["current_password"] = defaultPassword + pwd["new_password"] = defaultPassword + asJSON, err := json.Marshal(pwd) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the new password must be different from the current one") + + pwd["new_password"] = altAdminPassword + asJSON, err = json.Marshal(pwd) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, userProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, altAdminPassword) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + // remove the change password permission + user.Filters.WebClient = []string{sdk.WebClientPasswordChangeDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Len(t, user.Filters.WebClient, 1) + assert.Contains(t, user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) + + token, err = getJWTAPIUserTokenFromTestServer(defaultUsername, altAdminPassword) + assert.NoError(t, err) + assert.NotEmpty(t, token) + + pwd["current_password"] = altAdminPassword + pwd["new_password"] = defaultPassword + asJSON, err = json.Marshal(pwd) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userPwdPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginInvalidPasswordMock(t *testing.T) { + _, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass+"1") + assert.Error(t, err) + // now a login with no credentials + req, _ := http.NewRequest(http.MethodGet, "/api/v2/token", nil) + rr := executeRequest(req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) +} + +func TestWebAPIChangeAdminProfileMock(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + assert.False(t, admin.Filters.AllowAPIKeyAuth) + + token, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + // invalid json + req, err := http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer([]byte("{"))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + email := "adminapi@example.com" + description := "admin API description" + profileReq := make(map[string]any) + profileReq["allow_api_key_auth"] = true + profileReq["email"] = email + profileReq["description"] = description + asJSON, err := json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), "Profile updated") + + profileReq = make(map[string]any) + req, err = http.NewRequest(http.MethodGet, adminProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = json.Unmarshal(rr.Body.Bytes(), &profileReq) + assert.NoError(t, err) + assert.Equal(t, email, profileReq["email"].(string)) + assert.Equal(t, description, profileReq["description"].(string)) + assert.True(t, profileReq["allow_api_key_auth"].(bool)) + // set an invalid email + profileReq["email"] = "admin_invalid_email" + asJSON, err = json.Marshal(profileReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Validation error: email") + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, adminProfilePath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPut, adminProfilePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestChangeAdminPwdMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminDeleteUsers} + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + pwd := make(map[string]string) + pwd["current_password"] = altAdminPassword + pwd["new_password"] = defaultTokenAuthPass + asJSON, err = json.Marshal(pwd) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // try using the old token + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + _, err = getJWTAPITokenFromTestServer(altAdminUsername, altAdminPassword) + assert.Error(t, err) + + altToken, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) // current password does not match + + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestUpdateAdminMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.Error(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Permissions = []string{dataprovider.PermAdminAny} + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, "abc"), bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer([]byte("no json"))) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + admin.Permissions = nil + asJSON, err = json.Marshal(admin) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + admin = getTestAdmin() + admin.Status = 0 + asJSON, err = json.Marshal(admin) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "you cannot disable yourself") + admin.Status = 1 + admin.Permissions = []string{dataprovider.PermAdminAddUsers} + asJSON, err = json.Marshal(admin) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "you cannot change your permissions") + admin.Permissions = []string{dataprovider.PermAdminAny} + admin.Role = "missing role" + asJSON, err = json.Marshal(admin) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "you cannot add/change your role") + admin.Role = "" + + altToken, err := getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + admin.Password = "" // it must remain unchanged + admin.Permissions = []string{dataprovider.PermAdminAny} + asJSON, err = json.Marshal(admin) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer(asJSON)) + setBearerForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestAdminLastLoginWithAPIKey(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Filters.AllowAPIKeyAuth = true + admin, resp, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Equal(t, int64(0), admin.LastLogin) + + apiKey := dataprovider.APIKey{ + Name: "admin API key", + Scope: dataprovider.APIKeyScopeAdmin, + Admin: altAdminUsername, + LastUseAt: 123, + } + + apiKey, resp, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Equal(t, int64(0), apiKey.LastUseAt) + + req, err := http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, admin.Username) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, admin.LastLogin, int64(0)) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserLastLoginWithAPIKey(t *testing.T) { + user := getTestUser() + user.Filters.AllowAPIKeyAuth = true + user, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + assert.Equal(t, int64(0), user.LastLogin) + + apiKey := dataprovider.APIKey{ + Name: "user API key", + Scope: dataprovider.APIKeyScopeUser, + User: user.Username, + } + + apiKey, resp, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.LastLogin, int64(0)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAdminHandlingWithAPIKeys(t *testing.T) { + sysAdmin, _, err := httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) + assert.NoError(t, err) + sysAdmin.Filters.AllowAPIKeyAuth = true + sysAdmin, _, err = httpdtest.UpdateAdmin(sysAdmin, http.StatusOK) + assert.NoError(t, err) + + apiKey := dataprovider.APIKey{ + Name: "test admin API key", + Scope: dataprovider.APIKeyScopeAdmin, + Admin: defaultTokenAuthUser, + } + + apiKey, resp, err := httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + admin := getTestAdmin() + admin.Username = altAdminUsername + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + _, err = getJWTAPITokenFromTestServer(altAdminUsername, defaultTokenAuthPass) + assert.NoError(t, err) + + admin.Filters.AllowAPIKeyAuth = true + asJSON, err = json.Marshal(admin) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, path.Join(adminPath, altAdminUsername), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(adminPath, altAdminUsername), nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var adminGet dataprovider.Admin + err = json.Unmarshal(rr.Body.Bytes(), &adminGet) + assert.NoError(t, err) + assert.True(t, adminGet.Filters.AllowAPIKeyAuth) + + req, err = http.NewRequest(http.MethodPut, path.Join(adminPath, defaultTokenAuthUser), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "updating the admin impersonated with an API key is not allowed") + // changing the password for the impersonated admin is not allowed + pwd := make(map[string]string) + pwd["current_password"] = defaultTokenAuthPass + pwd["new_password"] = altAdminPassword + asJSON, err = json.Marshal(pwd) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, adminPwdPath, bytes.NewBuffer(asJSON)) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "API key authentication is not allowed") + + req, err = http.NewRequest(http.MethodDelete, path.Join(adminPath, defaultTokenAuthUser), nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "you cannot delete yourself") + + req, err = http.NewRequest(http.MethodDelete, path.Join(adminPath, altAdminUsername), nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) + + dbAdmin, err := dataprovider.AdminExists(defaultTokenAuthUser) + assert.NoError(t, err) + dbAdmin.Filters.AllowAPIKeyAuth = false + err = dataprovider.UpdateAdmin(&dbAdmin, "", "", "") + assert.NoError(t, err) + sysAdmin, _, err = httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) + assert.NoError(t, err) + assert.False(t, sysAdmin.Filters.AllowAPIKeyAuth) +} + +func TestUserHandlingWithAPIKey(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Filters.AllowAPIKeyAuth = true + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + apiKey := dataprovider.APIKey{ + Name: "test admin API key", + Scope: dataprovider.APIKeyScopeAdmin, + Admin: admin.Username, + } + + apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err) + + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + user.Filters.DisableFsChecks = true + user.Description = "desc" + userAsJSON = getUserAsJSON(t, user) + req, err = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updatedUser dataprovider.User + err = json.Unmarshal(rr.Body.Bytes(), &updatedUser) + assert.NoError(t, err) + assert.True(t, updatedUser.Filters.DisableFsChecks) + assert.Equal(t, user.Description, updatedUser.Description) + + req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusNotFound) + assert.NoError(t, err) +} + +func TestUpdateUserQuotaUsageMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + var user dataprovider.User + u := getTestUser() + usedQuotaFiles := 1 + usedQuotaSize := int64(65535) + u.UsedQuotaFiles = usedQuotaFiles + u.UsedQuotaSize = usedQuotaSize + u.QuotaFiles = 100 + userAsJSON := getUserAsJSON(t, u) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize, user.UsedQuotaSize) + // now update only quota size + u.UsedQuotaFiles = 0 + userAsJSON = getUserAsJSON(t, u) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage")+"?mode=add", bytes.NewBuffer(userAsJSON)) //nolint:goconst + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize*2, user.UsedQuotaSize) + // only quota files + u.UsedQuotaFiles = usedQuotaFiles + u.UsedQuotaSize = 0 + userAsJSON = getUserAsJSON(t, u) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage")+"?mode=add", bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles*2, user.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize*2, user.UsedQuotaSize) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer([]byte("string"))) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.True(t, common.QuotaScans.AddUserQuotaScan(user.Username, "")) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "users", u.Username, "usage"), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username)) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestUserPermissionsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + user.Permissions = make(map[string][]string) + user.Permissions["/somedir"] = []string{dataprovider.PermAny} + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + user.Permissions[".."] = []string{dataprovider.PermAny} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.Permissions["/somedir"] = []string{"invalid"} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + delete(user.Permissions, "/somedir") + user.Permissions["/somedir/.."] = []string{dataprovider.PermAny} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + delete(user.Permissions, "/somedir/..") + user.Permissions["not_abs_path"] = []string{dataprovider.PermAny} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + delete(user.Permissions, "not_abs_path") + user.Permissions["/somedir/../otherdir/"] = []string{dataprovider.PermListItems} + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updatedUser) + assert.NoError(t, err) + if val, ok := updatedUser.Permissions["/otherdir"]; ok { + assert.True(t, slices.Contains(val, dataprovider.PermListItems)) + assert.Equal(t, 1, len(val)) + } else { + assert.Fail(t, "expected dir not found in permissions") + } + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestUpdateUserInvalidJsonMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer([]byte("Invalid json"))) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestUpdateUserInvalidParamsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.HomeDir = "" + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + userID := user.ID + user.ID = 0 + user.CreatedAt = 0 + userAsJSON = getUserAsJSON(t, user) + req, _ = http.NewRequest(http.MethodPut, path.Join(userPath, user.Username), bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + user.ID = userID + req, _ = http.NewRequest(http.MethodPut, userPath+"/0", bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestGetAdminsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, adminPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=510&offset=0&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var admins []dataprovider.Admin + err = render.DecodeJSON(rr.Body, &admins) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(admins), 1) + firtAdmin := admins[0].Username + req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=510&offset=0&order=DESC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + admins = nil + err = render.DecodeJSON(rr.Body, &admins) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(admins), 1) + assert.NotEqual(t, firtAdmin, admins[0].Username) + + req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=510&offset=1&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + admins = nil + err = render.DecodeJSON(rr.Body, &admins) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(admins), 1) + assert.NotEqual(t, firtAdmin, admins[0].Username) + + req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=a&offset=0&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=1&offset=aa&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodGet, adminPath+"?limit=1&offset=0&order=ASCa", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(adminPath, admin.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestGetUsersMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=510&offset=0&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var users []dataprovider.User + err = render.DecodeJSON(rr.Body, &users) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(users), 1) + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=aa&offset=0&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=a&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodGet, userPath+"?limit=1&offset=0&order=ASCc", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestDeleteUserInvalidParamsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodDelete, userPath+"/0", nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestGetQuotaScansMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, quotaScanPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestStartQuotaScanMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + _, err = os.Stat(user.HomeDir) + if err == nil { + err = os.Remove(user.HomeDir) + assert.NoError(t, err) + } + // simulate a duplicate quota scan + common.QuotaScans.AddUserQuotaScan(user.Username, "") + req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + assert.True(t, common.QuotaScans.RemoveUserQuotaScan(user.Username)) + + req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + + waitForUsersQuotaScan(t, token) + + _, err = os.Stat(user.HomeDir) + if err != nil && errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(user.HomeDir, os.ModePerm) + assert.NoError(t, err) + } + req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + + waitForUsersQuotaScan(t, token) + + req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + + waitForUsersQuotaScan(t, token) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUpdateFolderQuotaUsageMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + mappedPath := filepath.Join(os.TempDir(), "vfolder") + folderName := filepath.Base(mappedPath) + f := vfs.BaseVirtualFolder{ + MappedPath: mappedPath, + Name: folderName, + } + usedQuotaFiles := 1 + usedQuotaSize := int64(65535) + f.UsedQuotaFiles = usedQuotaFiles + f.UsedQuotaSize = usedQuotaSize + var folder vfs.BaseVirtualFolder + folderAsJSON, err := json.Marshal(f) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var folderGet vfs.BaseVirtualFolder + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folderGet) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, folderGet.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize, folderGet.UsedQuotaSize) + // now update only quota size + f.UsedQuotaFiles = 0 + folderAsJSON, err = json.Marshal(f) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage")+"?mode=add", + bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + folderGet = vfs.BaseVirtualFolder{} + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folderGet) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles, folderGet.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize*2, folderGet.UsedQuotaSize) + // now update only quota files + f.UsedQuotaSize = 0 + f.UsedQuotaFiles = 1 + folderAsJSON, err = json.Marshal(f) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage")+"?mode=add", + bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + folderGet = vfs.BaseVirtualFolder{} + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folderGet) + assert.NoError(t, err) + assert.Equal(t, usedQuotaFiles*2, folderGet.UsedQuotaFiles) + assert.Equal(t, usedQuotaSize*2, folderGet.UsedQuotaSize) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), + bytes.NewBuffer([]byte("not a json"))) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + assert.True(t, common.QuotaScans.AddVFolderQuotaScan(folderName)) + req, _ = http.NewRequest(http.MethodPut, path.Join(quotasBasePath, "folders", folder.Name, "usage"), + bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName)) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestStartFolderQuotaScanMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + mappedPath := filepath.Join(os.TempDir(), "vfolder") + folderName := filepath.Base(mappedPath) + folder := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + folderAsJSON, err := json.Marshal(folder) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + _, err = os.Stat(mappedPath) + if err == nil { + err = os.Remove(mappedPath) + assert.NoError(t, err) + } + // simulate a duplicate quota scan + common.QuotaScans.AddVFolderQuotaScan(folderName) + req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + assert.True(t, common.QuotaScans.RemoveVFolderQuotaScan(folderName)) + // and now a real quota scan + _, err = os.Stat(mappedPath) + if err != nil && errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + } + req, _ = http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + waitForFoldersQuotaScanPath(t, token) + // cleanup + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = os.RemoveAll(folderPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestStartQuotaScanNonExistentUserMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + + req, _ := http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "users", user.Username, "scan"), nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestStartQuotaScanNonExistentFolderMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + folder := vfs.BaseVirtualFolder{ + Name: "afolder", + } + req, _ := http.NewRequest(http.MethodPost, path.Join(quotasBasePath, "folders", folder.Name, "scan"), nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestGetFoldersMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + mappedPath := filepath.Join(os.TempDir(), "vfolder") + folderName := filepath.Base(mappedPath) + folder := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + folderAsJSON, err := json.Marshal(folder) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + + var folders []vfs.BaseVirtualFolder + url, err := url.Parse(folderPath + "?limit=510&offset=0&order=DESC") + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, url.String(), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folders) + assert.NoError(t, err) + assert.GreaterOrEqual(t, len(folders), 1) + req, _ = http.NewRequest(http.MethodGet, folderPath+"?limit=a&offset=0&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodGet, folderPath+"?limit=1&offset=a&order=ASC", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + req, _ = http.NewRequest(http.MethodGet, folderPath+"?limit=1&offset=0&order=ASCV", nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestGetVersionMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, versionPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, versionPath, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, versionPath, nil) + setBearerForReq(req, "abcde") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) +} + +func TestGetConnectionsMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, activeConnectionsPath, nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestGetStatusMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestDeleteActiveConnectionMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + req.Header.Set(dataprovider.NodeTokenHeader, "Bearer abc") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "the provided token cannot be authenticated") + req, err = http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID?node=node1", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestNotFoundMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, "/non/existing/path", nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestMethodNotAllowedMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, activeConnectionsPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusMethodNotAllowed, rr) +} + +func TestHealthCheck(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, healthzPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "ok", rr.Body.String()) +} + +func TestGetWebRootMock(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "/", nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + req, _ = http.NewRequest(http.MethodGet, webBasePath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + req, _ = http.NewRequest(http.MethodGet, webBasePathAdmin, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + req, _ = http.NewRequest(http.MethodGet, webBasePathClient, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) +} + +func TestWebNotFoundURI(t *testing.T) { + urlString := httpBaseURL + webBasePath + "/a" + req, err := http.NewRequest(http.MethodGet, urlString, nil) + assert.NoError(t, err) + resp, err := httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, urlString, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, "invalid token") + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + + urlString = httpBaseURL + webBasePathClient + "/a" + req, err = http.NewRequest(http.MethodGet, urlString, nil) + assert.NoError(t, err) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, urlString, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, "invalid client token") + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestLogout(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, logoutPath, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token is no longer valid") +} + +func TestDefenderAPIInvalidIDMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, path.Join(defenderHosts, "abc"), nil) // not hex id + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "invalid host id") +} + +func TestTokenHeaderCookie(t *testing.T) { + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setJWTCookieForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "no token found") + + req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setBearerForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestTokenAudience(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token audience is not valid") + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) +} + +func TestWebAPILoginMock(t *testing.T) { + _, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername+"1", defaultPassword) + assert.Error(t, err) + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword+"1") + assert.Error(t, err) + apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + // a web token is not valid for API usage + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token audience is not valid") + + req, err = http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // API token is not valid for web usage + req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + setJWTCookieForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + + req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebClientLoginMock(t *testing.T) { + _, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + // a web token is not valid for API or WebAdmin usage + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token audience is not valid") + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + // bearer should not work + req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + setBearerForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + req, _ = http.NewRequest(http.MethodGet, webClientPingPath, nil) + req.RemoteAddr = defaultRemoteAddr + setBearerForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + // now try to render client pages + req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webClientPingPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now logout + req, _ = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + req, _ = http.NewRequest(http.MethodGet, webClientPingPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + // get a new token and use it after removing the user + webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + apiUserToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + setJWTCookieForReq(req, webToken) + req.RemoteAddr = defaultRemoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) + + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorDirListUser) + + form := make(url.Values) + form.Set("files", `[]`) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) + + req, _ = http.NewRequest(http.MethodGet, userDirsPath, nil) + setBearerForReq(req, apiUserToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") + + req, _ = http.NewRequest(http.MethodGet, userFilesPath, nil) + setBearerForReq(req, apiUserToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") + + req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer([]byte(`{}`))) + setBearerForReq(req, apiUserToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") + + form = make(url.Values) + form.Set("public_keys", testPubKey) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) +} + +func TestWebClientLoginErrorsMock(t *testing.T) { + form := getLoginForm("", "", "") + req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + form = getLoginForm(defaultUsername, defaultPassword, "") + req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) +} + +func TestWebClientMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // now add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), "", nil) + connection := &httpd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), + } + err = common.Connections.Add(connection) + assert.NoError(t, err) + + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) + + common.Connections.Remove(connection.GetID()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxTotalConnections = oldValue +} + +func TestTokenInvalidIPAddress(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + req.RemoteAddr = "1.1.1.2" + rr := executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + + apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) + assert.NoError(t, err) + req.RemoteAddr = "2.2.2.2" + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token is not valid") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDefender(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Threshold = 3 + cfg.DefenderConfig.ScoreLimitExceeded = 2 + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + remoteAddr := "172.16.5.6:9876" + + webAdminToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + req.RemoteAddr = remoteAddr + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + for i := 0; i < 3; i++ { + _, err = getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, "wrong pwd", remoteAddr) + assert.Error(t, err) + } + + _, err = getJWTWebClientTokenFromTestServerWithAddr(defaultUsername, defaultPassword, remoteAddr) + assert.Error(t, err) + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RequestURI = webClientFilesPath + setJWTCookieForReq(req, webToken) + req.RemoteAddr = remoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorIPForbidden) + + req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil) + req.RequestURI = webUsersPath + setJWTCookieForReq(req, webAdminToken) + req.RemoteAddr = remoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorIPForbidden) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.Header.Set("X-Real-IP", "127.0.0.1:2345") + setJWTCookieForReq(req, webToken) + req.RemoteAddr = remoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "your IP address is blocked") + // requests for static files should be always allowed + req, err = http.NewRequest(http.MethodGet, "/static/favicon.png", nil) + assert.NoError(t, err) + req.RemoteAddr = remoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Empty(t, rr.Header().Get("Cache-Control")) + + req, err = http.NewRequest(http.MethodGet, "/.well-known/acme-challenge/foo", nil) + assert.NoError(t, err) + req.RemoteAddr = remoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Equal(t, "no-cache, no-store, max-age=0, must-revalidate, private", rr.Header().Get("Cache-Control")) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestPostConnectHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + common.Config.PostConnectHook = postConnectPath + + u := getTestUser() + u.Filters.AllowAPIKeyAuth = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + apiKey, _, err := httpdtest.AddAPIKey(dataprovider.APIKey{ + Name: "name", + Scope: dataprovider.APIKeyScopeUser, + User: user.Username, + }, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.PostConnectHook = "" +} + +func TestMaxSessions(t *testing.T) { + u := getTestUser() + u.MaxSessions = 1 + u.Email = "user@session.com" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + apiToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + // now add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), "", nil) + connection := &httpd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), + } + err = common.Connections.Add(connection) + assert.NoError(t, err) + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + // try an user API call + req, err := http.NewRequest(http.MethodGet, userDirsPath+"/?path=%2F", nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + // web client requests + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("files", `[]`) + req, err = http.NewRequest(http.MethodPost, webClientDownloadZipPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorDirList429) + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=p", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=file", nil) //nolint:goconst + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=file", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + // test reset password + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err = smtpCfg.Initialize(configDir, true) + assert.NoError(t, err) + + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + lastResetCode = "" + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("password", defaultPassword) + form.Set("confirm_password", defaultPassword) + form.Set("code", lastResetCode) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + common.Connections.Remove(connection.GetID()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeReadWrite, + Paths: []string{"/"}, + Password: defaultPassword, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + fileName := "testfile.txt" + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+fileName, bytes.NewBuffer([]byte(" "))) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + conn, sftpClient, err := getSftpClient(user) + assert.NoError(t, err) + defer conn.Close() + defer sftpClient.Close() + + f1, err := sftpClient.Create("file1") + assert.NoError(t, err) + f2, err := sftpClient.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "filepre") + assert.NoError(t, err) + _, err = part.Write([]byte("file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + + req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+fileName, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError403Message) + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+fileName, bytes.NewBuffer([]byte(" "))) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + body = new(bytes.Buffer) + writer = multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file11.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file11 content")) + assert.NoError(t, err) + part2, err := writer.CreateFormFile("filenames", "file22.txt") + assert.NoError(t, err) + _, err = part2.Write([]byte("file22 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader = bytes.NewReader(body.Bytes()) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusConflict, rr) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return len(common.Connections.GetStats("")) == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + assert.Eventually(t, func() bool { + return common.Connections.GetTotalTransfers() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestWebConfigsMock(t *testing.T) { + acmeConfig := config.GetACMEConfig() + acmeConfig.CertsPath = filepath.Clean(os.TempDir()) + err := acme.Initialize(acmeConfig, configDir, true) + require.NoError(t, err) + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webConfigsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form := make(url.Values) + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // parse form error + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webConfigsPath+"?p=p%C3%AO%GH", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // save SFTP configs + form.Set("sftp_host_key_algos", ssh.KeyAlgoRSA) + form.Add("sftp_host_key_algos", ssh.InsecureCertAlgoDSAv01) //nolint:staticcheck + form.Set("sftp_pub_key_algos", ssh.InsecureKeyAlgoDSA) //nolint:staticcheck + form.Set("form_action", "sftp_submit") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) // invalid algo + form.Set("sftp_host_key_algos", ssh.KeyAlgoRSA) + form.Add("sftp_host_key_algos", ssh.CertAlgoRSAv01) + form.Set("sftp_pub_key_algos", ssh.InsecureKeyAlgoDSA) //nolint:staticcheck + form.Set("sftp_kex_algos", "diffie-hellman-group18-sha512") + form.Add("sftp_kex_algos", ssh.KeyExchangeDH16SHA512) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + // check SFTP configs + configs, err := dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) + assert.Contains(t, configs.SFTPD.HostKeyAlgos, ssh.KeyAlgoRSA) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) + assert.Contains(t, configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck + assert.Len(t, configs.SFTPD.KexAlgorithms, 1) + assert.Contains(t, configs.SFTPD.KexAlgorithms, ssh.KeyExchangeDH16SHA512) + // invalid form action + form.Set("form_action", "") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + // test SMTP configs + form.Set("form_action", "smtp_submit") + form.Set("smtp_host", "mail.example.net") + form.Set("smtp_from", "Example ") + form.Set("smtp_username", defaultUsername) + form.Set("smtp_password", defaultPassword) + form.Set("smtp_domain", "localdomain") + form.Set("smtp_auth", "100") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) // invalid smtp_auth + // set valid parameters + form.Set("smtp_port", "a") // converted to 587 + form.Set("smtp_auth", "1") + form.Set("smtp_encryption", "2") + form.Set("smtp_debug", "checked") + form.Set("smtp_oauth2_provider", "1") + form.Set("smtp_oauth2_client_id", "123") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + // check + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) + assert.Contains(t, configs.SFTPD.HostKeyAlgos, ssh.KeyAlgoRSA) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) + assert.Contains(t, configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck + assert.Equal(t, "mail.example.net", configs.SMTP.Host) + assert.Equal(t, 587, configs.SMTP.Port) + assert.Equal(t, "Example ", configs.SMTP.From) + assert.Equal(t, defaultUsername, configs.SMTP.User) + assert.Equal(t, 1, configs.SMTP.Debug) + assert.Equal(t, "", configs.SMTP.OAuth2.ClientID) + err = configs.SMTP.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, defaultPassword, configs.SMTP.Password.GetPayload()) + assert.Equal(t, 1, configs.SMTP.AuthType) + assert.Equal(t, 2, configs.SMTP.Encryption) + assert.Equal(t, "localdomain", configs.SMTP.Domain) + // set a redacted password, the current password must be preserved + form.Set("smtp_password", redactedSecret) + form.Set("smtp_auth", "") + configs.SMTP.AuthType = 0 // empty will be converted to 0 + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + updatedConfigs, err := dataprovider.GetConfigs() + assert.NoError(t, err) + encryptedPayload := updatedConfigs.SMTP.Password.GetPayload() + secretKey := updatedConfigs.SMTP.Password.GetKey() + err = updatedConfigs.SMTP.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, configs.SFTPD, updatedConfigs.SFTPD) + assert.Equal(t, configs.SMTP, updatedConfigs.SMTP) + // now set an undecryptable password + updatedConfigs.SMTP.Password = kms.NewSecret(sdkkms.SecretStatusSecretBox, encryptedPayload, secretKey, "") + err = dataprovider.UpdateConfigs(&updatedConfigs, "", "", "") + assert.NoError(t, err) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + form.Set("form_action", "acme_submit") + form.Set("acme_port", "") // on error will be set to 80 + form.Set("acme_protocols", "1") + form.Add("acme_protocols", "2") + form.Add("acme_protocols", "3") + form.Set("acme_domain", "example.com") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + // no email set, validation will fail + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidEmail) + form.Set("acme_domain", "") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + // check + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) + assert.Contains(t, configs.SFTPD.HostKeyAlgos, ssh.KeyAlgoRSA) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) + assert.Contains(t, configs.SFTPD.PublicKeyAlgos, ssh.InsecureKeyAlgoDSA) //nolint:staticcheck + assert.Equal(t, 80, configs.ACME.HTTP01Challenge.Port) + assert.Equal(t, 7, configs.ACME.Protocols) + assert.Empty(t, configs.ACME.Domain) + assert.Empty(t, configs.ACME.Email) + assert.True(t, configs.ACME.HasProtocol(common.ProtocolFTP)) + assert.True(t, configs.ACME.HasProtocol(common.ProtocolWebDAV)) + assert.True(t, configs.ACME.HasProtocol(common.ProtocolHTTP)) + // create certificate files, so no real ACME call is done + domain := "acme.example.com" + crtPath := filepath.Join(acmeConfig.CertsPath, util.SanitizeDomain(domain)+".crt") + keyPath := filepath.Join(acmeConfig.CertsPath, util.SanitizeDomain(domain)+".key") + err = os.WriteFile(crtPath, nil, 0666) + assert.NoError(t, err) + err = os.WriteFile(keyPath, nil, 0666) + assert.NoError(t, err) + form.Set("acme_port", "402") + form.Set("acme_protocols", "1") + form.Add("acme_protocols", "1000") + form.Set("acme_domain", domain) + form.Set("acme_email", "email@example.com") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.SFTPD.HostKeyAlgos, 1) + assert.Len(t, configs.SFTPD.PublicKeyAlgos, 1) + assert.Equal(t, 402, configs.ACME.HTTP01Challenge.Port) + assert.Equal(t, 1, configs.ACME.Protocols) + assert.Equal(t, domain, configs.ACME.Domain) + assert.Equal(t, "email@example.com", configs.ACME.Email) + assert.False(t, configs.ACME.HasProtocol(common.ProtocolFTP)) + assert.False(t, configs.ACME.HasProtocol(common.ProtocolWebDAV)) + assert.True(t, configs.ACME.HasProtocol(common.ProtocolHTTP)) + + err = os.Remove(crtPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestBrandingConfigMock(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + + webClientLogoPath := "/static/branding/webclient/logo.png" + webClientFaviconPath := "/static/branding/webclient/favicon.png" + webAdminLogoPath := "/static/branding/webadmin/logo.png" + webAdminFaviconPath := "/static/branding/webadmin/favicon.png" + // no custom log or favicon was set + for _, p := range []string{webClientLogoPath, webClientFaviconPath, webAdminLogoPath, webAdminFaviconPath} { + req, err := http.NewRequest(http.MethodGet, p, nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + } + + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, webToken) + assert.NoError(t, err) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("form_action", "branding_submit") + form.Set("branding_webadmin_name", "Custom WebAdmin") + form.Set("branding_webadmin_short_name", "WebAdmin") + form.Set("branding_webadmin_disclaimer_name", "Admin disclaimer") + form.Set("branding_webadmin_disclaimer_url", "invalid, not a URL") + form.Set("branding_webclient_name", "Custom WebClient") + form.Set("branding_webclient_short_name", "WebClient") + form.Set("branding_webclient_disclaimer_name", "Client disclaimer") + form.Set("branding_webclient_disclaimer_url", "https://example.com") + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidDisclaimerURL) + + form.Set("branding_webadmin_disclaimer_url", "https://example.net") + tmpFile := filepath.Join(os.TempDir(), util.GenerateUniqueID()+".png") + err = createTestPNG(tmpFile, 512, 512, color.RGBA{100, 200, 200, 0xff}) + assert.NoError(t, err) + + b, contentType, err = getMultipartFormData(form, "branding_webadmin_logo", tmpFile) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + // check + configs, err := dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Equal(t, "Custom WebAdmin", configs.Branding.WebAdmin.Name) + assert.Equal(t, "WebAdmin", configs.Branding.WebAdmin.ShortName) + assert.Equal(t, "Admin disclaimer", configs.Branding.WebAdmin.DisclaimerName) + assert.Equal(t, "https://example.net", configs.Branding.WebAdmin.DisclaimerURL) + assert.Equal(t, "Custom WebClient", configs.Branding.WebClient.Name) + assert.Equal(t, "WebClient", configs.Branding.WebClient.ShortName) + assert.Equal(t, "Client disclaimer", configs.Branding.WebClient.DisclaimerName) + assert.Equal(t, "https://example.com", configs.Branding.WebClient.DisclaimerURL) + assert.Greater(t, len(configs.Branding.WebAdmin.Logo), 0) + assert.Len(t, configs.Branding.WebAdmin.Favicon, 0) + assert.Len(t, configs.Branding.WebClient.Logo, 0) + assert.Len(t, configs.Branding.WebClient.Favicon, 0) + + err = createTestPNG(tmpFile, 256, 256, color.RGBA{120, 220, 220, 0xff}) + assert.NoError(t, err) + form.Set("branding_webadmin_logo_remove", "0") // 0 preserves WebAdmin logo + b, contentType, err = getMultipartFormData(form, "branding_webadmin_favicon", tmpFile) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Equal(t, "Custom WebAdmin", configs.Branding.WebAdmin.Name) + assert.Equal(t, "WebAdmin", configs.Branding.WebAdmin.ShortName) + assert.Equal(t, "Admin disclaimer", configs.Branding.WebAdmin.DisclaimerName) + assert.Equal(t, "https://example.net", configs.Branding.WebAdmin.DisclaimerURL) + assert.Equal(t, "Custom WebClient", configs.Branding.WebClient.Name) + assert.Equal(t, "WebClient", configs.Branding.WebClient.ShortName) + assert.Equal(t, "Client disclaimer", configs.Branding.WebClient.DisclaimerName) + assert.Equal(t, "https://example.com", configs.Branding.WebClient.DisclaimerURL) + assert.Greater(t, len(configs.Branding.WebAdmin.Logo), 0) + assert.Greater(t, len(configs.Branding.WebAdmin.Favicon), 0) + assert.Len(t, configs.Branding.WebClient.Logo, 0) + assert.Len(t, configs.Branding.WebClient.Favicon, 0) + + err = createTestPNG(tmpFile, 256, 256, color.RGBA{80, 90, 110, 0xff}) + assert.NoError(t, err) + b, contentType, err = getMultipartFormData(form, "branding_webclient_logo", tmpFile) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Greater(t, len(configs.Branding.WebClient.Logo), 0) + + err = createTestPNG(tmpFile, 256, 256, color.RGBA{120, 50, 120, 0xff}) + assert.NoError(t, err) + b, contentType, err = getMultipartFormData(form, "branding_webclient_favicon", tmpFile) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Greater(t, len(configs.Branding.WebClient.Favicon), 0) + + for _, p := range []string{webClientLogoPath, webClientFaviconPath, webAdminLogoPath, webAdminFaviconPath} { + req, err := http.NewRequest(http.MethodGet, p, nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + } + // remove images + form.Set("branding_webadmin_logo_remove", "1") + form.Set("branding_webclient_logo_remove", "1") + form.Set("branding_webadmin_favicon_remove", "1") + form.Set("branding_webclient_favicon_remove", "1") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nConfigsOK) + configs, err = dataprovider.GetConfigs() + assert.NoError(t, err) + assert.Len(t, configs.Branding.WebAdmin.Logo, 0) + assert.Len(t, configs.Branding.WebAdmin.Favicon, 0) + assert.Len(t, configs.Branding.WebClient.Logo, 0) + assert.Len(t, configs.Branding.WebClient.Favicon, 0) + for _, p := range []string{webClientLogoPath, webClientFaviconPath, webAdminLogoPath, webAdminFaviconPath} { + req, err := http.NewRequest(http.MethodGet, p, nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + } + form.Del("branding_webadmin_logo_remove") + form.Del("branding_webclient_logo_remove") + form.Del("branding_webadmin_favicon_remove") + form.Del("branding_webclient_favicon_remove") + // image too large + err = createTestPNG(tmpFile, 768, 512, color.RGBA{120, 50, 120, 0xff}) + assert.NoError(t, err) + b, contentType, err = getMultipartFormData(form, "branding_webclient_logo", tmpFile) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidPNGSize) + // not a png image + err = createTestFile(tmpFile, 128) + assert.NoError(t, err) + b, contentType, err = getMultipartFormData(form, "branding_webclient_logo", tmpFile) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webConfigsPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidPNG) + + err = os.Remove(tmpFile) + assert.NoError(t, err) + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestSFTPLoopError(t *testing.T) { + user1 := getTestUser() + user2 := getTestUser() + user1.Username += "1" + user1.Email = "user1@test.com" + user2.Username += "2" + user1.FsConfig = vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user2.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + } + + user2.FsConfig.Provider = sdk.SFTPFilesystemProvider + user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user1.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + + user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + // test reset password + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err = smtpCfg.Initialize(configDir, true) + assert.NoError(t, err) + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user1.Username) + lastResetCode = "" + req, err := http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("password", defaultPassword) + form.Set("confirm_password", defaultPassword) + form.Set("code", lastResetCode) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorLoginAfterReset) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginInvalidFs(t *testing.T) { + u := getTestUser() + u.Filters.AllowAPIKeyAuth = true + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.UploadPartSize = 1 + u.FsConfig.GCSConfig.UploadPartMaxTime = 10 + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + apiKey, _, err := httpdtest.AddAPIKey(dataprovider.APIKey{ + Name: "testk", + Scope: dataprovider.APIKeyScopeUser, + User: u.Username, + }, http.StatusCreated) + assert.NoError(t, err) + + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + + _, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebClientChangePwd(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webChangeClientPwdPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form := make(url.Values) + form.Set("current_password", defaultPassword) + form.Set("new_password1", defaultPassword) + form.Set("new_password2", defaultPassword) + // no csrf token + req, err = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, webToken) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoDifferent) + + form.Set("current_password", defaultPassword+"2") + form.Set("new_password1", defaultPassword+"1") + form.Set("new_password2", defaultPassword+"1") + req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdCurrentNoMatch) + + form.Set("current_password", defaultPassword) + form.Set("new_password1", defaultPassword+"1") + form.Set("new_password2", defaultPassword+"1") + req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + + req, err = http.NewRequest(http.MethodGet, webClientPingPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.Error(t, err) + _, err = getJWTWebClientTokenFromTestServer(defaultUsername+"1", defaultPassword+"1") + assert.Error(t, err) + _, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword+"1") + assert.NoError(t, err) + + // remove the change password permission + user.Filters.WebClient = []string{sdk.WebClientPasswordChangeDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Len(t, user.Filters.WebClient, 1) + assert.Contains(t, user.Filters.WebClient, sdk.WebClientPasswordChangeDisabled) + + webToken, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword+"1") + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("current_password", defaultPassword+"1") + form.Set("new_password1", defaultPassword) + form.Set("new_password2", defaultPassword) + req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPreDownloadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} + common.Config.Actions.Hook = preActionPath + + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preActionPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + + testFileName := "testfile" + testFileContents := []byte("file contents") + err = os.MkdirAll(filepath.Join(user.GetHomeDir()), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), testFileContents, os.ModePerm) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) //nolint:goconst + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + + err = os.WriteFile(preActionPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError403Message) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "permission denied") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPreUploadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} + common.Config.Actions.Hook = preActionPath + + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preActionPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "filepre") + assert.NoError(t, err) + _, err = part.Write([]byte("file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=filepre", + bytes.NewBuffer([]byte("single upload content"))) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + err = os.WriteFile(preActionPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=filepre", + bytes.NewBuffer([]byte("single upload content"))) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestShareUsage(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + testFileName := "testfile.dat" + testFileSize := int64(65536) + testFilePath := filepath.Join(user.GetHomeDir(), testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Password: defaultPassword, + MaxTokens: 2, + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Second)), + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/unknownid", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + req.SetBasicAuth(defaultUsername, "wrong password") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + time.Sleep(2 * time.Second) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"_mod", nil) + assert.NoError(t, err) + req.RequestURI = webClientPubSharesPath + "/" + objectID + "_mod" + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + share.ExpiresAt = 0 + jsonReq := make(map[string]any) + jsonReq["name"] = share.Name + jsonReq["scope"] = share.Scope + jsonReq["paths"] = share.Paths + jsonReq["password"] = share.Password + jsonReq["max_tokens"] = share.MaxTokens + jsonReq["expires_at"] = share.ExpiresAt + asJSON, err = json.Marshal(jsonReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userSharesPath+"/"+objectID, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "invalid share scope") + + share.MaxTokens = 3 + share.Scope = dataprovider.ShareScopeWrite + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userSharesPath+"/"+objectID, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file1.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file1 content")) + assert.NoError(t, err) + part2, err := writer.CreateFormFile("filenames", "file2.txt") + assert.NoError(t, err) + _, err = part2.Write([]byte("file2 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Unable to parse multipart form") + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + // set the proper content type + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Allowed usage exceeded") + + share.MaxTokens = 6 + share.Scope = dataprovider.ShareScopeWrite + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, userSharesPath+"/"+objectID, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + share, err = dataprovider.ShareExists(objectID, user.Username) + assert.NoError(t, err) + assert.Equal(t, 6, share.UsedTokens) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + share.MaxTokens = 0 + err = dataprovider.UpdateShare(&share, user.Username, "", "") + assert.NoError(t, err) + + user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "permission denied") + + user.Permissions["/"] = []string{dataprovider.PermAny} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + body = new(bytes.Buffer) + writer = multipart.NewWriter(body) + part, err := writer.CreateFormFile("filename", "file1.txt") + assert.NoError(t, err) + _, err = part.Write([]byte("file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader = bytes.NewReader(body.Bytes()) + + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "No files uploaded!") + + user.Filters.WebClient = []string{sdk.WebClientSharesDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + user.Filters.WebClient = []string{sdk.WebClientShareNoPasswordDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + share.Password = "" + err = dataprovider.UpdateShare(&share, user.Username, "", "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "sharing without a password was disabled") + + user.Filters.WebClient = []string{sdk.WebClientInfoChangeDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + share.Scope = dataprovider.ShareScopeRead + share.Paths = []string{"/missing1", "/missing2"} + err = dataprovider.UpdateShare(&share, user.Username, "", "") + assert.NoError(t, err) + + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + + share, err = dataprovider.ShareExists(objectID, user.Username) + assert.NoError(t, err) + assert.Equal(t, 6, share.UsedTokens) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + }() + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + executeRequest(req) +} + +func TestSharePasswordPolicy(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.PasswordStrength = 70 + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: g.Name, + Type: sdk.GroupTypePrimary, + }, + } + u.Password = rand.Text() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, u.Password) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: util.GenerateUniqueID(), + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Password: defaultPassword, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "insecure password") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestShareMaxExpiration(t *testing.T) { + u := getTestUser() + u.Filters.MaxSharesExpiration = 5 + u.Filters.DefaultSharesExpiration = 10 + _, resp, err := httpdtest.AddUser(u, http.StatusBadRequest) + assert.NoError(t, err) + assert.Contains(t, string(resp), "must be less than or equal to max shares expiration") + + u.Filters.DefaultSharesExpiration = 0 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webClientToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + s := dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeRead, + Password: defaultPassword, + Paths: []string{"/"}, + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(u.Filters.MaxSharesExpiration+2))), + } + asJSON, err := json.Marshal(s) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "share must expire before") + + req, err = http.NewRequest(http.MethodPut, path.Join(userSharesPath, "shareID"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // expiresAt is mandatory + s.ExpiresAt = 0 + asJSON, err = json.Marshal(s) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "share must expire before") + + s.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(2 * time.Hour)) + asJSON, err = json.Marshal(s) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + shareID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, shareID) + + s.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(u.Filters.MaxSharesExpiration+2))) + asJSON, err = json.Marshal(s) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, path.Join(userSharesPath, shareID), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "share must expire before") + + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientSharePath, webClientToken) + assert.NoError(t, err) + form := make(url.Values) + form.Set("name", s.Name) + form.Set("scope", strconv.Itoa(int(s.Scope))) + form.Set("max_tokens", "0") + form.Set("paths[0][path]", "/") + form.Set("expiration_date", time.Now().Add(24*time.Hour*time.Duration(u.Filters.MaxSharesExpiration+2)).Format("2006-01-02 15:04:05")) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webClientToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpirationOutOfRange) + + req, err = http.NewRequest(http.MethodPost, path.Join(webClientSharePath, shareID), bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webClientToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpirationOutOfRange) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webClientToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorGetUser) +} + +func TestWebClientShareCredentials(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + shareRead := dataprovider.Share{ + Name: "test share read", + Scope: dataprovider.ShareScopeRead, + Password: defaultPassword, + Paths: []string{"/"}, + } + + shareWrite := dataprovider.Share{ + Name: "test share write", + Scope: dataprovider.ShareScopeReadWrite, + Password: defaultPassword, + Paths: []string{"/"}, + } + + asJSON, err := json.Marshal(shareRead) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + shareReadID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, shareReadID) + + asJSON, err = json.Marshal(shareWrite) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + shareWriteID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, shareWriteID) + + uri := path.Join(webClientPubSharesPath, shareReadID, "browse") + req, err = http.NewRequest(http.MethodGet, uri, nil) + assert.NoError(t, err) + req.RequestURI = uri + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + location := rr.Header().Get("Location") + assert.Contains(t, location, url.QueryEscape(uri)) + // get the login form + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + req.RequestURI = uri + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now set the user token, it is not valid for the share + req, err = http.NewRequest(http.MethodGet, uri, nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + // get a share token + form := make(url.Values) + form.Set("share_password", defaultPassword) + loginURI := path.Join(webClientPubSharesPath, shareReadID, "login") + req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + // set the CSRF token + loginCookie, csrfToken, err := getCSRFTokenMock(loginURI, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nShareLoginOK) + cookie := rr.Header().Get("Set-Cookie") + cookie = strings.TrimPrefix(cookie, "jwt=") + assert.NotEmpty(t, cookie) + req, err = http.NewRequest(http.MethodGet, uri, nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // get the download page + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareReadID, "download?a=b"), nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // get the download page for a missing share + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, "invalidshareid", "download"), nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // the same cookie will not work for the other share + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareWriteID, "browse"), nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + // IP address does not match + req, err = http.NewRequest(http.MethodGet, uri, nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + req.RemoteAddr = "1.2.3.4:1234" + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + // logout to a different share, the cookie is not valid. + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareWriteID, "logout"), nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + // logout + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareReadID, "logout"), nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + // the cookie is no longer valid + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, shareReadID, "download?b=c"), nil) + assert.NoError(t, err) + req.RequestURI = uri + setJWTCookieForReq(req, cookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Contains(t, rr.Header().Get("Location"), "/login") + + // try to login with invalid credentials + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("share_password", "") + req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + // login with the next param set + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("share_password", defaultPassword) + nextURI := path.Join(webClientPubSharesPath, shareReadID, "browse") + loginURI = path.Join(webClientPubSharesPath, shareReadID, fmt.Sprintf("login?next=%s", url.QueryEscape(nextURI))) + req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, nextURI, rr.Header().Get("Location")) + // try to login to a missing share + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + loginURI = path.Join(webClientPubSharesPath, "missing", "login") + req, err = http.NewRequest(http.MethodPost, loginURI, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestShareMaxSessions(t *testing.T) { + u := getTestUser() + u.MaxSessions = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share max sessions read", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil) //nolint:goconst + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), "", nil) + connection := &httpd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolHTTP, "", "", user), + } + err = common.Connections.Add(connection) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/dirs", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/dirs", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=file.pdf"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"/browse", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID+"/browse/exist", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "invalid share scope") + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID+"/files?path=afile", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + form := make(url.Values) + form.Set("files", `[]`) + req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID+"/partial", + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), util.I18nError429Message) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + req, err = http.NewRequest(http.MethodDelete, userSharesPath+"/"+objectID, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // now test a write share + share = dataprovider.Share{ + Name: "test share max sessions write", + Scope: dataprovider.ShareScopeWrite, + Paths: []string{"/"}, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer([]byte("content"))) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file1.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file1 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + share = dataprovider.Share{ + Name: "test share max sessions read&write", + Scope: dataprovider.ShareScopeReadWrite, + Paths: []string{"/"}, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodPost, webClientPubSharesPath+"/"+objectID+"/browse/exist", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusTooManyRequests, rr) + assert.Contains(t, rr.Body.String(), "too many open sessions") + + common.Connections.Remove(connection.GetID()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestShareUploadSingle(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeWrite, + Paths: []string{"/"}, + Password: defaultPassword, + MaxTokens: 0, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + content := []byte("shared file content") + modTime := time.Now().Add(-12 * time.Hour) + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file.txt"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + req.Header.Set("X-SFTPGO-MTIME", strconv.FormatInt(util.GetTimeAsMsSinceEpoch(modTime), 10)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + info, err := os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) + if assert.NoError(t, err) { + assert.InDelta(t, util.GetTimeAsMsSinceEpoch(modTime), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(1000)) + } + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "upload"), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "file.txt"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + info, err = os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) + if assert.NoError(t, err) { + assert.InDelta(t, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(3000)) + } + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "dir", "file.dat"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "%2F"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + + err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "dir"), os.ModePerm) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "dir"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + + // only uploads to the share root dir are allowed + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "dir", "file.dat"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + share, err = dataprovider.ShareExists(objectID, user.Username) + assert.NoError(t, err) + assert.Equal(t, 2, share.UsedTokens) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, "file1.txt"), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestShareReadWrite(t *testing.T) { + u := getTestUser() + u.Filters.StartDirectory = path.Join("/start", "dir") + u.Permissions["/start/dir/limited"] = []string{dataprovider.PermListItems} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + testFileName := "test.txt" + testSubDirs := "/sub/dir" + + share := dataprovider.Share{ + Name: "test share rw", + Scope: dataprovider.ShareScopeReadWrite, + Paths: []string{user.Filters.StartDirectory}, + Password: defaultPassword, + MaxTokens: 0, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + filesToCheck := make(map[string]any) + filesToCheck["files"] = []string{testFileName} + asJSON, err = json.Marshal(filesToCheck) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var fileList []any + err = json.Unmarshal(rr.Body.Bytes(), &fileList) + assert.NoError(t, err) + assert.Len(t, fileList, 0) + + content := []byte("shared rw content") + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID, testFileName), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + assert.FileExists(t, filepath.Join(user.GetHomeDir(), user.Filters.StartDirectory, testFileName)) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape(path.Join(testSubDirs, testFileName)), bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape(path.Join(testSubDirs, testFileName))+"?mkdir_parents=true", + bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + assert.FileExists(t, filepath.Join(user.GetHomeDir(), user.Filters.StartDirectory, testSubDirs, testFileName)) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape(path.Join("limited", "sub", testFileName))+"?mkdir_parents=true", + bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + fileList = nil + err = json.Unmarshal(rr.Body.Bytes(), &fileList) + assert.NoError(t, err) + assert.Len(t, fileList, 1) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileName), nil) //nolint:goconst + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contentDisposition := rr.Header().Get("Content-Disposition") + assert.NotEmpty(t, contentDisposition) + + form := make(url.Values) + form.Set("files", fmt.Sprintf(`["%s"]`, testFileName)) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contentDisposition = rr.Header().Get("Content-Disposition") + assert.NotEmpty(t, contentDisposition) + assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) + // parse form error + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=p%C3%AO%GK"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // invalid files list + form.Set("files", fmt.Sprintf(`[%s]`, testFileName)) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + // missing directory + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=missing"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape("../"+testFileName), + bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Uploading outside the share is not allowed") + + req, err = http.NewRequest(http.MethodPost, path.Join(sharesPath, objectID)+"/"+url.PathEscape("/../../"+testFileName), + bytes.NewBuffer(content)) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Uploading outside the share is not allowed") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestShareUncompressed(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + testFileName := "testfile.dat" + testFileSize := int64(65536) + testFilePath := filepath.Join(user.GetHomeDir(), testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + Password: defaultPassword, + MaxTokens: 0, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + s, err := dataprovider.ShareExists(objectID, defaultUsername) + assert.NoError(t, err) + assert.Equal(t, int64(0), s.ExpiresAt) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) //nolint:goconst + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) + + share = dataprovider.Share{ + Name: "test share1", + Scope: dataprovider.ShareScopeRead, + Paths: []string{testFileName}, + Password: defaultPassword, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID, nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type")) + + share, err = dataprovider.ShareExists(objectID, user.Username) + assert.NoError(t, err) + assert.Equal(t, 2, share.UsedTokens) + + user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + share, err = dataprovider.ShareExists(objectID, user.Username) + assert.NoError(t, err) + assert.Equal(t, 2, share.UsedTokens) + + user.Permissions["/"] = []string{dataprovider.PermAny} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDownloadFromShareError(t *testing.T) { + u := getTestUser() + u.DownloadDataTransfer = 1 + u.Filters.DefaultSharesExpiration = 10 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.UsedDownloadDataTransfer = 1024*1024 - 32768 + _, err = httpdtest.UpdateTransferQuotaUsage(user, "add", http.StatusOK) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(1024*1024-32768), user.UsedDownloadDataTransfer) + testFileName := "test_share_file.dat" + testFileSize := int64(524288) + testFilePath := filepath.Join(user.GetHomeDir(), testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + share := dataprovider.Share{ + Name: "test share root browse", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + MaxTokens: 2, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + s, err := dataprovider.ShareExists(objectID, defaultUsername) + assert.NoError(t, err) + assert.Greater(t, s.ExpiresAt, int64(0)) + + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + + share, err = dataprovider.ShareExists(objectID, user.Username) + assert.NoError(t, err) + assert.Equal(t, 0, share.UsedTokens) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + }() + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileName), nil) + assert.NoError(t, err) + executeRequest(req) +} + +func TestBrowseShares(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + testFileName := "testsharefile.dat" + testFileNameLink := "testsharefile.link" + shareDir := "share" + subDir := "sub" + testFileSize := int64(65536) + testFilePath := filepath.Join(user.GetHomeDir(), shareDir, testFileName) + testLinkPath := filepath.Join(user.GetHomeDir(), shareDir, testFileNameLink) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(filepath.Join(user.GetHomeDir(), shareDir, subDir, testFileName), testFileSize) + assert.NoError(t, err) + err = os.Symlink(testFilePath, testLinkPath) + assert.NoError(t, err) + + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share browse", + Scope: dataprovider.ShareScopeRead, + Paths: []string{shareDir}, + MaxTokens: 0, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "upload"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "invalid share scope") + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Please set the path to a valid file") + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents := make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"+subDir), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &contents) + assert.NoError(t, err) + assert.Len(t, contents, 1) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F.."), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPathInvalid) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F.."), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F.."), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path=%2F..%2F"+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contentDisposition := rr.Header().Get("Content-Disposition") + assert.NotEmpty(t, contentDisposition) + + form := make(url.Values) + form.Set("files", `[]`) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=%2F.."), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPathInvalid) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileName), nil) //nolint:goconst + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contentDisposition = rr.Header().Get("Content-Disposition") + assert.NotEmpty(t, contentDisposition) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+subDir+"%2F"+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contentDisposition = rr.Header().Get("Content-Disposition") + assert.NotEmpty(t, contentDisposition) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=missing"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path=missing"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=missing"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=missing"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path="+testFileNameLink), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileNameLink), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "non regular files are not supported for shares") + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path="+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=missing"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=%2F.."), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPathInvalid) + + fakePDF := []byte(`%PDF-1.6`) + for i := 0; i < 128; i++ { + fakePDF = append(fakePDF, []byte(fmt.Sprintf("%d", i))...) + } + pdfPath := filepath.Join(user.GetHomeDir(), shareDir, "test.pdf") + pdfLinkPath := filepath.Join(user.GetHomeDir(), shareDir, "link.pdf") + err = os.WriteFile(pdfPath, fakePDF, 0666) + assert.NoError(t, err) + err = os.Symlink(pdfPath, pdfLinkPath) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "viewpdf?path=test.pdf"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=test.pdf"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + s, err := dataprovider.ShareExists(objectID, defaultUsername) + assert.NoError(t, err) + usedTokens := s.UsedTokens + assert.Greater(t, usedTokens, 0) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=link.pdf"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // downloading a symlink will fail, usage should not change + s, err = dataprovider.ShareExists(objectID, defaultUsername) + assert.NoError(t, err) + assert.Equal(t, usedTokens, s.UsedTokens) + + // share a symlink + share = dataprovider.Share{ + Name: "test share browse", + Scope: dataprovider.ShareScopeRead, + Paths: []string{path.Join(shareDir, testFileNameLink)}, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + // uncompressed download should not work + req, err = http.NewRequest(http.MethodGet, webClientPubSharesPath+"/"+objectID+"?compress=false", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "application/zip", rr.Header().Get("Content-Type")) + // this share is not browsable, it does not contains a directory + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + form = make(url.Values) + form.Set("files", `[]`) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowseNoDir) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path="+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowseNoDir) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the shared object is not a directory and so it is not browsable") + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowseNoDir) + + // now test a missing shareID + objectID = "123456" + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "files?path="+testFileName), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + form = make(url.Values) + form.Set("files", `[]`) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "partial?path=%2F"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "viewpdf?path=p"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "getpdf?path=p"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // share a missing base path + share = dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeRead, + Paths: []string{path.Join(shareDir, "missingdir")}, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "unable to check the share directory") + // share multiple paths + share = dataprovider.Share{ + Name: "test share", + Scope: dataprovider.ShareScopeRead, + Paths: []string{shareDir, "/anotherdir"}, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareBrowsePaths) + + share = dataprovider.Share{ + Name: "test share rw", + Scope: dataprovider.ShareScopeReadWrite, + Paths: []string{"/missingdir"}, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F"), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "unable to check the share directory") + + share = dataprovider.Share{ + Name: "test share rw", + Scope: dataprovider.ShareScopeReadWrite, + Paths: []string{shareDir}, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/browse/exist?path=%2F.."), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Invalid path") + // share the root path + share = dataprovider.Share{ + Name: "test share root", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + MaxTokens: 0, + } + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &contents) + assert.NoError(t, err) + assert.Len(t, contents, 1) + // if we require two-factor auth for HTTP protocol the share should not work anymore + user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH, common.ProtocolHTTP} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, path.Join(sharesPath, objectID, "dirs?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "two-factor authentication requirements not met") + user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // share read/write + share.Scope = dataprovider.ShareScopeReadWrite + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "browse?path=%2F"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // on upload we should be redirected + req, err = http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, objectID, "upload"), nil) + assert.NoError(t, err) + req.SetBasicAuth(defaultUsername, defaultPassword) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + location := rr.Header().Get("Location") + assert.Equal(t, path.Join(webClientPubSharesPath, objectID, "browse"), location) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUserAPIShareErrors(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Scope: 1000, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "invalid scope") + // invalid json + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer([]byte("{"))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + share.Scope = dataprovider.ShareScopeWrite + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "at least a shared path is required") + + share.Paths = []string{"path1", "../path1", "/path2"} + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the write share scope requires exactly one path") + + share.Paths = []string{"", ""} + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "at least a shared path is required") + + share.Paths = []string{"path1", "../path1", "/path1"} + share.Password = redactedSecret + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "cannot save a share with a redacted password") + + share.Password = "newpass" + share.AllowFrom = []string{"not valid"} + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "could not parse allow from entry") + + share.AllowFrom = []string{"127.0.0.1/8"} + share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-12 * time.Hour)) + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "expiration must be in the future") + + share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(12 * time.Hour)) + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + location := rr.Header().Get("Location") + + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "name is mandatory") + // invalid json + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer([]byte("}"))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, userSharesPath+"?limit=a", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUserAPIShares(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + u := getTestUser() + u.Username = altAdminUsername + user1, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + token1, err := getJWTAPIUserTokenFromTestServer(user1.Username, defaultPassword) + assert.NoError(t, err) + + // the share username will be set from the claims + share := dataprovider.Share{ + Name: "share1", + Description: "description1", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + CreatedAt: 1, + UpdatedAt: 2, + LastUseAt: 3, + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(2 * time.Hour)), + Password: defaultPassword, + MaxTokens: 10, + UsedTokens: 2, + AllowFrom: []string{"192.168.1.0/24"}, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + location := rr.Header().Get("Location") + assert.NotEmpty(t, location) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + assert.Equal(t, fmt.Sprintf("%v/%v", userSharesPath, objectID), location) + + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var shareGet dataprovider.Share + err = json.Unmarshal(rr.Body.Bytes(), &shareGet) + assert.NoError(t, err) + assert.Equal(t, objectID, shareGet.ShareID) + assert.Equal(t, share.Name, shareGet.Name) + assert.Equal(t, share.Description, shareGet.Description) + assert.Equal(t, share.Scope, shareGet.Scope) + assert.Equal(t, share.Paths, shareGet.Paths) + assert.Equal(t, int64(0), shareGet.LastUseAt) + assert.Greater(t, shareGet.CreatedAt, share.CreatedAt) + assert.Greater(t, shareGet.UpdatedAt, share.UpdatedAt) + assert.Equal(t, share.ExpiresAt, shareGet.ExpiresAt) + assert.Equal(t, share.MaxTokens, shareGet.MaxTokens) + assert.Equal(t, 0, shareGet.UsedTokens) + assert.Equal(t, share.Paths, shareGet.Paths) + assert.Equal(t, redactedSecret, shareGet.Password) + + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token1) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + s, err := dataprovider.ShareExists(objectID, defaultUsername) + assert.NoError(t, err) + match, err := s.CheckCredentials(defaultPassword) + assert.True(t, match) + assert.NoError(t, err) + match, err = s.CheckCredentials(defaultPassword + "mod") + assert.False(t, match) + assert.Error(t, err) + + shareGet.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(3 * time.Hour)) + asJSON, err = json.Marshal(shareGet) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + s, err = dataprovider.ShareExists(objectID, defaultUsername) + assert.NoError(t, err) + match, err = s.CheckCredentials(defaultPassword) + assert.True(t, match) + assert.NoError(t, err) + match, err = s.CheckCredentials(defaultPassword + "mod") + assert.False(t, match) + assert.Error(t, err) + + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var shareGetNew dataprovider.Share + err = json.Unmarshal(rr.Body.Bytes(), &shareGetNew) + assert.NoError(t, err) + assert.NotEqual(t, shareGet.UpdatedAt, shareGetNew.UpdatedAt) + shareGet.UpdatedAt = shareGetNew.UpdatedAt + assert.Equal(t, shareGet, shareGetNew) + + req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var shares []dataprovider.Share + err = json.Unmarshal(rr.Body.Bytes(), &shares) + assert.NoError(t, err) + if assert.Len(t, shares, 1) { + assert.Equal(t, shareGetNew, shares[0]) + } + + err = dataprovider.UpdateShareLastUse(&shareGetNew, 2) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + shareGetNew = dataprovider.Share{} + err = json.Unmarshal(rr.Body.Bytes(), &shareGetNew) + assert.NoError(t, err) + assert.Equal(t, 2, shareGetNew.UsedTokens, "share: %v", shareGetNew) + assert.Greater(t, shareGetNew.LastUseAt, int64(0), "share: %v", shareGetNew) + + req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token1) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + shares = nil + err = json.Unmarshal(rr.Body.Bytes(), &shares) + assert.NoError(t, err) + assert.Len(t, shares, 0) + + // set an empty password + shareGet.Password = "" + asJSON, err = json.Marshal(shareGet) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + shareGetNew = dataprovider.Share{} + err = json.Unmarshal(rr.Body.Bytes(), &shareGetNew) + assert.NoError(t, err) + assert.Empty(t, shareGetNew.Password) + + req, err = http.NewRequest(http.MethodDelete, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + share.Name = "" + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + location = rr.Header().Get("Location") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + // the share should be deleted with the associated user + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodDelete, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUsersAPISharesNoPasswordDisabled(t *testing.T) { + u := getTestUser() + u.Filters.WebClient = []string{sdk.WebClientShareNoPasswordDisabled} + u.Filters.PasswordStrength = 70 + u.Password = "ahpoo8baa6EeZieshies" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPIUserTokenFromTestServer(defaultUsername, u.Password) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "s", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + } + asJSON, err := json.Marshal(share) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "You are not authorized to share files/folders without a password") + + share.Password = defaultPassword + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + share.Password = "vi5eiJoovee5ya9yahpi" + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + location := rr.Header().Get("Location") + assert.NotEmpty(t, location) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + assert.Equal(t, fmt.Sprintf("%v/%v", userSharesPath, objectID), location) + + share.Password = "" + asJSON, err = json.Marshal(share) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "You are not authorized to share files/folders without a password") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUserAPIKey(t *testing.T) { + u := getTestUser() + u.Filters.AllowAPIKeyAuth = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + apiKey := dataprovider.APIKey{ + Name: "testkey", + User: user.Username + "1", + Scope: dataprovider.APIKeyScopeUser, + } + _, _, err = httpdtest.AddAPIKey(apiKey, http.StatusBadRequest) + assert.NoError(t, err) + apiKey.User = user.Username + apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err) + + adminAPIKey := dataprovider.APIKey{ + Name: "testadminkey", + Scope: dataprovider.APIKeyScopeAdmin, + } + adminAPIKey, _, err = httpdtest.AddAPIKey(adminAPIKey, http.StatusCreated) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "filenametest") + assert.NoError(t, err) + _, err = part.Write([]byte("test file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var dirEntries []map[string]any + err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) + assert.NoError(t, err) + assert.Len(t, dirEntries, 1) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, adminAPIKey.Key, user.Username) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + user.Status = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + user.Status = 1 + user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + user.Filters.DeniedProtocols = []string{common.ProtocolFTP} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + apiKeyNew := dataprovider.APIKey{ + Name: apiKey.Name, + Scope: dataprovider.APIKeyScopeUser, + } + + apiKeyNew, _, err = httpdtest.AddAPIKey(apiKeyNew, http.StatusCreated) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKeyNew.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + // now associate a user + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKeyNew.Key, user.Username) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now with a missing user + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKeyNew.Key, user.Username+"1") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + // empty user and key not associated to any user + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKeyNew.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + apiKeyNew.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + _, _, err = httpdtest.UpdateAPIKey(apiKeyNew, http.StatusOK) + assert.NoError(t, err) + // expired API key + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKeyNew.Key, user.Username) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAPIKey(apiKeyNew, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAPIKey(adminAPIKey, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebClientExistenceCheck(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, webClientExistPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) // no CSRF header + + req, err = http.NewRequest(http.MethodPost, webClientExistPath, bytes.NewBuffer([]byte(`[]`))) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + filesToCheck := make(map[string]any) + filesToCheck["files"] = nil + asJSON, err := json.Marshal(filesToCheck) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientExistPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "files to be checked are mandatory") + + testFileName := "file.dat" + testDirName := "adirname" + filesToCheck["files"] = []string{testFileName} + asJSON, err = json.Marshal(filesToCheck) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2Fmissingdir", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var fileList []any + err = json.Unmarshal(rr.Body.Bytes(), &fileList) + assert.NoError(t, err) + assert.Len(t, fileList, 0) + + err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName), 100) + assert.NoError(t, err) + err = os.Mkdir(filepath.Join(user.GetHomeDir(), testDirName), 0755) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + fileList = nil + err = json.Unmarshal(rr.Body.Bytes(), &fileList) + assert.NoError(t, err) + assert.Len(t, fileList, 1) + + filesToCheck["files"] = []string{testFileName, testDirName} + asJSON, err = json.Marshal(filesToCheck) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + fileList = nil + err = json.Unmarshal(rr.Body.Bytes(), &fileList) + assert.NoError(t, err) + assert.Len(t, fileList, 2) + + req, err = http.NewRequest(http.MethodPost, webClientExistPath+"?path=%2F"+testDirName, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + fileList = nil + err = json.Unmarshal(rr.Body.Bytes(), &fileList) + assert.NoError(t, err) + assert.Len(t, fileList, 0) + + user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientExistPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebClientViewPDF(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientViewPDFPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, webClientViewPDFPath+"?path=test.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=test.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2F", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) + + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "test.pdf"), []byte("some text data"), 0666) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) //nolint:goconst + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) + + err = createTestFile(filepath.Join(user.GetHomeDir(), "test.pdf"), 1024) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPDFMessage) + + fakePDF := []byte(`%PDF-1.6`) + for i := 0; i < 128; i++ { + fakePDF = append(fakePDF, []byte(fmt.Sprintf("%d", i))...) + } + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "test.pdf"), fakePDF, 0666) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + DeniedPatterns: []string{"*.pdf"}, + }, + } + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nError403Message) + + user.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + DeniedPatterns: []string{"*.txt"}, + }, + } + user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientGetPDFPath+"?path=%2Ftest.pdf", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestWebEditFile(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + testFile1 := "testfile1.txt" + testFile2 := "testfile2" + file1Size := int64(65536) + file2Size := int64(1048576 * 5) + err = createTestFile(filepath.Join(user.GetHomeDir(), testFile1), file1Size) + assert.NoError(t, err) + err = createTestFile(filepath.Join(user.GetHomeDir(), testFile2), file2Size) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile2, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorEditSize) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=missing", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFsGeneric) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path=%2F", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorEditDir) + + user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + user.Filters.DeniedProtocols = []string{common.ProtocolFTP} + user.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + DeniedPatterns: []string{"*.txt"}, + }, + } + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nError403Message) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFile1, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestWebGetFiles(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + testFileName := "testfile" + testDir := "testdir" + testFileContents := []byte("file contents") + err = os.MkdirAll(filepath.Join(user.GetHomeDir(), testDir), os.ModePerm) + assert.NoError(t, err) + extensions := []string{"", ".doc", ".ppt", ".xls", ".pdf", ".mkv", ".png", ".go", ".zip", ".txt"} + for _, ext := range extensions { + err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName+ext), testFileContents, os.ModePerm) + assert.NoError(t, err) + } + err = os.Symlink(filepath.Join(user.GetHomeDir(), testFileName+".doc"), filepath.Join(user.GetHomeDir(), testDir, testFileName+".link")) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testDir, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path="+testDir, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var dirContents []map[string]any + err = json.Unmarshal(rr.Body.Bytes(), &dirContents) + assert.NoError(t, err) + assert.Len(t, dirContents, 1) + + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?dirtree=1&path="+testDir, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + dirContents = make([]map[string]any, 0) + err = json.Unmarshal(rr.Body.Bytes(), &dirContents) + assert.NoError(t, err) + assert.Len(t, dirContents, 0) + + req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var dirEntries []map[string]any + err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) + assert.NoError(t, err) + assert.Len(t, dirEntries, 1) + + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + form := make(url.Values) + form.Set("files", fmt.Sprintf(`["%s","%s","%s"]`, testFileName, testDir, testFileName+extensions[2])) + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), + bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // add csrf token + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), + bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // parse form error + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path=p%C3%AO%GK", + bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + filesList := []string{testFileName, testDir, testFileName + extensions[2]} + asJSON, err := json.Marshal(filesList) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer([]byte(`file`))) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("files", fmt.Sprintf(`["%v"]`, testDir)) + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), + bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("files", "notalist") + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path="+url.QueryEscape("/"), + bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path=/", nil) //nolint:goconst + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + dirContents = nil + err = json.Unmarshal(rr.Body.Bytes(), &dirContents) + assert.NoError(t, err) + assert.Len(t, dirContents, len(extensions)+1) + + req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path=/", nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + dirEntries = nil + err = json.Unmarshal(rr.Body.Bytes(), &dirEntries) + assert.NoError(t, err) + assert.Len(t, dirEntries, len(extensions)+1) + + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path=/missing", nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorDirListGeneric) + + req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path=missing", nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to get directory lister") + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path=", nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Please set the path to a valid file") + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testDir, nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "is a directory") + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path=notafile", nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to stat the requested file") + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=2-") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + assert.Equal(t, testFileContents[2:], rr.Body.Bytes()) + lastModified, err := http.ParseTime(rr.Header().Get("Last-Modified")) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=2-") + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + assert.Equal(t, testFileContents[2:], rr.Body.Bytes()) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=-2") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + assert.Equal(t, testFileContents[11:], rr.Body.Bytes()) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=-2,") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestedRangeNotSatisfiable, rr) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=1a-") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestedRangeNotSatisfiable, rr) + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=2b-") + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestedRangeNotSatisfiable, rr) + + req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=2-") + req.Header.Set("If-Range", lastModified.UTC().Format(http.TimeFormat)) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + + req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=2-") + req.Header.Set("If-Range", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("If-Modified-Since", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("If-Modified-Since", lastModified.UTC().Add(120*time.Second).Format(http.TimeFormat)) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotModified, rr) + + req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("If-Unmodified-Since", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPreconditionFailed, rr) + + req, _ = http.NewRequest(http.MethodHead, userFilesPath+"?path="+testFileName, nil) + req.Header.Set("If-Unmodified-Since", lastModified.UTC().Add(-120*time.Second).Format(http.TimeFormat)) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPreconditionFailed, rr) + + req, _ = http.NewRequest(http.MethodHead, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("If-Unmodified-Since", lastModified.UTC().Add(120*time.Second).Format(http.TimeFormat)) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + _, resp, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath+"?path=/", nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, _ = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + filesList = []string{testDir} + asJSON, err = json.Marshal(filesList) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPost, userStreamZipPath, bytes.NewBuffer(asJSON)) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + user.Filters.DeniedProtocols = []string{common.ProtocolFTP} + user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + _, resp, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + form = make(url.Values) + form.Set("files", `[]`) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientDownloadZipPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, _ = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRenameDifferentResource(t *testing.T) { + folderName := "foldercryptfs" + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "folderName"), + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("super secret"), + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser() + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/folderPath", + }) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + testFileName := "file.txt" + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + + getStatusResponse := func(taskID string) int { + req, _ := http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + if rr.Code != http.StatusOK { + return -1 + } + resp := make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + if err != nil { + return -1 + } + return int(resp["status"].(float64)) + } + + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Cannot perform copy step") + + req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + taskResp := make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &taskResp) + assert.NoError(t, err) + taskID := taskResp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusNotFound + }, 1000*time.Millisecond, 100*time.Millisecond) + + err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), []byte("just a test"), os.ModePerm) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // recreate the file and remove the delete permission + err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), []byte("just another test"), os.ModePerm) + assert.NoError(t, err) + + u.Permissions = map[string][]string{ + "/": {dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermCreateDirs, + dataprovider.PermDownload, dataprovider.PermOverwrite}, + } + _, resp, err := httpdtest.UpdateUser(u, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + webAPIToken, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Cannot perform copy step") + + u.Permissions = map[string][]string{ + "/": {dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermCreateDirs, + dataprovider.PermDownload, dataprovider.PermOverwrite, dataprovider.PermCopy}, + } + _, resp, err = httpdtest.UpdateUser(u, http.StatusOK, "") + assert.NoError(t, err, string(resp)) + webAPIToken, err = getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Cannot perform remove step") + + req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+testFileName+"&target="+url.QueryEscape(path.Join("/", "folderPath", testFileName)), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + taskResp = make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &taskResp) + assert.NoError(t, err) + taskID = taskResp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusForbidden + }, 1000*time.Millisecond, 100*time.Millisecond) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebDirsAPI(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + testDir := "testdir" + + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var contents []map[string]any + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 0) + + // rename a missing folder + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir+"new", nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // copy a missing folder + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"%2F&target="+testDir+"new%2F", nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // delete a missing folder + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // create a dir + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + // check the dir was created + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + if assert.Len(t, contents, 1) { + assert.Equal(t, testDir, contents[0]["name"]) + } + // rename a dir with the same source and target name + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target=%2F"+testDir+"%2F", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + // copy a dir with the same source and target name + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"&target="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + // create a dir with missing parents + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+url.QueryEscape(path.Join("/sub/dir", testDir)), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // setting the mkdir_parents param will work + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?mkdir_parents=true&path="+url.QueryEscape(path.Join("/sub/dir", testDir)), nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + // copy the dir + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"&target="+testDir+"copy", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // rename the dir + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir+"new", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // delete the dir + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir+"new", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir+"copy", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // the root dir cannot be created + req, err = http.NewRequest(http.MethodPost, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + user.Permissions["/"] = []string{dataprovider.PermListItems} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // the user has no more the permission to create the directory + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + // the user is deleted, any API call should fail + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path="+testDir+"&target="+testDir+"new", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path="+testDir+"&target="+testDir+"new", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path="+testDir+"new", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestWebUploadSingleFile(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + content := []byte("test content") + + req, err := http.NewRequest(http.MethodPost, userUploadFilePath, bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "please set a file path") + + modTime := time.Now().Add(-24 * time.Hour) + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer(content)) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + req.Header.Set("X-SFTPGO-MTIME", strconv.FormatInt(util.GetTimeAsMsSinceEpoch(modTime), 10)) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + info, err := os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) + if assert.NoError(t, err) { + assert.InDelta(t, util.GetTimeAsMsSinceEpoch(modTime), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(1000)) + } + // invalid modification time will be ignored + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + req.Header.Set("X-SFTPGO-MTIME", "123abc") + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + info, err = os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) + if assert.NoError(t, err) { + assert.InDelta(t, util.GetTimeAsMsSinceEpoch(time.Now()), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(3000)) + } + // upload to a missing dir will fail without the mkdir_parents param + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path="+url.QueryEscape("/subdir/file.txt"), bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?mkdir_parents=true&path="+url.QueryEscape("/subdir/file.txt"), bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + metadataReq := make(map[string]int64) + metadataReq["modification_time"] = util.GetTimeAsMsSinceEpoch(modTime) + asJSON, err := json.Marshal(metadataReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + info, err = os.Stat(filepath.Join(user.GetHomeDir(), "file.txt")) + if assert.NoError(t, err) { + assert.InDelta(t, util.GetTimeAsMsSinceEpoch(modTime), util.GetTimeAsMsSinceEpoch(info.ModTime()), float64(1000)) + } + // missing file + req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file2.txt", bytes.NewBuffer(asJSON)) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to set metadata for path") + // invalid JSON + req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // missing mandatory parameter + req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "please set a modification_time and a path") + + metadataReq = make(map[string]int64) + asJSON, err = json.Marshal(metadataReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "please set a modification_time and a path") + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=%2Fdir%2Ffile.txt", bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to write file") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", bytes.NewBuffer(content)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") + + metadataReq["modification_time"] = util.GetTimeAsMsSinceEpoch(modTime) + asJSON, err = json.Marshal(metadataReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPatch, userFilesDirsMetadataPath+"?path=file.txt", bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to retrieve your user") +} + +func TestWebFilesAPI(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file1.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file1 content")) + assert.NoError(t, err) + part2, err := writer.CreateFormFile("filenames", "file2.txt") + assert.NoError(t, err) + _, err = part2.Write([]byte("file2 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Unable to parse multipart form") + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, int64(0), user.FirstDownload) + // set the proper content type + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Equal(t, int64(0), user.FirstDownload) + // check we have 2 files + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var contents []map[string]any + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + // download a file + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=file1.txt", nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "file1 content", rr.Body.String()) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Greater(t, user.FirstDownload, int64(0)) + // overwrite the existing files + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + // now create a dir and upload to that dir + testDir := "tdir" + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path="+testDir, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + // upload to a missing subdir will fail without the mkdir_parents param + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path="+url.QueryEscape("/sub/"+testDir), reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?mkdir_parents=true&path="+url.QueryEscape("/sub/"+testDir), reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 4) + req, err = http.NewRequest(http.MethodGet, userDirsPath+"?path="+testDir, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + // copy a file + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/copy?path=file1.txt&target=%2Ftdir%2Ffile_copy.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // rename a file + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?target=%2Ftdir%2Ffile3.txt&path=file1.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // rename a missing file + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=file1.txt&target=%2Ftdir%2Ffile3.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // rename a file with target name equal to source name + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=file1.txt&target=file1.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + // delete a file + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file2.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // delete a missing file + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file2.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // delete a directory + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=tdir", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // make a symlink outside the home dir and then try to delete it + extPath := filepath.Join(os.TempDir(), "file") + err = os.WriteFile(extPath, []byte("contents"), os.ModePerm) + assert.NoError(t, err) + err = os.Symlink(extPath, filepath.Join(user.GetHomeDir(), "file")) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + err = os.Remove(extPath) + assert.NoError(t, err) + // remove delete and overwrite permissions + user.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=tdir", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=%2Ftdir%2Ffile1.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + // the user is deleted, any API call should fail + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=file1.txt&target=%2Ftdir%2Ffile3.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file2.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestBufferedWebFilesAPI(t *testing.T) { + u := getTestUser() + u.FsConfig.OSConfig = sdk.OSFsConfig{ + ReadBufferSize: 1, + WriteBufferSize: 1, + } + vdirPath := "/crypted" + mappedPath := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName := filepath.Base(mappedPath) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + OSFsConfig: sdk.OSFsConfig{ + WriteBufferSize: 3, + ReadBufferSize: 2, + }, + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", "file1.txt") + assert.NoError(t, err) + _, err = part1.Write([]byte("file1 content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path="+url.QueryEscape(vdirPath), reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=file1.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "file1 content", rr.Body.String()) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+url.QueryEscape(vdirPath+"/file1.txt"), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, "file1 content", rr.Body.String()) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=file1.txt", nil) + assert.NoError(t, err) + req.Header.Set("Range", "bytes=2-") + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + assert.Equal(t, "le1 content", rr.Body.String()) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+url.QueryEscape(vdirPath+"/file1.txt"), nil) + assert.NoError(t, err) + req.Header.Set("Range", "bytes=3-6") + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + assert.Equal(t, "e1 c", rr.Body.String()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestWebClientTasksAPI(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + u1 := getTestUser() + u1.Username = xid.New().String() + user1, _, err := httpdtest.AddUser(u1, http.StatusCreated) + assert.NoError(t, err) + + testDir := "subdir" + testFileData := []byte("data") + testFilePath := filepath.Join(user.GetHomeDir(), testDir, "file.txt") + testFileName := filepath.Base(testFilePath) + err = os.MkdirAll(filepath.Dir(testFilePath), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath, testFileData, 0666) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + webToken1, err := getJWTWebClientTokenFromTestServer(user1.Username, defaultPassword) + assert.NoError(t, err) + + getStatusResponse := func(taskID string) int { + req, _ := http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + if rr.Code != http.StatusOK { + return -1 + } + resp := make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + if err != nil { + return -1 + } + return int(resp["status"].(float64)) + } + // missing task + assert.Equal(t, -1, getStatusResponse("missing")) + + req, err := http.NewRequest(http.MethodPost, webClientFileCopyPath+"?path="+ + url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + resp := make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + taskID := resp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusOK + }, 1000*time.Millisecond, 100*time.Millisecond) + + // cannot get the task with a different user + req, err = http.NewRequest(http.MethodGet, webClientTasksPath+"/"+url.PathEscape(taskID), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken1) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+ + url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + resp = make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + taskID = resp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusOK + }, 1000*time.Millisecond, 100*time.Millisecond) + + req, err = http.NewRequest(http.MethodDelete, webClientDirsPath+"?path="+ + url.QueryEscape(testDir), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + resp = make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + taskID = resp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusOK + }, 1000*time.Millisecond, 100*time.Millisecond) + + req, err = http.NewRequest(http.MethodDelete, webClientDirsPath+"?path="+ + url.QueryEscape(testDir), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + resp = make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + taskID = resp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusNotFound + }, 1000*time.Millisecond, 100*time.Millisecond) + + req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+ + url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + resp = make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + taskID = resp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusNotFound + }, 1000*time.Millisecond, 100*time.Millisecond) + + req, err = http.NewRequest(http.MethodPost, webClientFileCopyPath+"?path="+ + url.QueryEscape(path.Join(testDir, testFileName)+"/")+"&target="+url.QueryEscape(testFileName+"/"), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusAccepted, rr) + resp = make(map[string]any) + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + taskID = resp["message"].(string) + assert.NotEmpty(t, taskID) + + assert.Eventually(t, func() bool { + status := getStatusResponse(taskID) + return status == http.StatusNotFound + }, 1000*time.Millisecond, 100*time.Millisecond) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + // user deleted + req, err = http.NewRequest(http.MethodDelete, webClientDirsPath+"?path="+ + url.QueryEscape(testDir), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webClientFileMovePath+"?path="+ + url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webClientFileCopyPath+"?path="+ + url.QueryEscape(path.Join(testDir, testFileName))+"&target="+url.QueryEscape(testFileName), nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("X-CSRF-TOKEN", csrfToken) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestStartDirectory(t *testing.T) { + u := getTestUser() + u.Filters.StartDirectory = "/start/dir" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + filename := "file1.txt" + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part1, err := writer.CreateFormFile("filenames", filename) + assert.NoError(t, err) + _, err = part1.Write([]byte("test content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + // check we have 2 files in the defined start dir + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var contents []map[string]any + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + if assert.Len(t, contents, 1) { + assert.Equal(t, filename, contents[0]["name"].(string)) + } + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file2.txt", + bytes.NewBuffer([]byte("single upload content"))) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=testdir", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=testdir&target=testdir1", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=%2Ftestdirroot", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodGet, userDirsPath+"?path="+url.QueryEscape(u.Filters.StartDirectory), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 3) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+filename, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=%2F"+filename, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPatch, userFilesPath+"?path="+filename+"&target="+filename+"_rename", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path=testdir1", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + + req, err = http.NewRequest(http.MethodGet, webClientDirsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 2) + + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path="+filename+"_rename", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, userDirsPath+"?path="+url.QueryEscape(u.Filters.StartDirectory), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + contents = nil + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 1) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebFilesTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.UploadDataTransfer = 1 + u.DownloadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + testFileName := "file.data" + testFileSize := 550000 + testFileContents := make([]byte, testFileSize) + n, err := io.ReadFull(rand.Reader, testFileContents) + assert.NoError(t, err) + assert.Equal(t, testFileSize, n) + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", testFileName) + assert.NoError(t, err) + _, err = part.Write(testFileContents) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + // error while download is active + downloadFunc := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + } + downloadFunc() + // error before starting the download + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path="+testFileName, nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // error while upload is active + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + // error before starting the upload + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + // now test upload/download to/from shares + share1 := dataprovider.Share{ + Name: "share1", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + } + asJSON, err := json.Marshal(share1) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + req, err = http.NewRequest(http.MethodGet, sharesPath+"/"+objectID, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + form := make(url.Values) + form.Set("files", `[]`) + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, objectID, "/partial"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorQuotaRead) + + share2 := dataprovider.Share{ + Name: "share2", + Scope: dataprovider.ShareScopeWrite, + Paths: []string{"/"}, + } + asJSON, err = json.Marshal(share2) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userSharesPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + objectID = rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, sharesPath+"/"+objectID, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebUploadErrors(t *testing.T) { + u := getTestUser() + u.QuotaSize = 65535 + subDir1 := "sub1" + subDir2 := "sub2" + u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} + u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDelete} + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/sub2", + AllowedPatterns: []string{}, + DeniedPatterns: []string{"*.zip"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "file.zip") + assert.NoError(t, err) + _, err = part.Write([]byte("file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + // zip file are not allowed within sub2 + req, err := http.NewRequest(http.MethodPost, userFilesPath+"?path=sub2", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + // we have no upload permissions within sub1 + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=sub1", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // we cannot create dirs in sub2 + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?mkdir_parents=true&path="+url.QueryEscape("/sub2/dir"), reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "unable to check/create missing parent dir") + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?mkdir_parents=true&path="+url.QueryEscape("/sub2/dir/test"), nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Error checking parent directories") + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?mkdir_parents=true&path="+url.QueryEscape("/sub2/dir1/file.txt"), bytes.NewBuffer([]byte(""))) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Error checking parent directories") + // create a dir and try to overwrite it with a file + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=file.zip", nil) //nolint:goconst + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "operation unsupported") + // try to upload to a missing parent directory + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=missingdir", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path=file.zip", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // upload will work now + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + // overwrite the file + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + vfs.SetTempPath(filepath.Join(os.TempDir(), "missingpath")) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + if runtime.GOOS != osWindows { + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.zip", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + vfs.SetTempPath(filepath.Clean(os.TempDir())) + err = os.Chmod(user.GetHomeDir(), 0555) + assert.NoError(t, err) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Error closing file") + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.zip", bytes.NewBuffer(nil)) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Error closing file") + + err = os.Chmod(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + } + + vfs.SetTempPath("") + + // upload a multipart form with no files + body = new(bytes.Buffer) + writer = multipart.NewWriter(body) + err = writer.Close() + assert.NoError(t, err) + reader = bytes.NewReader(body.Bytes()) + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=sub2", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "No files uploaded!") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebAPIVFolder(t *testing.T) { + u := getTestUser() + u.QuotaSize = 65535 + vdir := "/vdir" + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdir, + QuotaSize: -1, + QuotaFiles: -1, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(user.Username, defaultPassword) + assert.NoError(t, err) + + fileContents := []byte("test contents") + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "file.txt") + assert.NoError(t, err) + _, err = part.Write(fileContents) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err := http.NewRequest(http.MethodPost, userFilesPath+"?path=vdir", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(len(fileContents)), user.UsedQuotaSize) + + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath+"?path=vdir", reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(len(fileContents)), user.UsedQuotaSize) + + folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestWebAPIWritePermission(t *testing.T) { + u := getTestUser() + u.Filters.WebClient = append(u.Filters.WebClient, sdk.WebClientWriteDisabled) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "file.txt") + assert.NoError(t, err) + _, err = part.Write([]byte("")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=a&target=b", nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=a", nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodGet, userFilesPath+"?path=a.txt", nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodPost, userDirsPath+"?path=dir", nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodPost, userFileActionsPath+"/move?path=dir&target=dir1", nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodDelete, userDirsPath+"?path=dir", nil) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebAPICryptFs(t *testing.T) { + u := getTestUser() + u.QuotaSize = 65535 + u.FsConfig.Provider = sdk.CryptedFilesystemProvider + u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "file.txt") + assert.NoError(t, err) + _, err = part.Write([]byte("content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebUploadSFTP(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 100 + u.FsConfig.SFTPConfig.BufferSize = 2 + u.HomeDir = filepath.Join(os.TempDir(), u.Username) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webAPIToken, err := getJWTAPIUserTokenFromTestServer(sftpUser.Username, defaultPassword) + assert.NoError(t, err) + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "file.txt") + assert.NoError(t, err) + _, err = part.Write([]byte("test file content")) + assert.NoError(t, err) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + expectedQuotaSize := int64(17) + expectedQuotaFiles := 1 + user, _, err := httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + + user.QuotaSize = 10 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // we are now overquota on overwrite + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + assert.Contains(t, rr.Body.String(), "denying write due to space limit") + assert.Contains(t, rr.Body.String(), "Unable to write file") + + // delete the file + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user, _, err = httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + + req, err = http.NewRequest(http.MethodPost, userUploadFilePath+"?path=file.txt", + bytes.NewBuffer([]byte("test upload single file content"))) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + assert.Contains(t, rr.Body.String(), "denying write due to space limit") + assert.Contains(t, rr.Body.String(), "Error saving file") + + // delete the file + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = reader.Seek(0, io.SeekStart) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusRequestEntityTooLarge, rr) + assert.Contains(t, rr.Body.String(), "denying write due to space limit") + assert.Contains(t, rr.Body.String(), "Error saving file") + + // delete the file + req, err = http.NewRequest(http.MethodDelete, userFilesPath+"?path=file.txt", nil) + assert.NoError(t, err) + setBearerForReq(req, webAPIToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user, _, err = httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(sftpUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebAPISFTPPasswordProtectedPrivateKey(t *testing.T) { + u := getTestUser() + u.Password = "" + u.PublicKeys = []string{testPubKeyPwd} + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.FsConfig.SFTPConfig.Password = kms.NewEmptySecret() + u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKeyPwd) + u.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret(privateKeyPwd) + u.HomeDir = filepath.Join(os.TempDir(), u.Username) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + webToken, err := getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // update the user, the key must be preserved + assert.Equal(t, sdkkms.SecretStatusSecretBox, sftpUser.FsConfig.SFTPConfig.KeyPassphrase.GetStatus()) + _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // using a wrong passphrase or no passphrase should fail + sftpUser.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret("wrong") + _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") + assert.NoError(t, err) + _, err = getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) + assert.Error(t, err) + sftpUser.FsConfig.SFTPConfig.KeyPassphrase = kms.NewEmptySecret() + _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") + assert.NoError(t, err) + _, err = getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(sftpUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebUploadMultipartFormReadError(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, userFilesPath, nil) + assert.NoError(t, err) + + mpartForm := &multipart.Form{ + File: make(map[string][]*multipart.FileHeader), + } + mpartForm.File["filenames"] = append(mpartForm.File["filenames"], &multipart.FileHeader{Filename: "missing"}) + req.MultipartForm = mpartForm + req.Header.Add("Content-Type", "multipart/form-data") + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + assert.Contains(t, rr.Body.String(), "Unable to read uploaded file") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCompressionErrorMock(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + _, err := httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + }() + + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, webToken) + assert.NoError(t, err) + + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("files", `["missing"]`) + req, _ := http.NewRequest(http.MethodPost, webClientDownloadZipPath+"?path=%2F", + bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + executeRequest(req) +} + +func TestGetFilesSFTPBackend(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser() + u.HomeDir = filepath.Clean(os.TempDir()) + u.FsConfig.SFTPConfig.BufferSize = 2 + u.Permissions["/adir"] = nil + u.Permissions["/adir1"] = []string{dataprovider.PermListItems} + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/adir2", + DeniedPatterns: []string{"*.txt"}, + }, + } + sftpUserBuffered, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u.Username += "_unbuffered" + u.FsConfig.SFTPConfig.BufferSize = 0 + sftpUserUnbuffered, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + testFileName := "testsftpfile" + testDir := "testsftpdir" + testFileContents := []byte("sftp file contents") + err = os.MkdirAll(filepath.Join(user.GetHomeDir(), testDir, "sub"), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "adir1"), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(user.GetHomeDir(), "adir2"), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), testFileName), testFileContents, os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "adir1", "afile"), testFileContents, os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(user.GetHomeDir(), "adir2", "afile.txt"), testFileContents, os.ModePerm) + assert.NoError(t, err) + for _, sftpUser := range []dataprovider.User{sftpUserBuffered, sftpUserUnbuffered} { + webToken, err := getJWTWebClientTokenFromTestServer(sftpUser.Username, defaultPassword) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+path.Join(testDir, "sub"), nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + htmlErrFrag := `div id="errorMsg" class="rounded border-warning border border-dashed bg-light-warning` + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+path.Join(testDir, "missing"), nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), htmlErrFrag) + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=adir/sub", nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), htmlErrFrag) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=adir1/afile", nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), htmlErrFrag) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path=adir2/afile.txt", nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), htmlErrFrag) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, testFileContents, rr.Body.Bytes()) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + req.Header.Set("Range", "bytes=2-") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusPartialContent, rr) + assert.Equal(t, testFileContents[2:], rr.Body.Bytes()) + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientUserClose(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 32 + u.DownloadBandwidth = 32 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileName := "file.dat" + testFileSize := int64(524288) + testFilePath := filepath.Join(user.GetHomeDir(), testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + uploadContent := make([]byte, testFileSize) + _, err = rand.Read(uploadContent) + assert.NoError(t, err) + webToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + webAPIToken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath+"?path="+testFileName, nil) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + }() + wg.Add(1) + go func() { + defer wg.Done() + req, _ := http.NewRequest(http.MethodGet, webClientEditFilePath+"?path="+testFileName, nil) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + }() + wg.Add(1) + go func() { + defer wg.Done() + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("filenames", "upload.dat") + assert.NoError(t, err) + n, err := part.Write(uploadContent) + assert.NoError(t, err) + assert.Equal(t, testFileSize, int64(n)) + err = writer.Close() + assert.NoError(t, err) + reader := bytes.NewReader(body.Bytes()) + req, err := http.NewRequest(http.MethodPost, userFilesPath, reader) + assert.NoError(t, err) + req.Header.Add("Content-Type", writer.FormDataContentType()) + setBearerForReq(req, webAPIToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "transfer aborted") + }() + // wait for the transfers + assert.Eventually(t, func() bool { + stats := common.Connections.GetStats("") + if len(stats) == 3 { + if len(stats[0].Transfers) > 0 && len(stats[1].Transfers) > 0 { + return true + } + } + return false + }, 1*time.Second, 50*time.Millisecond) + + for _, stat := range common.Connections.GetStats("") { + // close all the active transfers + common.Connections.Close(stat.ConnectionID, "") + } + wg.Wait() + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWebAdminSetupMock(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + // now delete all the admins + admins, err := dataprovider.GetAdmins(100, 0, dataprovider.OrderASC) + assert.NoError(t, err) + for _, admin := range admins { + err = dataprovider.DeleteAdmin(admin.Username, "", "", "") + assert.NoError(t, err) + } + // close the provider and initializes it without creating the default admin + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "0") + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + // now the setup page must be rendered + req, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check redirects to the setup page + req, err = http.NewRequest(http.MethodGet, "/", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + req, err = http.NewRequest(http.MethodGet, webBasePath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + req, err = http.NewRequest(http.MethodGet, webBasePathAdmin, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + req, err = http.NewRequest(http.MethodGet, webLoginPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + + loginCookie, csrfToken, err := getCSRFTokenMock(webAdminSetupPath, defaultRemoteAddr) + assert.NoError(t, err) + form := make(url.Values) + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("username", defaultTokenAuthUser) + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("password", defaultTokenAuthPass) + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoMatch) + form.Set("confirm_password", defaultTokenAuthPass) + // test a parse form error + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath+"?param=p%C3%AO%GH", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test a dataprovider error + err = dataprovider.Close() + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // finally initialize the provider and create the default admin + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webAdminMFAPath, rr.Header().Get("Location")) + // if we resubmit the form we get a bad request, an admin already exists + req, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") +} + +func TestAllowList(t *testing.T) { + configCopy := common.Config + + entries := []dataprovider.IPListEntry{ + { + IPOrNet: "172.120.1.1/32", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + { + IPOrNet: "172.120.1.2/32", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 0, + }, + { + IPOrNet: "192.8.7.0/22", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 8, + }, + } + + for _, e := range entries { + _, _, err := httpdtest.AddIPListEntry(e, http.StatusCreated) + assert.NoError(t, err) + } + + common.Config.MaxTotalConnections = 1 + common.Config.AllowListStatus = 1 + err := common.Initialize(common.Config, 0) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, webLoginPath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) + + req.RemoteAddr = "172.120.1.1" + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + testIP := "172.120.1.3" + req.RemoteAddr = testIP + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) + + req.RemoteAddr = "192.8.7.1" + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + entry := dataprovider.IPListEntry{ + IPOrNet: "172.120.1.3/32", + Type: dataprovider.IPListTypeAllowList, + Mode: dataprovider.ListModeAllow, + Protocols: 8, + } + err = dataprovider.AddIPListEntry(&entry, "", "", "") + assert.NoError(t, err) + + req.RemoteAddr = testIP + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + err = dataprovider.DeleteIPListEntry(entry.IPOrNet, entry.Type, "", "", "") + assert.NoError(t, err) + + req.RemoteAddr = testIP + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), common.ErrConnectionDenied.Error()) + + common.Config = configCopy + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) + + for _, e := range entries { + _, err := httpdtest.RemoveIPListEntry(e, http.StatusOK) + assert.NoError(t, err) + } +} + +func TestWebAdminLoginMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webStatusPath+"notfound", nil) + req.RequestURI = webStatusPath + "notfound" + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webLogoutPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + cookie := rr.Header().Get("Cookie") + assert.Empty(t, cookie) + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + + req, _ = http.NewRequest(http.MethodGet, logoutPath, nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, serverStatusPath, nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "Your token is no longer valid") + + req, _ = http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + // now try using wrong password + form := getLoginForm(defaultTokenAuthUser, "wrong pwd", csrfToken) + req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + // wrong username + form = getLoginForm("wrong username", defaultTokenAuthPass, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + // try from an ip not allowed + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Filters.AllowList = []string{"10.0.0.0/8"} + + _, _, err = httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + + rAddr := "127.1.1.1:1234" + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = rAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + rAddr = "10.9.9.9:1234" + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = rAddr + setLoginCookie(req, loginCookie) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + + rAddr = "127.0.1.1:4567" + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, rAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie) + form = getLoginForm(altAdminUsername, altAdminPassword, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = rAddr + req.Header.Set("X-Forwarded-For", "10.9.9.9") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + // invalid csrf token + form = getLoginForm(altAdminUsername, altAdminPassword, "invalid csrf") + req, _ = http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = "10.9.9.8:1234" + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, _ = http.NewRequest(http.MethodGet, webLoginPath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveAdmin(a, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminNoToken(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, webAdminProfilePath, nil) + rr := executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + req, _ = http.NewRequest(http.MethodGet, userPath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + req, _ = http.NewRequest(http.MethodGet, activeConnectionsPath, nil) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) +} + +func TestWebUserShare(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) + userAPItoken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "test share", + Description: "test share desc", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour)), + MaxTokens: 100, + AllowFrom: []string{"127.0.0.0/8", "172.16.0.0/16"}, + Password: defaultPassword, + } + form := make(url.Values) + form.Set("name", share.Name) + form.Set("scope", strconv.Itoa(int(share.Scope))) + form.Set("paths[0][path]", "/") + form.Set("max_tokens", strconv.Itoa(share.MaxTokens)) + form.Set("allowed_ip", strings.Join(share.AllowFrom, ",")) + form.Set("description", share.Description) + form.Set("password", share.Password) + form.Set("expiration_date", "123") + // invalid expiration date + req, err := http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpiration) + form.Set("expiration_date", util.GetTimeFromMsecSinceEpoch(share.ExpiresAt).UTC().Format("2006-01-02 15:04:05")) + form.Set("scope", "") + // invalid scope + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareScope) + form.Set("scope", strconv.Itoa(int(share.Scope))) + // invalid max tokens + form.Set("max_tokens", "t") + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareMaxTokens) + form.Set("max_tokens", strconv.Itoa(share.MaxTokens)) + // no csrf token + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + form.Set("scope", "100") + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareScope) + + form.Set("scope", strconv.Itoa(int(share.Scope))) + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userAPItoken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var shares []dataprovider.Share + err = json.Unmarshal(rr.Body.Bytes(), &shares) + assert.NoError(t, err) + if assert.Len(t, shares, 1) { + s := shares[0] + assert.Equal(t, share.Name, s.Name) + assert.Equal(t, share.Description, s.Description) + assert.Equal(t, share.Scope, s.Scope) + assert.Equal(t, share.Paths, s.Paths) + assert.InDelta(t, share.ExpiresAt, s.ExpiresAt, 999) + assert.Equal(t, share.MaxTokens, s.MaxTokens) + assert.Equal(t, share.AllowFrom, s.AllowFrom) + assert.Equal(t, redactedSecret, s.Password) + share.ShareID = s.ShareID + } + form.Set("password", redactedSecret) + form.Set("expiration_date", "123") + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/unknowid", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareExpiration) + + form.Set("expiration_date", "") + form.Set(csrfFormToken, "") + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + form.Set("allowed_ip", "1.1.1") + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidIPMask) + + form.Set("allowed_ip", "") + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userAPItoken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + shares = nil + err = json.Unmarshal(rr.Body.Bytes(), &shares) + assert.NoError(t, err) + if assert.Len(t, shares, 1) { + s := shares[0] + assert.Equal(t, share.Name, s.Name) + assert.Equal(t, share.Description, s.Description) + assert.Equal(t, share.Scope, s.Scope) + assert.Equal(t, share.Paths, s.Paths) + assert.Equal(t, int64(0), s.ExpiresAt) + assert.Equal(t, share.MaxTokens, s.MaxTokens) + assert.Empty(t, s.AllowFrom) + } + // check the password + s, err := dataprovider.ShareExists(share.ShareID, user.Username) + assert.NoError(t, err) + match, err := s.CheckCredentials(defaultPassword) + assert.NoError(t, err) + assert.True(t, match) + + req, err = http.NewRequest(http.MethodGet, webClientSharePath+"?path=%2F&files=a", nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + + req, err = http.NewRequest(http.MethodGet, webClientSharePath+"?path=%2F&files=%5B\"adir\"%5D", nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientSharePath+"/unknown", nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webClientSharePath+"/"+share.ShareID, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientSharesPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebUserShareNoPasswordDisabled(t *testing.T) { + u := getTestUser() + u.Filters.WebClient = []string{sdk.WebClientShareNoPasswordDisabled} + u.Filters.DefaultSharesExpiration = 15 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Filters.DefaultSharesExpiration = 30 + user, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") + assert.NoError(t, err) + token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientSharePath, token) + assert.NoError(t, err) + userAPItoken, err := getJWTAPIUserTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + + share := dataprovider.Share{ + Name: "s", + Scope: dataprovider.ShareScopeRead, + Paths: []string{"/"}, + } + form := make(url.Values) + form.Set("name", share.Name) + form.Set("scope", strconv.Itoa(int(share.Scope))) + form.Set("paths[0][path]", "/") + form.Set("max_tokens", "0") + form.Set(csrfFormToken, csrfToken) + req, err := http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareNoPwd) + + form.Set("password", defaultPassword) + req, err = http.NewRequest(http.MethodPost, webClientSharePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + req, err = http.NewRequest(http.MethodGet, webClientSharePath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, userSharesPath, nil) + assert.NoError(t, err) + setBearerForReq(req, userAPItoken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var shares []dataprovider.Share + err = json.Unmarshal(rr.Body.Bytes(), &shares) + assert.NoError(t, err) + if assert.Len(t, shares, 1) { + s := shares[0] + assert.Equal(t, share.Name, s.Name) + assert.Equal(t, share.Scope, s.Scope) + assert.Equal(t, share.Paths, s.Paths) + share.ShareID = s.ShareID + } + assert.NotEmpty(t, share.ShareID) + form.Set("password", "") + req, err = http.NewRequest(http.MethodPost, webClientSharePath+"/"+share.ShareID, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorShareNoPwd) + + user.Filters.DefaultSharesExpiration = 0 + user.Filters.MaxSharesExpiration = 30 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webClientSharePath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestInvalidCSRF(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + for _, loginURL := range []string{webClientLoginPath, webLoginPath} { + // try using an invalid CSRF token + loginCookie1, csrfToken1, err := getCSRFTokenMock(loginURL, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie1) + assert.NotEmpty(t, csrfToken1) + loginCookie2, csrfToken2, err := getCSRFTokenMock(loginURL, defaultRemoteAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie2) + assert.NotEmpty(t, csrfToken2) + rAddr := "1.2.3.4" + loginCookie3, csrfToken3, err := getCSRFTokenMock(loginURL, rAddr) + assert.NoError(t, err) + assert.NotEmpty(t, loginCookie3) + assert.NotEmpty(t, csrfToken3) + + form := getLoginForm(defaultUsername, defaultPassword, csrfToken1) + req, err := http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURL + setLoginCookie(req, loginCookie2) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + // use a CSRF token as login cookie (invalid audience) + form = getLoginForm(defaultUsername, defaultPassword, csrfToken1) + req, err = http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURL + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", csrfToken1)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + // invalid IP + form = getLoginForm(defaultUsername, defaultPassword, csrfToken3) + req, err = http.NewRequest(http.MethodPost, loginURL, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.RequestURI = loginURL + setLoginCookie(req, loginCookie3) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + } + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebUserProfile(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) + + email := "user@user.com" + description := "User" + + form := make(url.Values) + form.Set("allow_api_key_auth", "1") + form.Set("email", email) + form.Set("description", description) + form.Set("public_keys[0][public_key]", testPubKey) + form.Set("public_keys[1][public_key]", testPubKey1) + form.Set("tls_certs[0][tls_cert]", httpsCert) + form.Set("additional_emails[0][additional_email]", "email1@user.com") + // no csrf token + req, err := http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.AllowAPIKeyAuth) + assert.Len(t, user.PublicKeys, 2) + assert.Len(t, user.Filters.TLSCerts, 1) + assert.Equal(t, email, user.Email) + assert.Equal(t, description, user.Description) + if assert.Len(t, user.Filters.AdditionalEmails, 1) { + assert.Equal(t, "email1@user.com", user.Filters.AdditionalEmails[0]) + } + + // set an invalid email + form.Set("email", "not an email") + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidEmail) + // invalid tls cert + form.Set("email", email) + form.Set("tls_certs[0][tls_cert]", "not a TLS cert") + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidTLSCert) + // invalid public key + form.Set("tls_certs[0][tls_cert]", httpsCert) + form.Set("public_keys[0][public_key]", "invalid") + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorPubKeyInvalid) + // now remove permissions + form.Set("public_keys[0][public_key]", testPubKey) + form.Del("public_keys[1][public_key]") + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) + + form.Set("allow_api_key_auth", "0") + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.AllowAPIKeyAuth) + assert.Len(t, user.PublicKeys, 1) + assert.Len(t, user.Filters.TLSCerts, 1) + assert.Equal(t, email, user.Email) + assert.Equal(t, description, user.Description) + + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, + sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) + form.Set("public_keys[0][public_key]", testPubKey) + form.Set("public_keys[1][public_key]", testPubKey1) + form.Set("tls_certs[0][tls_cert]", "") + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.AllowAPIKeyAuth) + assert.Len(t, user.PublicKeys, 1) + assert.Len(t, user.Filters.TLSCerts, 1) + assert.Equal(t, email, user.Email) + assert.Equal(t, description, user.Description) + + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webClientProfilePath, token) + assert.NoError(t, err) + form.Set("email", "newemail@user.com") + form.Set("description", "new description") + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.AllowAPIKeyAuth) + assert.Len(t, user.PublicKeys, 2) + assert.Len(t, user.Filters.TLSCerts, 0) + assert.Equal(t, email, user.Email) + assert.Equal(t, description, user.Description) + // finally disable all profile permissions + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled, sdk.WebClientInfoChangeDisabled, + sdk.WebClientPubKeyChangeDisabled, sdk.WebClientTLSCertChangeDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + token, err = getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + csrfToken, err = getCSRFTokenFromInternalPageMock(webChangeClientPwdPath, token) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) +} + +func TestWebAdminProfile(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTWebTokenFromTestServer(admin.Username, altAdminPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminProfilePath, token) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webAdminProfilePath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form := make(url.Values) + form.Set("allow_api_key_auth", "1") + form.Set("email", "admin@example.com") + form.Set("description", "admin desc") + // no csrf token + req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.AllowAPIKeyAuth) + assert.Equal(t, "admin@example.com", admin.Email) + assert.Equal(t, "admin desc", admin.Description) + + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nProfileUpdated) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.AllowAPIKeyAuth) + assert.Empty(t, admin.Email) + assert.Empty(t, admin.Description) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) +} + +func TestWebAdminPwdChange(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Filters.Preferences.HideUserPageSections = 16 + 32 + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTWebTokenFromTestServer(admin.Username, altAdminPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webChangeAdminPwdPath, token) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webChangeAdminPwdPath, nil) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form := make(url.Values) + form.Set("current_password", altAdminPassword) + form.Set("new_password1", altAdminPassword) + form.Set("new_password2", altAdminPassword) + // no csrf token + req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoDifferent) + + form.Set("new_password1", altAdminPassword+"1") + form.Set("new_password2", altAdminPassword+"1") + req, _ = http.NewRequest(http.MethodPost, webChangeAdminPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestAPIKeysManagement(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiKey := dataprovider.APIKey{ + Name: "test key", + Scope: dataprovider.APIKeyScopeAdmin, + } + asJSON, err := json.Marshal(apiKey) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + location := rr.Header().Get("Location") + assert.NotEmpty(t, location) + objectID := rr.Header().Get("X-Object-ID") + assert.NotEmpty(t, objectID) + assert.Equal(t, fmt.Sprintf("%v/%v", apiKeysPath, objectID), location) + apiKey.KeyID = objectID + response := make(map[string]string) + err = json.Unmarshal(rr.Body.Bytes(), &response) + assert.NoError(t, err) + key := response["key"] + assert.NotEmpty(t, key) + assert.True(t, strings.HasPrefix(key, apiKey.KeyID+".")) + + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var keyGet dataprovider.APIKey + err = json.Unmarshal(rr.Body.Bytes(), &keyGet) + assert.NoError(t, err) + assert.Empty(t, keyGet.Key) + assert.Equal(t, apiKey.KeyID, keyGet.KeyID) + assert.Equal(t, apiKey.Scope, keyGet.Scope) + assert.Equal(t, apiKey.Name, keyGet.Name) + assert.Equal(t, int64(0), keyGet.ExpiresAt) + assert.Equal(t, int64(0), keyGet.LastUseAt) + assert.Greater(t, keyGet.CreatedAt, int64(0)) + assert.Greater(t, keyGet.UpdatedAt, int64(0)) + assert.Empty(t, keyGet.Description) + assert.Empty(t, keyGet.User) + assert.Empty(t, keyGet.Admin) + + // API key is not enabled for the admin user so this request should fail + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, admin.Username) + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "the admin associated with the provided api key cannot be authenticated") + + admin.Filters.AllowAPIKeyAuth = true + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, admin.Username) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, admin.Username+"1") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + // now associate the key directly to the admin + apiKey.Admin = admin.Username + apiKey.Description = "test description" + asJSON, err = json.Marshal(apiKey) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, apiKeysPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var keys []dataprovider.APIKey + err = json.Unmarshal(rr.Body.Bytes(), &keys) + assert.NoError(t, err) + if assert.GreaterOrEqual(t, len(keys), 1) { + found := false + for _, k := range keys { + if k.KeyID == apiKey.KeyID { + found = true + assert.Empty(t, k.Key) + assert.Equal(t, apiKey.Scope, k.Scope) + assert.Equal(t, apiKey.Name, k.Name) + assert.Equal(t, int64(0), k.ExpiresAt) + assert.Greater(t, k.LastUseAt, int64(0)) + assert.Equal(t, k.CreatedAt, keyGet.CreatedAt) + assert.Greater(t, k.UpdatedAt, keyGet.UpdatedAt) + assert.Equal(t, apiKey.Description, k.Description) + assert.Empty(t, k.User) + assert.Equal(t, admin.Username, k.Admin) + } + } + assert.True(t, found) + } + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // invalid API keys + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key+"invalid", "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + assert.Contains(t, rr.Body.String(), "the provided api key cannot be authenticated") + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, "invalid", "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // using an API key we cannot modify/get API keys + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setAPIKeyForReq(req, key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + admin.Filters.AllowList = []string{"172.16.18.0/24"} + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + req, err = http.NewRequest(http.MethodDelete, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "the provided api key is not valid") + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestAPIKeySearch(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiKey := dataprovider.APIKey{ + Scope: dataprovider.APIKeyScopeAdmin, + } + for i := 1; i < 5; i++ { + apiKey.Name = fmt.Sprintf("testapikey%v", i) + asJSON, err := json.Marshal(apiKey) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + } + + req, err := http.NewRequest(http.MethodGet, apiKeysPath+"?limit=1&order=ASC", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var keys []dataprovider.APIKey + err = json.Unmarshal(rr.Body.Bytes(), &keys) + assert.NoError(t, err) + assert.Len(t, keys, 1) + firstKey := keys[0] + + req, err = http.NewRequest(http.MethodGet, apiKeysPath+"?limit=1&order=DESC", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + keys = nil + err = json.Unmarshal(rr.Body.Bytes(), &keys) + assert.NoError(t, err) + if assert.Len(t, keys, 1) { + assert.NotEqual(t, firstKey.KeyID, keys[0].KeyID) + } + + req, err = http.NewRequest(http.MethodGet, apiKeysPath+"?limit=1&offset=100", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + keys = nil + err = json.Unmarshal(rr.Body.Bytes(), &keys) + assert.NoError(t, err) + assert.Len(t, keys, 0) + + req, err = http.NewRequest(http.MethodGet, apiKeysPath+"?limit=f", nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("%v/%v", apiKeysPath, "missingid"), nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, apiKeysPath, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + keys = nil + err = json.Unmarshal(rr.Body.Bytes(), &keys) + assert.NoError(t, err) + counter := 0 + for _, k := range keys { + if strings.HasPrefix(k.Name, "testapikey") { + req, err = http.NewRequest(http.MethodDelete, fmt.Sprintf("%v/%v", apiKeysPath, k.KeyID), nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + counter++ + } + } + assert.Equal(t, 4, counter) +} + +func TestAPIKeyErrors(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiKey := dataprovider.APIKey{ + Name: "testkey", + Scope: dataprovider.APIKeyScopeUser, + } + asJSON, err := json.Marshal(apiKey) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + location := rr.Header().Get("Location") + assert.NotEmpty(t, location) + + // invalid API scope + apiKey.Scope = 1000 + asJSON, err = json.Marshal(apiKey) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + // invalid JSON + req, err = http.NewRequest(http.MethodPost, apiKeysPath, bytes.NewBuffer([]byte(`invalid JSON`))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer([]byte(`invalid JSON`))) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodDelete, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodDelete, location, nil) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodPut, location, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestAPIKeyOnDeleteCascade(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + apiKey := dataprovider.APIKey{ + Name: "user api key", + Scope: dataprovider.APIKeyScopeUser, + User: user.Username, + } + + apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr := executeRequest(req) + checkResponseCode(t, http.StatusUnauthorized, rr) + + user.Filters.AllowAPIKeyAuth = true + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, userDirsPath, nil) + assert.NoError(t, err) + setAPIKeyForReq(req, apiKey.Key, "") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var contents []map[string]any + err = json.NewDecoder(rr.Body).Decode(&contents) + assert.NoError(t, err) + assert.Len(t, contents, 0) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusNotFound) + assert.NoError(t, err) + + apiKey.User = "" + apiKey.Admin = admin.Username + apiKey.Scope = dataprovider.APIKeyScopeAdmin + + apiKey, _, err = httpdtest.AddAPIKey(apiKey, http.StatusCreated) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusNotFound) + assert.NoError(t, err) +} + +func TestBasicWebUsersMock(t *testing.T) { + token, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user1 := getTestUser() + user1.Username += "1" + user1AsJSON := getUserAsJSON(t, user1) + req, _ = http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(user1AsJSON)) + setBearerForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user1) + assert.NoError(t, err) + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, webUsersPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webUsersPath+jsonAPISuffix, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(webUserPath, user.Username), nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webUserPath+"/0", nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + form := make(url.Values) + form.Set("username", user.Username) + form.Set(csrfFormToken, csrfToken) + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath+"/0", &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + req, _ = http.NewRequest(http.MethodPost, webUserPath+"/aaa", &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + req, _ = http.NewRequest(http.MethodDelete, path.Join(webUserPath, user.Username), nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") + req, _ = http.NewRequest(http.MethodDelete, path.Join(webUserPath, user.Username), nil) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodDelete, path.Join(webUserPath, user1.Username), nil) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestRenderDefenderPageMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webDefenderPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nDefenderTitle) +} + +func TestWebAdminBasicMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set("username", admin.Username) + form.Set("password", "") + form.Set("status", "1") + form.Set("permissions", "*") + form.Set("description", admin.Description) + form.Set("user_page_hidden_sections", "1") + form.Add("user_page_hidden_sections", "2") + form.Add("user_page_hidden_sections", "3") + form.Add("user_page_hidden_sections", "4") + form.Add("user_page_hidden_sections", "5") + form.Add("user_page_hidden_sections", "6") + form.Add("user_page_hidden_sections", "7") + form.Set("default_users_expiration", "10") + req, _ := http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + form.Set("status", "a") + req, _ = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("status", "1") + form.Set("default_users_expiration", "a") + req, _ = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + form.Set("default_users_expiration", "10") + form.Set("password", admin.Password) + req, _ = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + // add TOTP config + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], altAdminUsername) + assert.NoError(t, err) + altToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + adminTOTPConfig := dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + } + asJSON, err := json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + // no CSRF token + req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") + + req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, altToken) + setCSRFHeaderForReq(req, csrfToken) // invalid CSRF token + req.RemoteAddr = defaultRemoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "the token is not valid") + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminPath, altToken) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, altToken) + setCSRFHeaderForReq(req, csrfToken) + req.RemoteAddr = defaultRemoteAddr + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + secretPayload := admin.Filters.TOTPConfig.Secret.GetPayload() + assert.NotEmpty(t, secretPayload) + assert.Equal(t, 1+2+4+8+16+32+64, admin.Filters.Preferences.HideUserPageSections) + assert.Equal(t, 10, admin.Filters.Preferences.DefaultUsersExpiration) + + adminTOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewEmptySecret(), + } + asJSON, err = json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, altToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Equal(t, secretPayload, admin.Filters.TOTPConfig.Secret.GetPayload()) + + adminTOTPConfig = dataprovider.AdminTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: nil, + } + asJSON, err = json.Marshal(adminTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, altToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webAdminsPath+jsonAPISuffix, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, webAdminsPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webAdminPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("password", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + form.Set(csrfFormToken, csrfToken) // associated to altToken + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webAdminPath, token) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("email", "not-an-email") + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("email", "") + form.Set("status", "b") + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("email", "admin@example.com") + form.Set("status", "0") + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + admin, _, err = httpdtest.GetAdminByUsername(altAdminUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, admin.Filters.TOTPConfig.Enabled) + assert.Equal(t, "admin@example.com", admin.Email) + assert.Equal(t, 0, admin.Status) + + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, altAdminUsername+"1"), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, _ = http.NewRequest(http.MethodGet, path.Join(webAdminPath, altAdminUsername), nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, path.Join(webAdminPath, altAdminUsername+"1"), nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, altAdminUsername), nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusNotFound) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, defaultTokenAuthUser), nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "you cannot delete yourself") + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webAdminPath, defaultTokenAuthUser), nil) + req.RemoteAddr = defaultRemoteAddr + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") +} + +func TestWebAdminGroupsMock(t *testing.T) { + group1 := getTestGroup() + group1.Name += "_1" + group1, _, err := httpdtest.AddGroup(group1, http.StatusCreated) + assert.NoError(t, err) + group2 := getTestGroup() + group2.Name += "_2" + group2, _, err = httpdtest.AddGroup(group2, http.StatusCreated) + assert.NoError(t, err) + group3 := getTestGroup() + group3.Name += "_3" + group3, _, err = httpdtest.AddGroup(group3, http.StatusCreated) + assert.NoError(t, err) + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", admin.Username) + form.Set("password", "") + form.Set("status", "1") + form.Set("permissions", "*") + form.Set("description", admin.Description) + form.Set("password", admin.Password) + form.Set("groups[0][group]", group1.Name) + form.Set("groups[0][group_type]", "1") + form.Set("groups[1][group]", group2.Name) + form.Set("groups[1][group_type]", "2") + form.Set("groups[2][group]", group3.Name) + req, err := http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, admin.Groups, 3) { + for _, g := range admin.Groups { + switch g.Name { + case group1.Name: + assert.Equal(t, dataprovider.GroupAddToUsersAsPrimary, g.Options.AddToUsersAs) + case group2.Name: + assert.Equal(t, dataprovider.GroupAddToUsersAsSecondary, g.Options.AddToUsersAs) + case group3.Name: + assert.Equal(t, dataprovider.GroupAddToUsersAsMembership, g.Options.AddToUsersAs) + default: + t.Errorf("unexpected group %q", g.Name) + } + } + } + adminToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webUserPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodGet, webUserPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, adminToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group3, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebAdminPermissions(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Permissions = []string{dataprovider.PermAdminAddUsers} + _, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTWebToken(altAdminUsername, altAdminPassword) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, httpBaseURL+webUserPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err := httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, httpBaseURL+path.Join(webUserPath, "auser"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, httpBaseURL+webFolderPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, httpBaseURL+webStatusPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, httpBaseURL+webConnectionsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, httpBaseURL+webAdminPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + req, err = http.NewRequest(http.MethodGet, httpBaseURL+path.Join(webAdminPath, "a"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + resp, err = httpclient.GetHTTPClient().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestAdminUpdateSelfMock(t *testing.T) { + admin, _, err := httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) + assert.NoError(t, err) + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminPath, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set("username", admin.Username) + form.Set("password", admin.Password) + form.Set("status", "0") + form.Set("permissions", dataprovider.PermAdminAddUsers) + form.Set("permissions", dataprovider.PermAdminCloseConnections) + form.Set(csrfFormToken, csrfToken) + req, _ := http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorAdminSelfPerms) + + form.Set("permissions", dataprovider.PermAdminAny) + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorAdminSelfDisable) + + form.Set("status", "1") + form.Set("require_two_factor", "1") + form.Set("require_password_change", "1") + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + admin, _, err = httpdtest.GetAdminByUsername(defaultTokenAuthUser, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.RequirePasswordChange) + assert.False(t, admin.Filters.RequireTwoFactor) + + form.Set("role", "my role") + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorAdminSelfRole) +} + +func TestWebMaintenanceMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, webMaintenancePath, nil) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + csrfToken, err := getCSRFTokenFromInternalPageMock(webMaintenancePath, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set("mode", "a") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("mode", "0") + form.Set("quota", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("quota", "0") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodPost, webRestorePath+"?a=%3", &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + backupFilePath := filepath.Join(os.TempDir(), "backup.json") + err = createTestFile(backupFilePath, 0) + assert.NoError(t, err) + + b, contentType, _ = getMultipartFormData(form, "backup_file", backupFilePath) + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + err = createTestFile(backupFilePath, 10) + assert.NoError(t, err) + + b, contentType, _ = getMultipartFormData(form, "backup_file", backupFilePath) + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user := getTestUser() + user.ID = 1 + user.Username = "test_user_web_restore" + admin := getTestAdmin() + admin.ID = 1 + admin.Username = "test_admin_web_restore" + + apiKey := dataprovider.APIKey{ + Name: "key name", + KeyID: util.GenerateUniqueID(), + Key: fmt.Sprintf("%v.%v", util.GenerateUniqueID(), util.GenerateUniqueID()), + Scope: dataprovider.APIKeyScopeAdmin, + } + backupData := dataprovider.BackupData{ + Version: dataprovider.DumpVersion, + } + backupData.Users = append(backupData.Users, user) + backupData.Admins = append(backupData.Admins, admin) + backupData.APIKeys = append(backupData.APIKeys, apiKey) + backupContent, err := json.Marshal(backupData) + assert.NoError(t, err) + err = os.WriteFile(backupFilePath, backupContent, os.ModePerm) + assert.NoError(t, err) + + b, contentType, _ = getMultipartFormData(form, "backup_file", backupFilePath) + req, _ = http.NewRequest(http.MethodPost, webRestorePath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nBackupOK) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + _, _, err = httpdtest.GetAPIKeyByID(apiKey.KeyID, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAPIKey(apiKey, http.StatusOK) + assert.NoError(t, err) + + err = os.Remove(backupFilePath) + assert.NoError(t, err) +} + +func TestWebUserAddMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + group1 := getTestGroup() + group1.Name += "_1" + group1, _, err = httpdtest.AddGroup(group1, http.StatusCreated) + assert.NoError(t, err) + group2 := getTestGroup() + group2.Name += "_2" + group2, _, err = httpdtest.AddGroup(group2, http.StatusCreated) + assert.NoError(t, err) + group3 := getTestGroup() + group3.Name += "_3" + group3, _, err = httpdtest.AddGroup(group3, http.StatusCreated) + assert.NoError(t, err) + user := getTestUser() + user.UploadBandwidth = 32 + user.DownloadBandwidth = 64 + user.UploadDataTransfer = 1000 + user.DownloadDataTransfer = 2000 + user.UID = 1000 + user.AdditionalInfo = "info" + user.Description = "user dsc" + user.Email = "test@test.com" + user.Filters.AdditionalEmails = []string{"example1@test.com", "example2@test.com"} + mappedDir := filepath.Join(os.TempDir(), "mapped") + folderName := filepath.Base(mappedDir) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedDir, + } + folderAsJSON, err := json.Marshal(f) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("email", user.Email) + form.Set("additional_emails[0][additional_email]", user.Filters.AdditionalEmails[0]) + form.Set("additional_emails[1][additional_email]", user.Filters.AdditionalEmails[1]) + form.Set("home_dir", user.HomeDir) + form.Set("osfs_read_buffer_size", "2") + form.Set("osfs_write_buffer_size", "3") + form.Set("password", user.Password) + form.Set("primary_group", group1.Name) + form.Set("secondary_groups", group2.Name) + form.Set("membership_groups", group3.Name) + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "") + form.Set("permissions", "*") + form.Set("directory_permissions[0][sub_perm_path]", "/subdir") + form.Set("directory_permissions[0][sub_perm_permissions][]", "list") + form.Add("directory_permissions[0][sub_perm_permissions][]", "download") + form.Set("virtual_folders[0][vfolder_path]", " /vdir") + form.Set("virtual_folders[0][vfolder_name]", folderName) + form.Set("virtual_folders[0][vfolder_quota_files]", "2") + form.Set("virtual_folders[0][vfolder_quota_size]", "1024") + form.Set("directory_patterns[0][pattern_path]", "/dir2") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("directory_patterns[0][pattern_policy]", "1") + form.Set("directory_patterns[1][pattern_path]", "/dir1") + form.Set("directory_patterns[1][patterns]", "*.png") + form.Set("directory_patterns[1][pattern_type]", "allowed") + form.Set("directory_patterns[2][pattern_path]", "/dir1") + form.Set("directory_patterns[2][patterns]", "*.zip") + form.Set("directory_patterns[2][pattern_type]", "denied") + form.Set("directory_patterns[3][pattern_path]", "/dir3") + form.Set("directory_patterns[3][patterns]", "*.rar") + form.Set("directory_patterns[3][pattern_type]", "denied") + form.Set("directory_patterns[4][pattern_path]", "/dir2") + form.Set("directory_patterns[4][patterns]", "*.mkv") + form.Set("directory_patterns[4][pattern_type]", "denied") + form.Set("access_time_restrictions[0][access_time_day_of_week]", "2") + form.Set("access_time_restrictions[0][access_time_start]", "10") // invalid and no end, ignored + form.Set("access_time_restrictions[1][access_time_day_of_week]", "3") + form.Set("access_time_restrictions[1][access_time_start]", "12:00") + form.Set("access_time_restrictions[1][access_time_end]", "14:09") + form.Set("additional_info", user.AdditionalInfo) + form.Set("description", user.Description) + form.Add("hooks", "external_auth_disabled") + form.Set("disable_fs_checks", "checked") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("start_directory", "start/dir") + form.Set("require_password_change", "1") + b, contentType, _ := getMultipartFormData(form, "", "") + // test invalid url escape + req, _ = http.NewRequest(http.MethodPost, webUserPath+"?a=%2", &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("public_keys", testPubKey) + form.Add("public_keys", testPubKey1) + form.Set("uid", strconv.FormatInt(int64(user.UID), 10)) + form.Set("gid", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid gid + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("gid", "0") + form.Set("max_sessions", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid max sessions + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("max_sessions", "0") + form.Set("quota_size", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid quota size + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("quota_size", "0") + form.Set("quota_files", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid quota files + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("quota_files", "0") + form.Set("upload_bandwidth", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid upload bandwidth + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("upload_bandwidth", strconv.FormatInt(user.UploadBandwidth, 10)) + form.Set("download_bandwidth", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid download bandwidth + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("download_bandwidth", strconv.FormatInt(user.DownloadBandwidth, 10)) + form.Set("upload_data_transfer", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid upload data transfer + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("upload_data_transfer", strconv.FormatInt(user.UploadDataTransfer, 10)) + form.Set("download_data_transfer", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid download data transfer + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("download_data_transfer", strconv.FormatInt(user.DownloadDataTransfer, 10)) + form.Set("total_data_transfer", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid total data transfer + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("total_data_transfer", strconv.FormatInt(user.TotalDataTransfer, 10)) + form.Set("status", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid status + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "123") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid expiration date + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("expiration_date", "") + form.Set("allowed_ip", "invalid,ip") + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid allowed_ip + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("allowed_ip", "") + form.Set("denied_ip", "192.168.1.2") // it should be 192.168.1.2/32 + b, contentType, _ = getMultipartFormData(form, "", "") + // test invalid denied_ip + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("denied_ip", "") + // test invalid max file upload size + form.Set("max_upload_file_size", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("max_upload_file_size", "1 KB") + // test invalid default shares expiration + form.Set("default_shares_expiration", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("default_shares_expiration", "10") + // test invalid max shares expiration + form.Set("max_shares_expiration", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("max_shares_expiration", "30") + // test invalid password expiration + form.Set("password_expiration", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("password_expiration", "90") + // test invalid password strength + form.Set("password_strength", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("password_strength", "60") + // test invalid tls username + form.Set("tls_username", "username") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("tls_username", string(sdk.TLSUsernameNone)) + // invalid upload bandwidth source + form.Set("src_bandwidth_limits[0][bandwidth_limit_sources]", "192.168.1.0/24, 192.168.2.0/25") + form.Set("src_bandwidth_limits[0][upload_bandwidth_source]", "a") + form.Set("src_bandwidth_limits[0][download_bandwidth_source]", "0") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + // invalid download bandwidth source + form.Set("src_bandwidth_limits[0][upload_bandwidth_source]", "256") + form.Set("src_bandwidth_limits[0][download_bandwidth_source]", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("src_bandwidth_limits[0][download_bandwidth_source]", "512") + form.Set("src_bandwidth_limits[1][download_bandwidth_source]", "1024") + form.Set("src_bandwidth_limits[1][bandwidth_limit_sources]", "1.1.1") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorSourceBWLimitInvalid) + form.Set("src_bandwidth_limits[1][bandwidth_limit_sources]", "127.0.0.1/32") + form.Set("src_bandwidth_limits[1][upload_bandwidth_source]", "-1") + // invalid external auth cache size + form.Set("external_auth_cache_time", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("external_auth_cache_time", "0") + form.Set(csrfFormToken, "invalid form token") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + dbUser, err := dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + // the user already exists, was created with the above request + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + newUser := dataprovider.User{} + err = render.DecodeJSON(rr.Body, &newUser) + assert.NoError(t, err) + assert.Equal(t, user.UID, newUser.UID) + assert.Equal(t, 2, newUser.FsConfig.OSConfig.ReadBufferSize) + assert.Equal(t, 3, newUser.FsConfig.OSConfig.WriteBufferSize) + assert.Equal(t, user.UploadBandwidth, newUser.UploadBandwidth) + assert.Equal(t, user.DownloadBandwidth, newUser.DownloadBandwidth) + assert.Equal(t, user.UploadDataTransfer, newUser.UploadDataTransfer) + assert.Equal(t, user.DownloadDataTransfer, newUser.DownloadDataTransfer) + assert.Equal(t, user.TotalDataTransfer, newUser.TotalDataTransfer) + assert.Equal(t, int64(1000), newUser.Filters.MaxUploadFileSize) + assert.Equal(t, user.AdditionalInfo, newUser.AdditionalInfo) + assert.Equal(t, user.Description, newUser.Description) + assert.True(t, newUser.Filters.Hooks.ExternalAuthDisabled) + assert.False(t, newUser.Filters.Hooks.PreLoginDisabled) + assert.False(t, newUser.Filters.Hooks.CheckPasswordDisabled) + assert.True(t, newUser.Filters.DisableFsChecks) + assert.False(t, newUser.Filters.AllowAPIKeyAuth) + assert.Equal(t, user.Email, newUser.Email) + assert.Equal(t, len(user.Filters.AdditionalEmails), len(newUser.Filters.AdditionalEmails)) + assert.Equal(t, "/start/dir", newUser.Filters.StartDirectory) + assert.Equal(t, 0, newUser.Filters.FTPSecurity) + assert.Equal(t, 10, newUser.Filters.DefaultSharesExpiration) + assert.Equal(t, 30, newUser.Filters.MaxSharesExpiration) + assert.Equal(t, 90, newUser.Filters.PasswordExpiration) + assert.Equal(t, 60, newUser.Filters.PasswordStrength) + assert.Greater(t, newUser.LastPasswordChange, int64(0)) + assert.True(t, newUser.Filters.RequirePasswordChange) + assert.True(t, slices.Contains(newUser.PublicKeys, testPubKey)) + if val, ok := newUser.Permissions["/subdir"]; ok { + assert.True(t, slices.Contains(val, dataprovider.PermListItems)) + assert.True(t, slices.Contains(val, dataprovider.PermDownload)) + } else { + assert.Fail(t, "user permissions must contain /somedir", "actual: %v", newUser.Permissions) + } + assert.Len(t, newUser.PublicKeys, 2) + assert.Len(t, newUser.VirtualFolders, 1) + for _, v := range newUser.VirtualFolders { + assert.Equal(t, v.VirtualPath, "/vdir") + assert.Equal(t, v.Name, folderName) + assert.Equal(t, v.MappedPath, mappedDir) + assert.Equal(t, v.QuotaFiles, 2) + assert.Equal(t, v.QuotaSize, int64(1024)) + } + assert.Len(t, newUser.Filters.FilePatterns, 3) + for _, filter := range newUser.Filters.FilePatterns { + switch filter.Path { + case "/dir1": + assert.Len(t, filter.DeniedPatterns, 1) + assert.Len(t, filter.AllowedPatterns, 1) + assert.True(t, slices.Contains(filter.AllowedPatterns, "*.png")) + assert.True(t, slices.Contains(filter.DeniedPatterns, "*.zip")) + assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) + case "/dir2": + assert.Len(t, filter.DeniedPatterns, 1) + assert.Len(t, filter.AllowedPatterns, 2) + assert.True(t, slices.Contains(filter.AllowedPatterns, "*.jpg")) + assert.True(t, slices.Contains(filter.AllowedPatterns, "*.png")) + assert.True(t, slices.Contains(filter.DeniedPatterns, "*.mkv")) + assert.Equal(t, sdk.DenyPolicyHide, filter.DenyPolicy) + case "/dir3": + assert.Len(t, filter.DeniedPatterns, 1) + assert.Len(t, filter.AllowedPatterns, 0) + assert.True(t, slices.Contains(filter.DeniedPatterns, "*.rar")) + assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) + } + } + if assert.Len(t, newUser.Filters.BandwidthLimits, 2) { + for _, bwLimit := range newUser.Filters.BandwidthLimits { + if len(bwLimit.Sources) == 2 { + assert.Equal(t, "192.168.1.0/24", bwLimit.Sources[0]) + assert.Equal(t, "192.168.2.0/25", bwLimit.Sources[1]) + assert.Equal(t, int64(256), bwLimit.UploadBandwidth) + assert.Equal(t, int64(512), bwLimit.DownloadBandwidth) + } else { + assert.Equal(t, []string{"127.0.0.1/32"}, bwLimit.Sources) + assert.Equal(t, int64(0), bwLimit.UploadBandwidth) + assert.Equal(t, int64(1024), bwLimit.DownloadBandwidth) + } + } + } + if assert.Len(t, newUser.Filters.AccessTime, 1) { + assert.Equal(t, 3, newUser.Filters.AccessTime[0].DayOfWeek) + assert.Equal(t, "12:00", newUser.Filters.AccessTime[0].From) + assert.Equal(t, "14:09", newUser.Filters.AccessTime[0].To) + } + assert.Len(t, newUser.Groups, 3) + assert.Equal(t, sdk.TLSUsernameNone, newUser.Filters.TLSUsername) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, newUser.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + _, err = httpdtest.RemoveGroup(group1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group3, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebUserUpdateMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + user := getTestUser() + user.Filters.BandwidthLimits = []sdk.BandwidthLimit{ + { + Sources: []string{"10.8.0.0/16", "192.168.1.0/25"}, + UploadBandwidth: 256, + DownloadBandwidth: 512, + }, + } + user.TotalDataTransfer = 4000 + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + lastPwdChange := user.LastPasswordChange + assert.Greater(t, lastPwdChange, int64(0)) + // add TOTP config + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + userToken, err := getJWTWebClientTokenFromTestServer(defaultUsername, defaultPassword) + assert.NoError(t, err) + userTOTPConfig := dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH, common.ProtocolFTP}, + } + asJSON, err := json.Marshal(userTOTPConfig) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, userToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") + + csrfToken, err := getCSRFTokenFromInternalPageMock(webClientProfilePath, userToken) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientTOTPSavePath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, userToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, int64(4000), user.TotalDataTransfer) + assert.Equal(t, lastPwdChange, user.LastPasswordChange) + if assert.Len(t, user.Filters.BandwidthLimits, 1) { + if assert.Len(t, user.Filters.BandwidthLimits[0].Sources, 2) { + assert.Equal(t, "10.8.0.0/16", user.Filters.BandwidthLimits[0].Sources[0]) + assert.Equal(t, "192.168.1.0/25", user.Filters.BandwidthLimits[0].Sources[1]) + } + assert.Equal(t, int64(256), user.Filters.BandwidthLimits[0].UploadBandwidth) + assert.Equal(t, int64(512), user.Filters.BandwidthLimits[0].DownloadBandwidth) + } + + dbUser, err := dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.MaxSessions = 1 + user.QuotaFiles = 2 + user.QuotaSize = 1000 * 1000 * 1000 + user.GID = 1000 + user.Filters.AllowAPIKeyAuth = true + user.AdditionalInfo = "new additional info" + user.Email = "user@example.com" + form := make(url.Values) + form.Set("username", user.Username) + form.Set("email", user.Email) + form.Set("password", "") + form.Set("public_keys[0][public_key]", testPubKey) + form.Set("tls_certs[0][tls_cert]", httpsCert) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", "1 GB") + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("permissions", "*") + form.Set("directory_permissions[0][sub_perm_path]", "/otherdir") + form.Set("directory_permissions[0][sub_perm_permissions][]", "list") + form.Add("directory_permissions[0][sub_perm_permissions][]", "upload") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", " 192.168.1.3/32, 192.168.2.0/24 ") + form.Set("denied_ip", " 10.0.0.2/32 ") + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.zip") + form.Set("directory_patterns[0][pattern_type]", "denied") + form.Set("denied_login_methods", dataprovider.SSHLoginMethodKeyboardInteractive) + form.Set("denied_protocols", common.ProtocolFTP) + form.Set("max_upload_file_size", "100") + form.Set("default_shares_expiration", "30") + form.Set("max_shares_expiration", "60") + form.Set("password_expiration", "60") + form.Set("password_strength", "40") + form.Set("disconnect", "1") + form.Set("additional_info", user.AdditionalInfo) + form.Set("description", user.Description) + form.Set("tls_username", string(sdk.TLSUsernameCN)) + form.Set("allow_api_key_auth", "1") + form.Set("require_password_change", "1") + form.Set("external_auth_cache_time", "120") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + dbUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.Empty(t, dbUser.Password) + assert.False(t, dbUser.IsPasswordHashed()) + + form.Set("password", defaultPassword) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + dbUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + prevPwd := dbUser.Password + + form.Set("password", redactedSecret) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + dbUser, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.NotEmpty(t, dbUser.Password) + assert.True(t, dbUser.IsPasswordHashed()) + assert.Equal(t, prevPwd, dbUser.Password) + assert.True(t, dbUser.Filters.TOTPConfig.Enabled) + + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, user.Email, updateUser.Email) + assert.Equal(t, user.HomeDir, updateUser.HomeDir) + assert.Equal(t, user.MaxSessions, updateUser.MaxSessions) + assert.Equal(t, user.QuotaFiles, updateUser.QuotaFiles) + assert.Equal(t, user.QuotaSize, updateUser.QuotaSize) + assert.Equal(t, user.UID, updateUser.UID) + assert.Equal(t, user.GID, updateUser.GID) + assert.Equal(t, user.AdditionalInfo, updateUser.AdditionalInfo) + assert.Equal(t, user.Description, updateUser.Description) + assert.Equal(t, int64(100), updateUser.Filters.MaxUploadFileSize) + assert.Equal(t, sdk.TLSUsernameCN, updateUser.Filters.TLSUsername) + assert.True(t, updateUser.Filters.AllowAPIKeyAuth) + assert.True(t, updateUser.Filters.TOTPConfig.Enabled) + assert.Equal(t, int64(0), updateUser.TotalDataTransfer) + assert.Equal(t, int64(0), updateUser.DownloadDataTransfer) + assert.Equal(t, int64(0), updateUser.UploadDataTransfer) + assert.Equal(t, int64(0), updateUser.Filters.ExternalAuthCacheTime) + assert.Equal(t, 30, updateUser.Filters.DefaultSharesExpiration) + assert.Equal(t, 60, updateUser.Filters.MaxSharesExpiration) + assert.Equal(t, 60, updateUser.Filters.PasswordExpiration) + assert.Equal(t, 40, updateUser.Filters.PasswordStrength) + assert.True(t, updateUser.Filters.RequirePasswordChange) + if val, ok := updateUser.Permissions["/otherdir"]; ok { + assert.True(t, slices.Contains(val, dataprovider.PermListItems)) + assert.True(t, slices.Contains(val, dataprovider.PermUpload)) + } else { + assert.Fail(t, "user permissions must contains /otherdir", "actual: %v", updateUser.Permissions) + } + assert.True(t, slices.Contains(updateUser.Filters.AllowedIP, "192.168.1.3/32")) + assert.True(t, slices.Contains(updateUser.Filters.DeniedIP, "10.0.0.2/32")) + assert.True(t, slices.Contains(updateUser.Filters.DeniedLoginMethods, dataprovider.SSHLoginMethodKeyboardInteractive)) + assert.True(t, slices.Contains(updateUser.Filters.DeniedProtocols, common.ProtocolFTP)) + assert.True(t, slices.Contains(updateUser.Filters.FilePatterns[0].DeniedPatterns, "*.zip")) + assert.Len(t, updateUser.Filters.BandwidthLimits, 0) + assert.Len(t, updateUser.Filters.TLSCerts, 1) + req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestRenderFolderTemplateMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webTemplateFolder, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + folder := vfs.BaseVirtualFolder{ + Name: "templatefolder", + MappedPath: filepath.Join(os.TempDir(), "mapped"), + Description: "template folder desc", + } + folder, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webTemplateFolder+fmt.Sprintf("?from=%v", folder.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webTemplateFolder+"?from=unknown-folder", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) +} + +func TestRenderUserTemplateMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, webTemplateUser, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webTemplateUser+fmt.Sprintf("?from=%v", user.Username), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webTemplateUser+"?from=unknown", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserSaveFromTemplateMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) + assert.NoError(t, err) + user1 := "u1" + user2 := "u2" + form := make(url.Values) + form.Set("username", "") + form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("uid", "0") + form.Set("gid", "0") + form.Set("max_sessions", "0") + form.Set("quota_size", "0") + form.Set("quota_files", "0") + form.Set("permissions", "*") + form.Set("status", "1") + form.Set("expiration_date", "") + form.Set("fs_provider", "0") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Set("external_auth_cache_time", "0") + form.Add("template_users[0][tpl_username]", user1) + form.Add("template_users[0][tpl_password]", "password1") + form.Add("template_users[0][tpl_public_keys]", " ") + form.Add("template_users[1][tpl_username]", user2) + form.Add("template_users[1][tpl_public_keys]", testPubKey) + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ := http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + u1, _, err := httpdtest.GetUserByUsername(user1, http.StatusOK) + assert.NoError(t, err) + assert.False(t, u1.Filters.RequirePasswordChange) + u2, _, err := httpdtest.GetUserByUsername(user2, http.StatusOK) + assert.NoError(t, err) + assert.False(t, u2.Filters.RequirePasswordChange) + + _, err = httpdtest.RemoveUser(u1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(u2, http.StatusOK) + assert.NoError(t, err) + + form.Add("tpl_require_password_change", "checked") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + u1, _, err = httpdtest.GetUserByUsername(user1, http.StatusOK) + assert.NoError(t, err) + assert.True(t, u1.Filters.RequirePasswordChange) + u2, _, err = httpdtest.GetUserByUsername(user2, http.StatusOK) + assert.NoError(t, err) + assert.True(t, u2.Filters.RequirePasswordChange) + + _, err = httpdtest.RemoveUser(u1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(u2, http.StatusOK) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + + b, contentType, _ = getMultipartFormData(form, "", "") + req, err = http.NewRequest(http.MethodPost, webTemplateUser, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestUserTemplateErrors(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) + assert.NoError(t, err) + user := getTestUser() + user.FsConfig.Provider = sdk.S3FilesystemProvider + user.FsConfig.S3Config.Bucket = "test" + user.FsConfig.S3Config.Region = "eu-central-1" + user.FsConfig.S3Config.AccessKey = "%username%" + user.FsConfig.S3Config.KeyPrefix = "somedir/subdir/" + user.FsConfig.S3Config.UploadPartSize = 5 + user.FsConfig.S3Config.UploadConcurrency = 4 + user.FsConfig.S3Config.DownloadPartSize = 6 + user.FsConfig.S3Config.DownloadConcurrency = 3 + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "1") + form.Set("s3_bucket", user.FsConfig.S3Config.Bucket) + form.Set("s3_region", user.FsConfig.S3Config.Region) + form.Set("s3_access_key", "%username%") + form.Set("s3_access_secret", "%password%") + form.Set("s3_sse_customer_key", "%password%") + form.Set("s3_key_prefix", "base/%username%") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Add("hooks", "external_auth_disabled") + form.Add("hooks", "check_password_disabled") + form.Set("disable_fs_checks", "checked") + form.Set("s3_download_part_max_time", "0") + form.Set("s3_upload_part_max_time", "0") + // test invalid s3_upload_part_size + form.Set("s3_upload_part_size", "a") + form.Set("form_action", "export_from_template") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ := http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + form.Set("s3_upload_part_size", strconv.FormatInt(user.FsConfig.S3Config.UploadPartSize, 10)) + form.Set("s3_upload_concurrency", strconv.Itoa(user.FsConfig.S3Config.UploadConcurrency)) + form.Set("s3_download_part_size", strconv.FormatInt(user.FsConfig.S3Config.DownloadPartSize, 10)) + form.Set("s3_download_concurrency", strconv.Itoa(user.FsConfig.S3Config.DownloadConcurrency)) + // no user defined + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorUserTemplate) + + form.Set("template_users[0][tpl_username]", "user1") + form.Set("template_users[0][tpl_password]", "password1") + form.Set("template_users[0][tpl_public_keys]", "invalid-pkey") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + require.Contains(t, rr.Body.String(), util.I18nErrorPubKeyInvalid) + + form.Set("template_users[0][tpl_username]", " ") + form.Set("template_users[0][tpl_password]", "pwd") + form.Set("template_users[0][tpl_public_keys]", testPubKey) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + require.Contains(t, rr.Body.String(), util.I18nErrorUserTemplate) +} + +func TestUserTemplateRoleAndPermissions(t *testing.T) { + r1 := getTestRole() + r2 := getTestRole() + r2.Name += "_mod" + role1, resp, err := httpdtest.AddRole(r1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + role2, resp, err := httpdtest.AddRole(r2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Role = role1.Name + admin.Permissions = []string{dataprovider.PermAdminManageFolders, dataprovider.PermAdminChangeUsers, + dataprovider.PermAdminViewUsers} + admin, _, err = httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + + token, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, webTemplateUser, nil) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) + assert.NoError(t, err) + user1 := "u1" + user2 := "u2" + form := make(url.Values) + form.Set("username", "") + form.Set("role", role2.Name) + form.Set("home_dir", filepath.Join(os.TempDir(), "%username%")) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("uid", "0") + form.Set("gid", "0") + form.Set("max_sessions", "0") + form.Set("quota_size", "0") + form.Set("quota_files", "0") + form.Set("permissions", "*") + form.Set("status", "1") + form.Set("expiration_date", "") + form.Set("fs_provider", "0") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Set("external_auth_cache_time", "0") + form.Add("template_users[0][tpl_username]", user1) + form.Add("template_users[0][tpl_password]", "password1") + form.Add("template_users[0][tpl_public_keys]", " ") + form.Add("template_users[1][tpl_username]", user2) + form.Add("template_users[1][tpl_public_keys]", testPubKey) + form.Set(csrfFormToken, csrfToken) + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + // Add the required permissions + admin.Permissions = append(admin.Permissions, dataprovider.PermAdminAddUsers) + _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + token, err = getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodGet, webTemplateUser, nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + csrfToken, err = getCSRFTokenFromInternalPageMock(webTemplateUser, token) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + u1, _, err := httpdtest.GetUserByUsername(user1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, admin.Role, u1.Role) + u2, _, err := httpdtest.GetUserByUsername(user2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, admin.Role, u2.Role) + + _, err = httpdtest.RemoveUser(u1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(u2, http.StatusOK) + assert.NoError(t, err) + // Set an empty role + form.Set("role", "") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateUser, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + u1, _, err = httpdtest.GetUserByUsername(user1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, admin.Role, u1.Role) + u2, _, err = httpdtest.GetUserByUsername(user2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, admin.Role, u2.Role) + + _, err = httpdtest.RemoveUser(u1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(u2, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role2, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserPlaceholders(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, token) + assert.NoError(t, err) + u := getTestUser() + u.HomeDir = filepath.Join(os.TempDir(), "%username%_%password%") + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", u.Username) + form.Set("home_dir", u.HomeDir) + form.Set("password", u.Password) + form.Set("status", strconv.Itoa(u.Status)) + form.Set("expiration_date", "") + form.Set("permissions", "*") + form.Set("public_keys[0][public_key]", testPubKey) + form.Set("public_keys[1][public_key]", testPubKey1) + form.Set("uid", "0") + form.Set("gid", "0") + form.Set("max_sessions", "0") + form.Set("quota_size", "0") + form.Set("quota_files", "0") + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("total_data_transfer", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ := http.NewRequest(http.MethodPost, webUserPath, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, filepath.Join(os.TempDir(), fmt.Sprintf("%v_%v", defaultUsername, defaultPassword)), user.HomeDir) + + dbUser, err := dataprovider.UserExists(defaultUsername, "") + assert.NoError(t, err) + assert.True(t, dbUser.IsPasswordHashed()) + hashedPwd := dbUser.Password + + form.Set("password", redactedSecret) + b, contentType, _ = getMultipartFormData(form, "", "") + req, err = http.NewRequest(http.MethodPost, path.Join(webUserPath, defaultUsername), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, filepath.Join(os.TempDir(), defaultUsername+"_%password%"), user.HomeDir) + // check that the password was unchanged + dbUser, err = dataprovider.UserExists(defaultUsername, "") + assert.NoError(t, err) + assert.True(t, dbUser.IsPasswordHashed()) + assert.Equal(t, hashedPwd, dbUser.Password) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestFolderPlaceholders(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, token) + assert.NoError(t, err) + folderName := "folderName" + form := make(url.Values) + form.Set("name", folderName) + form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%")) + form.Set("description", "desc folder %name%") + form.Set(csrfFormToken, csrfToken) + b, contentType, _ := getMultipartFormData(form, "", "") + req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + folderGet, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, filepath.Join(os.TempDir(), folderName), folderGet.MappedPath) + assert.Equal(t, fmt.Sprintf("desc folder %v", folderName), folderGet.Description) + + form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%_%name%")) + b, contentType, _ = getMultipartFormData(form, "", "") + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, filepath.Join(os.TempDir(), fmt.Sprintf("%v_%v", folderName, folderName)), folderGet.MappedPath) + assert.Equal(t, fmt.Sprintf("desc folder %v", folderName), folderGet.Description) + + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) +} + +func TestFolderSaveFromTemplateMock(t *testing.T) { + folder1 := "f1" + folder2 := "f2" + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set("name", "name") + form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%")) + form.Set("description", "desc folder %name%") + form.Set("template_folders[0][tpl_foldername]", folder1) + form.Set("template_folders[1][tpl_foldername]", folder2) + form.Set(csrfFormToken, csrfToken) + b, contentType, _ := getMultipartFormData(form, "", "") + req, err := http.NewRequest(http.MethodPost, webTemplateFolder, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + _, _, err = httpdtest.GetFolderByName(folder1, http.StatusOK) + assert.NoError(t, err) + _, _, err = httpdtest.GetFolderByName(folder2, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folder2}, http.StatusOK) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + + b, contentType, _ = getMultipartFormData(form, "", "") + req, err = http.NewRequest(http.MethodPost, webTemplateFolder, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestFolderTemplateErrors(t *testing.T) { + folderName := "vfolder-template" + mappedPath := filepath.Join(os.TempDir(), "%name%mapped%name%path") + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateFolder, token) + assert.NoError(t, err) + form := make(url.Values) + form.Set("name", folderName) + form.Set("mapped_path", mappedPath) + form.Set("description", "desc folder %name%") + form.Set("template_folders[0][tpl_foldername]", "folder1") + form.Set("template_folders[1][tpl_foldername]", "folder2") + form.Set("template_folders[2][tpl_foldername]", "folder3") + form.Set("template_folders[3][tpl_foldername]", "folder1 ") + form.Add("template_folders[3][tpl_foldername]", " ") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ := http.NewRequest(http.MethodPost, webTemplateFolder, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateFolder+"?param=p%C3%AO%GG", &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + form.Set("fs_provider", "1") + form.Set("s3_bucket", "bucket") + form.Set("s3_region", "us-east-1") + form.Set("s3_access_key", "%name%") + form.Set("s3_access_secret", "pwd%name%") + form.Set("s3_sse_customer_key", "key%name%") + form.Set("s3_key_prefix", "base/%name%") + + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + form.Set("s3_upload_part_size", "5") + form.Set("s3_upload_concurrency", "4") + form.Set("s3_download_part_max_time", "0") + form.Set("s3_upload_part_max_time", "0") + form.Set("s3_download_part_size", "6") + form.Set("s3_download_concurrency", "2") + + form.Set("template_folders[0][tpl_foldername]", " ") + form.Set("template_folders[1][tpl_foldername]", "") + form.Set("template_folders[2][tpl_foldername]", "") + form.Set("template_folders[3][tpl_foldername]", " ") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFolderTemplate) + + form.Set("template_folders[0][tpl_foldername]", "name") + form.Set("mapped_path", "relative-path") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, webTemplateFolder, &b) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidHomeDir) +} + +func TestFolderTemplatePermission(t *testing.T) { + admin := getTestAdmin() + admin.Username = altAdminUsername + admin.Password = altAdminPassword + admin.Permissions = []string{dataprovider.PermAdminChangeUsers, dataprovider.PermAdminAddUsers, dataprovider.PermAdminViewUsers} + admin, _, err := httpdtest.AddAdmin(admin, http.StatusCreated) + assert.NoError(t, err) + // no permission to view or add folders from templates + token, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webTemplateUser, token) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webTemplateFolder, nil) + assert.NoError(t, err) + req.RequestURI = webTemplateFolder + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + form := make(url.Values) + form.Set("name", "name") + form.Set("mapped_path", filepath.Join(os.TempDir(), "%name%")) + form.Set("description", "desc folder %name%") + form.Set("template_folders[0][tpl_foldername]", "folder1") + form.Set("template_folders[1][tpl_foldername]", "folder2") + form.Set(csrfFormToken, csrfToken) + b, contentType, _ := getMultipartFormData(form, "", "") + req, err = http.NewRequest(http.MethodPost, webTemplateFolder, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + + admin.Permissions = append(admin.Permissions, dataprovider.PermAdminManageFolders) + _, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + token, err = getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + _, err = getCSRFTokenFromInternalPageMock(webTemplateUser, token) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webTemplateFolder, nil) + assert.NoError(t, err) + req.RequestURI = webTemplateFolder + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebUserS3Mock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + lastPwdChange := user.LastPasswordChange + assert.Greater(t, lastPwdChange, int64(0)) + user.FsConfig.Provider = sdk.S3FilesystemProvider + user.FsConfig.S3Config.Bucket = "test" + user.FsConfig.S3Config.Region = "eu-west-1" + user.FsConfig.S3Config.AccessKey = "access-key" + user.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("access-secret") + user.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("enc-key") + user.FsConfig.S3Config.RoleARN = "arn:aws:iam::123456789012:user/Development/product_1234/*" + user.FsConfig.S3Config.Endpoint = "http://127.0.0.1:9000/path?a=b" + user.FsConfig.S3Config.StorageClass = "Standard" + user.FsConfig.S3Config.KeyPrefix = "somedir/subdir/" + user.FsConfig.S3Config.UploadPartSize = 5 + user.FsConfig.S3Config.UploadConcurrency = 4 + user.FsConfig.S3Config.DownloadPartMaxTime = 60 + user.FsConfig.S3Config.UploadPartMaxTime = 120 + user.FsConfig.S3Config.DownloadPartSize = 6 + user.FsConfig.S3Config.DownloadConcurrency = 3 + user.FsConfig.S3Config.ForcePathStyle = true + user.FsConfig.S3Config.SkipTLSVerify = true + user.FsConfig.S3Config.ACL = "public-read" + user.Description = "s3 tèst user" + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("password", redactedSecret) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "1") + form.Set("s3_bucket", user.FsConfig.S3Config.Bucket) + form.Set("s3_region", user.FsConfig.S3Config.Region) + form.Set("s3_access_key", user.FsConfig.S3Config.AccessKey) + form.Set("s3_access_secret", user.FsConfig.S3Config.AccessSecret.GetPayload()) + form.Set("s3_sse_customer_key", user.FsConfig.S3Config.SSECustomerKey.GetPayload()) + form.Set("s3_role_arn", user.FsConfig.S3Config.RoleARN) + form.Set("s3_storage_class", user.FsConfig.S3Config.StorageClass) + form.Set("s3_acl", user.FsConfig.S3Config.ACL) + form.Set("s3_endpoint", user.FsConfig.S3Config.Endpoint) + form.Set("s3_key_prefix", user.FsConfig.S3Config.KeyPrefix) + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("directory_patterns[0][pattern_policy]", "0") + form.Set("directory_patterns[1][pattern_path]", "/dir2") + form.Set("directory_patterns[1][patterns]", "*.zip") + form.Set("directory_patterns[1][pattern_type]", "denied") + form.Set("directory_patterns[1][pattern_policy]", "1") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Set("ftp_security", "1") + form.Set("s3_force_path_style", "checked") + form.Set("s3_skip_tls_verify", "checked") + form.Set("description", user.Description) + form.Add("hooks", "pre_login_disabled") + form.Add("allow_api_key_auth", "1") + // test invalid s3_upload_part_size + form.Set("s3_upload_part_size", "a") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid s3_upload_concurrency + form.Set("s3_upload_part_size", strconv.FormatInt(user.FsConfig.S3Config.UploadPartSize, 10)) + form.Set("s3_upload_concurrency", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid s3_download_part_size + form.Set("s3_upload_concurrency", strconv.Itoa(user.FsConfig.S3Config.UploadConcurrency)) + form.Set("s3_download_part_size", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid s3_download_concurrency + form.Set("s3_download_part_size", strconv.FormatInt(user.FsConfig.S3Config.DownloadPartSize, 10)) + form.Set("s3_download_concurrency", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid s3_download_part_max_time + form.Set("s3_download_concurrency", strconv.Itoa(user.FsConfig.S3Config.DownloadConcurrency)) + form.Set("s3_download_part_max_time", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid s3_upload_part_max_time + form.Set("s3_download_part_max_time", strconv.Itoa(user.FsConfig.S3Config.DownloadPartMaxTime)) + form.Set("s3_upload_part_max_time", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now add the user + form.Set("s3_upload_part_max_time", strconv.Itoa(user.FsConfig.S3Config.UploadPartMaxTime)) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) + assert.Equal(t, updateUser.FsConfig.S3Config.Bucket, user.FsConfig.S3Config.Bucket) + assert.Equal(t, updateUser.FsConfig.S3Config.Region, user.FsConfig.S3Config.Region) + assert.Equal(t, updateUser.FsConfig.S3Config.AccessKey, user.FsConfig.S3Config.AccessKey) + assert.Equal(t, updateUser.FsConfig.S3Config.RoleARN, user.FsConfig.S3Config.RoleARN) + assert.Equal(t, updateUser.FsConfig.S3Config.StorageClass, user.FsConfig.S3Config.StorageClass) + assert.Equal(t, updateUser.FsConfig.S3Config.ACL, user.FsConfig.S3Config.ACL) + assert.Equal(t, updateUser.FsConfig.S3Config.Endpoint, user.FsConfig.S3Config.Endpoint) + assert.Equal(t, updateUser.FsConfig.S3Config.KeyPrefix, user.FsConfig.S3Config.KeyPrefix) + assert.Equal(t, updateUser.FsConfig.S3Config.UploadPartSize, user.FsConfig.S3Config.UploadPartSize) + assert.Equal(t, updateUser.FsConfig.S3Config.UploadConcurrency, user.FsConfig.S3Config.UploadConcurrency) + assert.Equal(t, updateUser.FsConfig.S3Config.DownloadPartMaxTime, user.FsConfig.S3Config.DownloadPartMaxTime) + assert.Equal(t, updateUser.FsConfig.S3Config.UploadPartMaxTime, user.FsConfig.S3Config.UploadPartMaxTime) + assert.Equal(t, updateUser.FsConfig.S3Config.DownloadPartSize, user.FsConfig.S3Config.DownloadPartSize) + assert.Equal(t, updateUser.FsConfig.S3Config.DownloadConcurrency, user.FsConfig.S3Config.DownloadConcurrency) + assert.Equal(t, lastPwdChange, updateUser.LastPasswordChange) + assert.True(t, updateUser.FsConfig.S3Config.ForcePathStyle) + assert.True(t, updateUser.FsConfig.S3Config.SkipTLSVerify) + if assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) { + for _, filter := range updateUser.Filters.FilePatterns { + switch filter.Path { + case "/dir1": + assert.Equal(t, sdk.DenyPolicyDefault, filter.DenyPolicy) + case "/dir2": + assert.Equal(t, sdk.DenyPolicyHide, filter.DenyPolicy) + } + } + } + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, updateUser.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, updateUser.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.S3Config.SSECustomerKey.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetPayload()) + assert.Empty(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetKey()) + assert.Empty(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) + assert.Equal(t, user.Description, updateUser.Description) + assert.True(t, updateUser.Filters.Hooks.PreLoginDisabled) + assert.False(t, updateUser.Filters.Hooks.ExternalAuthDisabled) + assert.False(t, updateUser.Filters.Hooks.CheckPasswordDisabled) + assert.False(t, updateUser.Filters.DisableFsChecks) + assert.True(t, updateUser.Filters.AllowAPIKeyAuth) + assert.Equal(t, 1, updateUser.Filters.FTPSecurity) + // now check that a redacted password is not saved + form.Set("s3_access_secret", redactedSecret) + form.Set("s3_sse_customer_key", redactedSecret) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + var lastUpdatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &lastUpdatedUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetStatus()) + assert.Equal(t, updateUser.FsConfig.S3Config.AccessSecret.GetPayload(), lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.AccessSecret.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetStatus()) + assert.Equal(t, updateUser.FsConfig.S3Config.SSECustomerKey.GetPayload(), lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.S3Config.SSECustomerKey.GetAdditionalData()) + assert.Equal(t, lastPwdChange, lastUpdatedUser.LastPasswordChange) + // now clear credentials + form.Set("s3_access_key", "") + form.Set("s3_access_secret", "") + form.Set("s3_sse_customer_key", "") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var userGet dataprovider.User + err = render.DecodeJSON(rr.Body, &userGet) + assert.NoError(t, err) + assert.Nil(t, userGet.FsConfig.S3Config.AccessSecret) + assert.Nil(t, userGet.FsConfig.S3Config.SSECustomerKey) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebUserGCSMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + credentialsFilePath := filepath.Join(os.TempDir(), "gcs.json") + err = createTestFile(credentialsFilePath, 0) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.GCSFilesystemProvider + user.FsConfig.GCSConfig.Bucket = "test" + user.FsConfig.GCSConfig.KeyPrefix = "somedir/subdir/" + user.FsConfig.GCSConfig.StorageClass = "standard" + user.FsConfig.GCSConfig.ACL = "publicReadWrite" + user.FsConfig.GCSConfig.UploadPartSize = 16 + user.FsConfig.GCSConfig.UploadPartMaxTime = 32 + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("password", redactedSecret) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "2") + form.Set("gcs_bucket", user.FsConfig.GCSConfig.Bucket) + form.Set("gcs_storage_class", user.FsConfig.GCSConfig.StorageClass) + form.Set("gcs_acl", user.FsConfig.GCSConfig.ACL) + form.Set("gcs_key_prefix", user.FsConfig.GCSConfig.KeyPrefix) + form.Set("gcs_upload_part_size", strconv.FormatInt(user.FsConfig.GCSConfig.UploadPartSize, 10)) + form.Set("gcs_upload_part_max_time", strconv.FormatInt(int64(user.FsConfig.GCSConfig.UploadPartMaxTime), 10)) + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Set("ftp_security", "1") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + b, contentType, _ = getMultipartFormData(form, "gcs_credential_file", credentialsFilePath) + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = createTestFile(credentialsFilePath, 4096) + assert.NoError(t, err) + b, contentType, _ = getMultipartFormData(form, "gcs_credential_file", credentialsFilePath) + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) + assert.Equal(t, user.FsConfig.Provider, updateUser.FsConfig.Provider) + assert.Equal(t, user.FsConfig.GCSConfig.Bucket, updateUser.FsConfig.GCSConfig.Bucket) + assert.Equal(t, user.FsConfig.GCSConfig.StorageClass, updateUser.FsConfig.GCSConfig.StorageClass) + assert.Equal(t, user.FsConfig.GCSConfig.ACL, updateUser.FsConfig.GCSConfig.ACL) + assert.Equal(t, user.FsConfig.GCSConfig.KeyPrefix, updateUser.FsConfig.GCSConfig.KeyPrefix) + assert.Equal(t, user.FsConfig.GCSConfig.UploadPartSize, updateUser.FsConfig.GCSConfig.UploadPartSize) + assert.Equal(t, user.FsConfig.GCSConfig.UploadPartMaxTime, updateUser.FsConfig.GCSConfig.UploadPartMaxTime) + if assert.Len(t, updateUser.Filters.FilePatterns, 1) { + assert.Equal(t, "/dir1", updateUser.Filters.FilePatterns[0].Path) + assert.Len(t, updateUser.Filters.FilePatterns[0].AllowedPatterns, 2) + assert.Contains(t, updateUser.Filters.FilePatterns[0].AllowedPatterns, "*.png") + assert.Contains(t, updateUser.Filters.FilePatterns[0].AllowedPatterns, "*.jpg") + } + assert.Equal(t, 1, updateUser.Filters.FTPSecurity) + form.Set("gcs_auto_credentials", "on") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + updateUser = dataprovider.User{} + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, 1, updateUser.FsConfig.GCSConfig.AutomaticCredentials) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = os.Remove(credentialsFilePath) + assert.NoError(t, err) +} + +func TestWebUserHTTPFsMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, err := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.HTTPFilesystemProvider + user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "https://127.0.0.1:9999/api/v1", + Username: defaultUsername, + SkipTLSVerify: true, + }, + Password: kms.NewPlainSecret(defaultPassword), + APIKey: kms.NewPlainSecret(defaultTokenAuthPass), + } + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("password", redactedSecret) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "6") + form.Set("http_endpoint", user.FsConfig.HTTPConfig.Endpoint) + form.Set("http_username", user.FsConfig.HTTPConfig.Username) + form.Set("http_password", user.FsConfig.HTTPConfig.Password.GetPayload()) + form.Set("http_api_key", user.FsConfig.HTTPConfig.APIKey.GetPayload()) + form.Set("http_skip_tls_verify", "checked") + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("directory_patterns[1][pattern_path]", "/dir2") + form.Set("directory_patterns[1][patterns]", "*.zip") + form.Set("directory_patterns[1][pattern_type]", "denied") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Set("http_equality_check_mode", "true") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the updated user + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) + assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) + assert.Equal(t, user.FsConfig.HTTPConfig.Endpoint, updateUser.FsConfig.HTTPConfig.Endpoint) + assert.Equal(t, user.FsConfig.HTTPConfig.Username, updateUser.FsConfig.HTTPConfig.Username) + assert.Equal(t, user.FsConfig.HTTPConfig.SkipTLSVerify, updateUser.FsConfig.HTTPConfig.SkipTLSVerify) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, updateUser.FsConfig.HTTPConfig.Password.GetKey()) + assert.Empty(t, updateUser.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.HTTPConfig.APIKey.GetPayload()) + assert.Empty(t, updateUser.FsConfig.HTTPConfig.APIKey.GetKey()) + assert.Empty(t, updateUser.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + assert.Equal(t, 1, updateUser.FsConfig.HTTPConfig.EqualityCheckMode) + // now check that a redacted password is not saved + form.Set("http_equality_check_mode", "") + form.Set("http_password", " "+redactedSecret+" ") + form.Set("http_api_key", redactedSecret) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var lastUpdatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &lastUpdatedUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.HTTPConfig.Password.GetStatus()) + assert.Equal(t, updateUser.FsConfig.HTTPConfig.Password.GetPayload(), lastUpdatedUser.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.Password.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.Equal(t, updateUser.FsConfig.HTTPConfig.APIKey.GetPayload(), lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + assert.Equal(t, 0, lastUpdatedUser.FsConfig.HTTPConfig.EqualityCheckMode) + + req, err = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebUserAzureBlobMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + user.FsConfig.AzBlobConfig.Container = "container" + user.FsConfig.AzBlobConfig.AccountName = "aname" + user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("access-skey") + user.FsConfig.AzBlobConfig.Endpoint = "http://127.0.0.1:9000/path?b=c" + user.FsConfig.AzBlobConfig.KeyPrefix = "somedir/subdir/" + user.FsConfig.AzBlobConfig.UploadPartSize = 5 + user.FsConfig.AzBlobConfig.UploadConcurrency = 4 + user.FsConfig.AzBlobConfig.DownloadPartSize = 3 + user.FsConfig.AzBlobConfig.DownloadConcurrency = 6 + user.FsConfig.AzBlobConfig.UseEmulator = true + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("password", redactedSecret) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "3") + form.Set("az_container", user.FsConfig.AzBlobConfig.Container) + form.Set("az_account_name", user.FsConfig.AzBlobConfig.AccountName) + form.Set("az_account_key", user.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + form.Set("az_endpoint", user.FsConfig.AzBlobConfig.Endpoint) + form.Set("az_key_prefix", user.FsConfig.AzBlobConfig.KeyPrefix) + form.Set("az_use_emulator", "checked") + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("directory_patterns[1][pattern_path]", "/dir2") + form.Set("directory_patterns[1][patterns]", "*.zip") + form.Set("directory_patterns[1][pattern_type]", "denied") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + // test invalid az_upload_part_size + form.Set("az_upload_part_size", "a") + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid az_upload_concurrency + form.Set("az_upload_part_size", strconv.FormatInt(user.FsConfig.AzBlobConfig.UploadPartSize, 10)) + form.Set("az_upload_concurrency", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid az_download_part_size + form.Set("az_upload_concurrency", strconv.Itoa(user.FsConfig.AzBlobConfig.UploadConcurrency)) + form.Set("az_download_part_size", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // test invalid az_download_concurrency + form.Set("az_download_part_size", strconv.FormatInt(user.FsConfig.AzBlobConfig.DownloadPartSize, 10)) + form.Set("az_download_concurrency", "a") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // now add the user + form.Set("az_download_concurrency", strconv.Itoa(user.FsConfig.AzBlobConfig.DownloadConcurrency)) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.Container, user.FsConfig.AzBlobConfig.Container) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.AccountName, user.FsConfig.AzBlobConfig.AccountName) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.Endpoint, user.FsConfig.AzBlobConfig.Endpoint) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.KeyPrefix, user.FsConfig.AzBlobConfig.KeyPrefix) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.UploadPartSize, user.FsConfig.AzBlobConfig.UploadPartSize) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.UploadConcurrency, user.FsConfig.AzBlobConfig.UploadConcurrency) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.DownloadPartSize, user.FsConfig.AzBlobConfig.DownloadPartSize) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.DownloadConcurrency, user.FsConfig.AzBlobConfig.DownloadConcurrency) + assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + assert.Empty(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetKey()) + assert.Empty(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + // now check that a redacted password is not saved + form.Set("az_account_key", redactedSecret+" ") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var lastUpdatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &lastUpdatedUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetStatus()) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.AccountKey.GetPayload(), lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.AccountKey.GetAdditionalData()) + // test SAS url + user.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("sasurl") + form.Set("az_account_name", "") + form.Set("az_account_key", "") + form.Set("az_container", "") + form.Set("az_sas_url", user.FsConfig.AzBlobConfig.SASURL.GetPayload()) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + updateUser = dataprovider.User{} + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.AzBlobConfig.SASURL.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetPayload()) + assert.Empty(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetKey()) + assert.Empty(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) + // now check that a redacted sas url is not saved + form.Set("az_sas_url", redactedSecret) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + lastUpdatedUser = dataprovider.User{} + err = render.DecodeJSON(rr.Body, &lastUpdatedUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetStatus()) + assert.Equal(t, updateUser.FsConfig.AzBlobConfig.SASURL.GetPayload(), lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.AzBlobConfig.SASURL.GetAdditionalData()) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebUserCryptMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("crypted passphrase") + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("password", redactedSecret) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "4") + form.Set("crypt_passphrase", "") + form.Set("cryptfs_read_buffer_size", "1") + form.Set("cryptfs_write_buffer_size", "2") + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("directory_patterns[1][pattern_path]", "/dir2") + form.Set("directory_patterns[1][patterns]", "*.zip") + form.Set("directory_patterns[1][pattern_type]", "denied") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + // passphrase cannot be empty + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("crypt_passphrase", user.FsConfig.CryptConfig.Passphrase.GetPayload()) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) + assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, updateUser.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Empty(t, updateUser.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + assert.Equal(t, 1, updateUser.FsConfig.CryptConfig.ReadBufferSize) + assert.Equal(t, 2, updateUser.FsConfig.CryptConfig.WriteBufferSize) + // now check that a redacted password is not saved + form.Set("crypt_passphrase", redactedSecret+" ") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var lastUpdatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &lastUpdatedUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetStatus()) + assert.Equal(t, updateUser.FsConfig.CryptConfig.Passphrase.GetPayload(), lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.CryptConfig.Passphrase.GetAdditionalData()) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebUserSFTPFsMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + userAsJSON := getUserAsJSON(t, user) + req, _ := http.NewRequest(http.MethodPost, userPath, bytes.NewBuffer(userAsJSON)) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + err = render.DecodeJSON(rr.Body, &user) + assert.NoError(t, err) + user.FsConfig.Provider = sdk.SFTPFilesystemProvider + user.FsConfig.SFTPConfig.Endpoint = "127.0.0.1:22" + user.FsConfig.SFTPConfig.Username = "sftpuser" + user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("pwd") + user.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKeyPwd) + user.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret(privateKeyPwd) + user.FsConfig.SFTPConfig.Fingerprints = []string{sftpPkeyFingerprint} + user.FsConfig.SFTPConfig.Prefix = "/home/sftpuser" + user.FsConfig.SFTPConfig.DisableCouncurrentReads = true + user.FsConfig.SFTPConfig.BufferSize = 5 + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("password", redactedSecret) + form.Set("home_dir", user.HomeDir) + form.Set("uid", "0") + form.Set("gid", strconv.FormatInt(int64(user.GID), 10)) + form.Set("max_sessions", strconv.FormatInt(int64(user.MaxSessions), 10)) + form.Set("quota_size", strconv.FormatInt(user.QuotaSize, 10)) + form.Set("quota_files", strconv.FormatInt(int64(user.QuotaFiles), 10)) + form.Set("upload_bandwidth", "0") + form.Set("download_bandwidth", "0") + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("external_auth_cache_time", "0") + form.Set("permissions", "*") + form.Set("status", strconv.Itoa(user.Status)) + form.Set("expiration_date", "2020-01-01 00:00:00") + form.Set("allowed_ip", "") + form.Set("denied_ip", "") + form.Set("fs_provider", "5") + form.Set("crypt_passphrase", "") + form.Set("directory_patterns[0][pattern_path]", "/dir1") + form.Set("directory_patterns[0][patterns]", "*.jpg,*.png") + form.Set("directory_patterns[0][pattern_type]", "allowed") + form.Set("directory_patterns[1][pattern_path]", "/dir2") + form.Set("directory_patterns[1][patterns]", "*.zip") + form.Set("directory_patterns[1][pattern_type]", "denied") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + // empty sftpconfig + b, contentType, _ := getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + form.Set("sftp_endpoint", user.FsConfig.SFTPConfig.Endpoint) + form.Set("sftp_username", user.FsConfig.SFTPConfig.Username) + form.Set("sftp_password", user.FsConfig.SFTPConfig.Password.GetPayload()) + form.Set("sftp_private_key", user.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + form.Set("sftp_key_passphrase", user.FsConfig.SFTPConfig.KeyPassphrase.GetPayload()) + form.Set("sftp_fingerprints", user.FsConfig.SFTPConfig.Fingerprints[0]) + form.Set("sftp_prefix", user.FsConfig.SFTPConfig.Prefix) + form.Set("sftp_disable_concurrent_reads", "true") + form.Set("sftp_equality_check_mode", "true") + form.Set("sftp_buffer_size", strconv.FormatInt(user.FsConfig.SFTPConfig.BufferSize, 10)) + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var updateUser dataprovider.User + err = render.DecodeJSON(rr.Body, &updateUser) + assert.NoError(t, err) + assert.Equal(t, int64(1577836800000), updateUser.ExpirationDate) + assert.Equal(t, 2, len(updateUser.Filters.FilePatterns)) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.SFTPConfig.Password.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.SFTPConfig.Password.GetPayload()) + assert.Empty(t, updateUser.FsConfig.SFTPConfig.Password.GetKey()) + assert.Empty(t, updateUser.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + assert.Empty(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetKey()) + assert.Empty(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetStatus()) + assert.NotEmpty(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload()) + assert.Empty(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetKey()) + assert.Empty(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetAdditionalData()) + assert.Equal(t, updateUser.FsConfig.SFTPConfig.Prefix, user.FsConfig.SFTPConfig.Prefix) + assert.Equal(t, updateUser.FsConfig.SFTPConfig.Username, user.FsConfig.SFTPConfig.Username) + assert.Equal(t, updateUser.FsConfig.SFTPConfig.Endpoint, user.FsConfig.SFTPConfig.Endpoint) + assert.True(t, updateUser.FsConfig.SFTPConfig.DisableCouncurrentReads) + assert.Len(t, updateUser.FsConfig.SFTPConfig.Fingerprints, 1) + assert.Equal(t, user.FsConfig.SFTPConfig.BufferSize, updateUser.FsConfig.SFTPConfig.BufferSize) + assert.Contains(t, updateUser.FsConfig.SFTPConfig.Fingerprints, sftpPkeyFingerprint) + assert.Equal(t, 1, updateUser.FsConfig.SFTPConfig.EqualityCheckMode) + // now check that a redacted credentials are not saved + form.Set("sftp_password", redactedSecret+" ") + form.Set("sftp_private_key", redactedSecret) + form.Set("sftp_key_passphrase", redactedSecret) + form.Set("sftp_equality_check_mode", "") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var lastUpdatedUser dataprovider.User + err = render.DecodeJSON(rr.Body, &lastUpdatedUser) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.SFTPConfig.Password.GetStatus()) + assert.Equal(t, updateUser.FsConfig.SFTPConfig.Password.GetPayload(), lastUpdatedUser.FsConfig.SFTPConfig.Password.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.Password.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.Password.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetStatus()) + assert.Equal(t, updateUser.FsConfig.SFTPConfig.PrivateKey.GetPayload(), lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.PrivateKey.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetStatus()) + assert.Equal(t, updateUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload(), lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload()) + assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetKey()) + assert.Empty(t, lastUpdatedUser.FsConfig.SFTPConfig.KeyPassphrase.GetAdditionalData()) + assert.Equal(t, 0, lastUpdatedUser.FsConfig.SFTPConfig.EqualityCheckMode) + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, user.Username), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebUserRole(t *testing.T) { + role, resp, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Role = role.Name + a.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, + dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + webToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + user := getTestUser() + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + form.Set("home_dir", user.HomeDir) + form.Set("password", user.Password) + form.Set("status", strconv.Itoa(user.Status)) + form.Set("permissions", "*") + form.Set("external_auth_cache_time", "0") + form.Set("uid", "0") + form.Set("gid", "0") + form.Set("max_sessions", "0") + form.Set("quota_size", "0") + form.Set("quota_files", "0") + form.Set("upload_bandwidth", strconv.FormatInt(user.UploadBandwidth, 10)) + form.Set("download_bandwidth", strconv.FormatInt(user.DownloadBandwidth, 10)) + form.Set("upload_data_transfer", strconv.FormatInt(user.UploadDataTransfer, 10)) + form.Set("download_data_transfer", strconv.FormatInt(user.DownloadDataTransfer, 10)) + form.Set("total_data_transfer", strconv.FormatInt(user.TotalDataTransfer, 10)) + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "10") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + b, contentType, _ := getMultipartFormData(form, "", "") + req, err := http.NewRequest(http.MethodPost, webUserPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, user.Role) + + form.Set("role", "") + b, contentType, _ = getMultipartFormData(form, "", "") + req, _ = http.NewRequest(http.MethodPost, path.Join(webUserPath, user.Username), &b) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, role.Name, user.Role) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebEventAction(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminEventActionPath, webToken) + assert.NoError(t, err) + action := dataprovider.BaseEventAction{ + Name: "web_action_http", + Description: "http web action", + Type: dataprovider.ActionTypeHTTP, + Options: dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: "https://localhost:4567/action", + Username: defaultUsername, + Headers: []dataprovider.KeyValue{ + { + Key: "Content-Type", + Value: "application/json", + }, + }, + Password: kms.NewPlainSecret(defaultPassword), + Timeout: 10, + SkipTLSVerify: true, + Method: http.MethodPost, + QueryParameters: []dataprovider.KeyValue{ + { + Key: "param1", + Value: "value1", + }, + }, + Body: `{"event":"{{.Event}}","name":"{{.Name}}"}`, + }, + }, + } + form := make(url.Values) + form.Set("name", action.Name) + form.Set("description", action.Description) + form.Set("fs_action_type", "0") + form.Set("type", "a") + req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("http_timeout", "b") + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("cmd_timeout", "20") + form.Set("pwd_expiration_threshold", "10") + form.Set("http_timeout", fmt.Sprintf("%d", action.Options.HTTPConfig.Timeout)) + form.Set("http_headers[0][http_header_key]", action.Options.HTTPConfig.Headers[0].Key) + form.Set("http_headers[0][http_header_value]", action.Options.HTTPConfig.Headers[0].Value) + form.Set("http_headers[1][http_header_key]", action.Options.HTTPConfig.Headers[0].Key) // ignored + form.Set("query_parameters[0][http_query_key]", action.Options.HTTPConfig.QueryParameters[0].Key) + form.Set("query_parameters[0][http_query_value]", action.Options.HTTPConfig.QueryParameters[0].Value) + form.Set("http_body", action.Options.HTTPConfig.Body) + form.Set("http_skip_tls_verify", "1") + form.Set("http_username", action.Options.HTTPConfig.Username) + form.Set("http_password", action.Options.HTTPConfig.Password.GetPayload()) + form.Set("http_method", action.Options.HTTPConfig.Method) + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorURLRequired) + form.Set("http_endpoint", action.Options.HTTPConfig.Endpoint) + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // a new add will fail + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // list actions + req, err = http.NewRequest(http.MethodGet, webAdminEventActionsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webAdminEventActionsPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // render add page + req, err = http.NewRequest(http.MethodGet, webAdminEventActionPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // render action page + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventActionPath, action.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // missing action + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventActionPath, action.Name+"1"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // check the action + actionGet, _, err := httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Equal(t, action.Description, actionGet.Description) + assert.Equal(t, action.Options.HTTPConfig.Body, actionGet.Options.HTTPConfig.Body) + assert.Equal(t, action.Options.HTTPConfig.Endpoint, actionGet.Options.HTTPConfig.Endpoint) + assert.Equal(t, action.Options.HTTPConfig.Headers, actionGet.Options.HTTPConfig.Headers) + assert.Equal(t, action.Options.HTTPConfig.Method, actionGet.Options.HTTPConfig.Method) + assert.Equal(t, action.Options.HTTPConfig.SkipTLSVerify, actionGet.Options.HTTPConfig.SkipTLSVerify) + assert.Equal(t, action.Options.HTTPConfig.Timeout, actionGet.Options.HTTPConfig.Timeout) + assert.Equal(t, action.Options.HTTPConfig.Username, actionGet.Options.HTTPConfig.Username) + assert.Equal(t, sdkkms.SecretStatusSecretBox, actionGet.Options.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, actionGet.Options.HTTPConfig.Password.GetPayload()) + assert.Empty(t, actionGet.Options.HTTPConfig.Password.GetKey()) + assert.Empty(t, actionGet.Options.HTTPConfig.Password.GetAdditionalData()) + // update and check that the password is preserved and the multipart fields + form.Set("http_password", redactedSecret) + form.Set("http_body", "") + form.Set("http_timeout", "0") + form.Del("http_headers[0][http_header_key]") + form.Del("http_headers[0][http_header_val]") + form.Set("multipart_body[0][http_part_name]", "part1") + form.Set("multipart_body[0][http_part_file]", "{{.VirtualPath}}") + form.Set("multipart_body[0][http_part_body]", "") + form.Set("multipart_body[0][http_part_headers]", "X-MyHeader: a:b,c") + form.Set("multipart_body[12][http_part_name]", "part2") + form.Set("multipart_body[12][http_part_headers]", "Content-Type:application/json \r\n") + form.Set("multipart_body[12][http_part_body]", "{{.ObjectData}}") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + dbAction, err := dataprovider.EventActionExists(action.Name) + assert.NoError(t, err) + err = dbAction.Options.HTTPConfig.Password.Decrypt() + assert.NoError(t, err) + assert.Equal(t, defaultPassword, dbAction.Options.HTTPConfig.Password.GetPayload()) + assert.Empty(t, dbAction.Options.HTTPConfig.Body) + assert.Equal(t, 0, dbAction.Options.HTTPConfig.Timeout) + if assert.Len(t, dbAction.Options.HTTPConfig.Parts, 2) { + assert.Equal(t, "part1", dbAction.Options.HTTPConfig.Parts[0].Name) + assert.Equal(t, "/{{.VirtualPath}}", dbAction.Options.HTTPConfig.Parts[0].Filepath) + assert.Empty(t, dbAction.Options.HTTPConfig.Parts[0].Body) + assert.Equal(t, "X-MyHeader", dbAction.Options.HTTPConfig.Parts[0].Headers[0].Key) + assert.Equal(t, "a:b,c", dbAction.Options.HTTPConfig.Parts[0].Headers[0].Value) + assert.Equal(t, "part2", dbAction.Options.HTTPConfig.Parts[1].Name) + assert.Equal(t, "{{.ObjectData}}", dbAction.Options.HTTPConfig.Parts[1].Body) + assert.Empty(t, dbAction.Options.HTTPConfig.Parts[1].Filepath) + assert.Equal(t, "Content-Type", dbAction.Options.HTTPConfig.Parts[1].Headers[0].Key) + assert.Equal(t, "application/json", dbAction.Options.HTTPConfig.Parts[1].Headers[0].Value) + } + // change action type + action.Type = dataprovider.ActionTypeCommand + action.Options.CmdConfig = dataprovider.EventActionCommandConfig{ + Cmd: filepath.Join(os.TempDir(), "cmd"), + Args: []string{"arg1", "arg2"}, + Timeout: 20, + EnvVars: []dataprovider.KeyValue{ + { + Key: "key", + Value: "val", + }, + }, + } + dataprovider.EnabledActionCommands = []string{action.Options.CmdConfig.Cmd} + defer func() { + dataprovider.EnabledActionCommands = nil + }() + form.Set("type", fmt.Sprintf("%d", action.Type)) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorCommandRequired) + form.Set("cmd_path", action.Options.CmdConfig.Cmd) + form.Set("cmd_timeout", "a") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("cmd_timeout", fmt.Sprintf("%d", action.Options.CmdConfig.Timeout)) + form.Set("env_vars[0][cmd_env_key]", action.Options.CmdConfig.EnvVars[0].Key) + form.Set("env_vars[0][cmd_env_value]", action.Options.CmdConfig.EnvVars[0].Value) + form.Set("cmd_arguments", "arg1 ,arg2 ") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // update a missing action + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name+"1"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // update with no csrf token + form.Del(csrfFormToken) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + form.Set(csrfFormToken, csrfToken) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Equal(t, action.Options.CmdConfig.Cmd, actionGet.Options.CmdConfig.Cmd) + assert.Equal(t, action.Options.CmdConfig.Args, actionGet.Options.CmdConfig.Args) + assert.Equal(t, action.Options.CmdConfig.Timeout, actionGet.Options.CmdConfig.Timeout) + assert.Equal(t, action.Options.CmdConfig.EnvVars, actionGet.Options.CmdConfig.EnvVars) + assert.Equal(t, dataprovider.EventActionHTTPConfig{}, actionGet.Options.HTTPConfig) + assert.Equal(t, dataprovider.EventActionPasswordExpiration{}, actionGet.Options.PwdExpirationConfig) + // change action type again + action.Type = dataprovider.ActionTypeEmail + action.Options.EmailConfig = dataprovider.EventActionEmailConfig{ + Recipients: []string{"address1@example.com", "address2@example.com"}, + Bcc: []string{"address3@example.com"}, + Subject: "subject", + ContentType: 1, + Body: "body", + Attachments: []string{"/file1.txt", "/file2.txt"}, + } + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("email_recipients", "address1@example.com, address2@example.com") + form.Set("email_bcc", "address3@example.com") + form.Set("email_subject", action.Options.EmailConfig.Subject) + form.Set("email_content_type", fmt.Sprintf("%d", action.Options.EmailConfig.ContentType)) + form.Set("email_body", action.Options.EmailConfig.Body) + form.Set("email_attachments", "file1.txt, file2.txt") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Equal(t, action.Options.EmailConfig.Recipients, actionGet.Options.EmailConfig.Recipients) + assert.Equal(t, action.Options.EmailConfig.Bcc, actionGet.Options.EmailConfig.Bcc) + assert.Equal(t, action.Options.EmailConfig.Subject, actionGet.Options.EmailConfig.Subject) + assert.Equal(t, action.Options.EmailConfig.ContentType, actionGet.Options.EmailConfig.ContentType) + assert.Equal(t, action.Options.EmailConfig.Body, actionGet.Options.EmailConfig.Body) + assert.Equal(t, action.Options.EmailConfig.Attachments, actionGet.Options.EmailConfig.Attachments) + assert.Equal(t, dataprovider.EventActionHTTPConfig{}, actionGet.Options.HTTPConfig) + assert.Empty(t, actionGet.Options.CmdConfig.Cmd) + assert.Equal(t, 0, actionGet.Options.CmdConfig.Timeout) + assert.Len(t, actionGet.Options.CmdConfig.EnvVars, 0) + // change action type to data retention check + action.Type = dataprovider.ActionTypeDataRetentionCheck + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("data_retention[10][folder_retention_path]", "p1") + form.Set("data_retention[10][folder_retention_val]", "a") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("data_retention[10][folder_retention_val]", "24") + form.Set("data_retention[10][folder_retention_options][]", "1") + form.Set("data_retention[11][folder_retention_path]", "../p2") + form.Set("data_retention[11][folder_retention_val]", "48") + form.Set("data_retention[11][folder_retention_options][]", "1") + form.Set("data_retention[13][folder_retention_options][]", "1") // ignored + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + if assert.Len(t, actionGet.Options.RetentionConfig.Folders, 2) { + for _, folder := range actionGet.Options.RetentionConfig.Folders { + switch folder.Path { + case "/p1": + assert.Equal(t, 24, folder.Retention) + assert.True(t, folder.DeleteEmptyDirs) + case "/p2": + assert.Equal(t, 48, folder.Retention) + assert.True(t, folder.DeleteEmptyDirs) + default: + t.Errorf("unexpected folder path %v", folder.Path) + } + } + } + action.Type = dataprovider.ActionTypeFilesystem + action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionMkdirs, + MkDirs: []string{"a ", " a/b"}, + } + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("fs_mkdir_paths", strings.Join(action.Options.FsConfig.MkDirs, ",")) + form.Set("fs_action_type", "invalid") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + if assert.Len(t, actionGet.Options.FsConfig.MkDirs, 2) { + for _, dir := range actionGet.Options.FsConfig.MkDirs { + switch dir { + case "/a": + case "/a/b": + default: + t.Errorf("unexpected dir path %v", dir) + } + } + } + + action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionExist, + Exist: []string{"b ", " c/d"}, + } + form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) + form.Set("fs_exist_paths", strings.Join(action.Options.FsConfig.Exist, ",")) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + if assert.Len(t, actionGet.Options.FsConfig.Exist, 2) { + for _, p := range actionGet.Options.FsConfig.Exist { + switch p { + case "/b": + case "/c/d": + default: + t.Errorf("unexpected path %v", p) + } + } + } + + action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionRename, + Renames: []dataprovider.RenameConfig{ + { + KeyValue: dataprovider.KeyValue{ + Key: "/src", + Value: "/target", + }, + }, + }, + } + form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) + form.Set("fs_rename[0][fs_rename_source]", action.Options.FsConfig.Renames[0].Key) + form.Set("fs_rename[0][fs_rename_target]", action.Options.FsConfig.Renames[0].Value) + form.Set("fs_rename[0][fs_rename_options][]", "1") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + if assert.Len(t, actionGet.Options.FsConfig.Renames, 1) { + assert.True(t, actionGet.Options.FsConfig.Renames[0].UpdateModTime) + } + + action.Options.FsConfig = dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionCopy, + Copy: []dataprovider.KeyValue{ + { + Key: "/copy_src", + Value: "/copy_target", + }, + }, + } + form.Set("fs_action_type", fmt.Sprintf("%d", action.Options.FsConfig.Type)) + form.Set("fs_copy[0][fs_copy_source]", action.Options.FsConfig.Copy[0].Key) + form.Set("fs_copy[0][fs_copy_target]", action.Options.FsConfig.Copy[0].Value) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the update + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Len(t, actionGet.Options.FsConfig.Copy, 1) + + action.Type = dataprovider.ActionTypePasswordExpirationCheck + action.Options.PwdExpirationConfig.Threshold = 15 + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("pwd_expiration_threshold", "a") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("pwd_expiration_threshold", strconv.Itoa(action.Options.PwdExpirationConfig.Threshold)) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Equal(t, action.Options.PwdExpirationConfig.Threshold, actionGet.Options.PwdExpirationConfig.Threshold) + assert.Equal(t, 0, actionGet.Options.CmdConfig.Timeout) + assert.Len(t, actionGet.Options.CmdConfig.EnvVars, 0) + + action.Type = dataprovider.ActionTypeUserInactivityCheck + action.Options.UserInactivityConfig = dataprovider.EventActionUserInactivity{ + DisableThreshold: 10, + DeleteThreshold: 15, + } + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("inactivity_disable_threshold", strconv.Itoa(action.Options.UserInactivityConfig.DisableThreshold)) + form.Set("inactivity_delete_threshold", strconv.Itoa(action.Options.UserInactivityConfig.DeleteThreshold)) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Equal(t, 0, actionGet.Options.PwdExpirationConfig.Threshold) + assert.Equal(t, action.Options.UserInactivityConfig.DisableThreshold, actionGet.Options.UserInactivityConfig.DisableThreshold) + assert.Equal(t, action.Options.UserInactivityConfig.DeleteThreshold, actionGet.Options.UserInactivityConfig.DeleteThreshold) + + action.Type = dataprovider.ActionTypeIDPAccountCheck + form.Set("type", fmt.Sprintf("%d", action.Type)) + form.Set("idp_mode", "1") + form.Set("idp_user", `{"username":"user"}`) + form.Set("idp_admin", `{"username":"admin"}`) + form.Set("pwd_expiration_threshold", strconv.Itoa(action.Options.PwdExpirationConfig.Threshold)) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventActionPath, action.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + actionGet, _, err = httpdtest.GetEventActionByName(action.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, action.Type, actionGet.Type) + assert.Equal(t, 1, actionGet.Options.IDPConfig.Mode) + assert.Contains(t, actionGet.Options.IDPConfig.TemplateUser, `"user"`) + assert.Contains(t, actionGet.Options.IDPConfig.TemplateAdmin, `"admin"`) + + req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventActionPath, action.Name), nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventActionPath, action.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webAdminEventActionsPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Equal(t, `[]`, rr.Body.String()) +} + +func TestWebEventRule(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminEventRulePath, webToken) + assert.NoError(t, err) + a := dataprovider.BaseEventAction{ + Name: "web_action", + Type: dataprovider.ActionTypeFilesystem, + Options: dataprovider.BaseEventActionOptions{ + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: dataprovider.FilesystemActionExist, + Exist: []string{"/dir1"}, + }, + }, + } + action, _, err := httpdtest.AddEventAction(a, http.StatusCreated) + assert.NoError(t, err) + rule := dataprovider.EventRule{ + Name: "test_web_rule", + Status: 1, + Description: "rule added using web API", + Trigger: dataprovider.EventTriggerSchedule, + Conditions: dataprovider.EventConditions{ + Schedules: []dataprovider.Schedule{ + { + Hours: "0", + DayOfWeek: "*", + DayOfMonth: "*", + Month: "*", + }, + }, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "u*", + InverseMatch: true, + }, + }, + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "g*", + InverseMatch: true, + }, + }, + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: "r*", + InverseMatch: true, + }, + }, + }, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Order: 1, + }, + }, + } + form := make(url.Values) + form.Set("name", rule.Name) + form.Set("description", rule.Description) + form.Set("status", "a") + req, err := http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("status", fmt.Sprintf("%d", rule.Status)) + form.Set("trigger", "a") + req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("trigger", fmt.Sprintf("%d", rule.Trigger)) + form.Set("schedules[0][schedule_hour]", rule.Conditions.Schedules[0].Hours) + form.Set("schedules[0][schedule_day_of_week]", rule.Conditions.Schedules[0].DayOfWeek) + form.Set("schedules[0][schedule_day_of_month]", rule.Conditions.Schedules[0].DayOfMonth) + form.Set("schedules[0][schedule_month]", rule.Conditions.Schedules[0].Month) + form.Set("name_filters[0][name_pattern]", rule.Conditions.Options.Names[0].Pattern) + form.Set("name_filters[0][type_name_pattern]", "inverse") + form.Set("group_name_filters[0][group_name_pattern]", rule.Conditions.Options.GroupNames[0].Pattern) + form.Set("group_name_filters[0][type_group_name_pattern]", "inverse") + form.Set("role_name_filters[0][role_name_pattern]", rule.Conditions.Options.RoleNames[0].Pattern) + form.Set("role_name_filters[0][type_role_name_pattern]", "inverse") + req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidMinSize) + form.Set("fs_min_size", "0") + req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidMaxSize) + form.Set("fs_max_size", "0") + form.Set("actions[0][action_name]", action.Name) + req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // a new add will fail + req, err = http.NewRequest(http.MethodPost, webAdminEventRulePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // list rules + req, err = http.NewRequest(http.MethodGet, webAdminEventRulesPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webAdminEventRulesPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // render add page + req, err = http.NewRequest(http.MethodGet, webAdminEventRulePath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // render rule page + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventRulePath, rule.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // missing rule + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminEventRulePath, rule.Name+"1"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // check the rule + ruleGet, _, err := httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, rule.Trigger, ruleGet.Trigger) + assert.Equal(t, rule.Status, ruleGet.Status) + assert.Equal(t, rule.Description, ruleGet.Description) + assert.Equal(t, rule.Conditions, ruleGet.Conditions) + if assert.Len(t, ruleGet.Actions, 1) { + assert.Equal(t, rule.Actions[0].Name, ruleGet.Actions[0].Name) + assert.Equal(t, rule.Actions[0].Order, ruleGet.Actions[0].Order) + } + // change rule trigger and status + rule.Status = 0 + rule.Trigger = dataprovider.EventTriggerFsEvent + rule.Conditions = dataprovider.EventConditions{ + FsEvents: []string{"upload", "download"}, + Options: dataprovider.ConditionOptions{ + Names: []dataprovider.ConditionPattern{ + { + Pattern: "u*", + InverseMatch: true, + }, + }, + GroupNames: []dataprovider.ConditionPattern{ + { + Pattern: "g*", + InverseMatch: true, + }, + }, + RoleNames: []dataprovider.ConditionPattern{ + { + Pattern: "r*", + InverseMatch: true, + }, + }, + FsPaths: []dataprovider.ConditionPattern{ + { + Pattern: "/subdir/*.txt", + }, + }, + Protocols: []string{common.ProtocolSFTP, common.ProtocolHTTP}, + MinFileSize: 1024 * 1024, + MaxFileSize: 5 * 1024 * 1024, + }, + } + form.Set("status", fmt.Sprintf("%d", rule.Status)) + form.Set("trigger", fmt.Sprintf("%d", rule.Trigger)) + for _, event := range rule.Conditions.FsEvents { + form.Add("fs_events", event) + } + form.Set("path_filters[0][fs_path_pattern]", rule.Conditions.Options.FsPaths[0].Pattern) + for _, protocol := range rule.Conditions.Options.Protocols { + form.Add("fs_protocols", protocol) + } + form.Set("fs_min_size", fmt.Sprintf("%d", rule.Conditions.Options.MinFileSize)) + form.Set("fs_max_size", fmt.Sprintf("%d", rule.Conditions.Options.MaxFileSize)) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the rule + ruleGet, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, rule.Status, ruleGet.Status) + assert.Equal(t, rule.Trigger, ruleGet.Trigger) + assert.Equal(t, rule.Description, ruleGet.Description) + assert.Equal(t, rule.Conditions, ruleGet.Conditions) + if assert.Len(t, ruleGet.Actions, 1) { + assert.Equal(t, rule.Actions[0].Name, ruleGet.Actions[0].Name) + assert.Equal(t, rule.Actions[0].Order, ruleGet.Actions[0].Order) + } + rule.Trigger = dataprovider.EventTriggerIDPLogin + form.Set("trigger", fmt.Sprintf("%d", rule.Trigger)) + form.Set("idp_login_event", "1") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the rule + ruleGet, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, rule.Trigger, ruleGet.Trigger) + assert.Equal(t, 1, ruleGet.Conditions.IDPLoginEvent) + + form.Set("idp_login_event", "2") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the rule + ruleGet, _, err = httpdtest.GetEventRuleByName(rule.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, rule.Trigger, ruleGet.Trigger) + assert.Equal(t, 2, ruleGet.Conditions.IDPLoginEvent) + + // update a missing rule + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name+"1"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // update with no csrf token + form.Del(csrfFormToken) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + form.Set(csrfFormToken, csrfToken) + // update with no action defined + form.Del("actions[0][action_name]") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorRuleActionRequired) + // invalid trigger + form.Set("trigger", "a") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminEventRulePath, rule.Name), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventRulePath, rule.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodDelete, path.Join(webAdminEventActionPath, action.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebIPListEntries(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, webToken) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webIPListPath+"/mode", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, webIPListPath+"/mode/a", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webIPListPath, "/1/a"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webIPListPath+"/1", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webIPListsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + entry := dataprovider.IPListEntry{ + IPOrNet: "12.34.56.78/20", + Type: dataprovider.IPListTypeDefender, + Mode: dataprovider.ListModeDeny, + Description: "note", + Protocols: 5, + } + form := make(url.Values) + form.Set("ipornet", entry.IPOrNet) + form.Set("description", entry.Description) + form.Set("mode", "a") + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/mode", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), util.I18nError400Message) + + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/2", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + form.Set("mode", "2") + form.Set("protocols", "a") + form.Add("protocols", "1") + form.Add("protocols", "4") + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/"+strconv.Itoa(int(entry.Type)), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + entry1, _, err := httpdtest.GetIPListEntry(entry.IPOrNet, dataprovider.IPListTypeDefender, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, entry.Description, entry1.Description) + assert.Equal(t, entry.Mode, entry1.Mode) + assert.Equal(t, entry.Protocols, entry1.Protocols) + + form.Set("ipornet", "1111.11.11.11") + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorIPInvalid) + + form.Set("ipornet", entry.IPOrNet) + form.Set("mode", "invalid") // ignored for list type 1 + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + entry2, _, err := httpdtest.GetIPListEntry(entry.IPOrNet, dataprovider.IPListTypeAllowList, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, entry.Description, entry2.Description) + assert.Equal(t, dataprovider.ListModeAllow, entry2.Mode) + assert.Equal(t, entry.Protocols, entry2.Protocols) + + req, err = http.NewRequest(http.MethodGet, path.Join(webIPListPath, "1", url.PathEscape(entry2.IPOrNet)), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("protocols", "1") + req, err = http.NewRequest(http.MethodPost, path.Join(webIPListPath, "1", url.PathEscape(entry.IPOrNet)), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + entry2, _, err = httpdtest.GetIPListEntry(entry.IPOrNet, dataprovider.IPListTypeAllowList, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, entry.Description, entry2.Description) + assert.Equal(t, dataprovider.ListModeAllow, entry2.Mode) + assert.Equal(t, 1, entry2.Protocols) + + form.Del(csrfFormToken) + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1/"+url.PathEscape(entry.IPOrNet), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/a/"+url.PathEscape(entry.IPOrNet), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1/"+url.PathEscape(entry.IPOrNet)+"a", + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + form.Set("mode", "a") + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/2/"+url.PathEscape(entry.IPOrNet), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + form.Set("mode", "100") + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/2/"+url.PathEscape(entry.IPOrNet), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + _, err = httpdtest.RemoveIPListEntry(entry1, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveIPListEntry(entry2, http.StatusOK) + assert.NoError(t, err) +} + +func TestWebRole(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webAdminRolePath, webToken) + assert.NoError(t, err) + role := getTestRole() + form := make(url.Values) + form.Set("name", "") + form.Set("description", role.Description) + req, err := http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorNameRequired) + form.Set("name", role.Name) + req, err = http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // a new add will fail + req, err = http.NewRequest(http.MethodPost, webAdminRolePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // list roles + req, err = http.NewRequest(http.MethodGet, webAdminRolesPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webAdminRolesPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // render the new role page + req, err = http.NewRequest(http.MethodGet, webAdminRolePath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminRolePath, role.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminRolePath, "missing_role"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + // parse form error + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name)+"?param=p%C4%AO%GH", + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + // update role + form.Set("description", "new desc") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check the changes + role, _, err = httpdtest.GetRoleByName(role.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, "new desc", role.Description) + // no CSRF token + form.Set(csrfFormToken, "") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + // missing role + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, "missing"), bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) +} + +func TestAddWebGroup(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) + assert.NoError(t, err) + group := getTestGroup() + group.UserSettings = dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + HomeDir: filepath.Join(os.TempDir(), util.GenerateUniqueID()), + Permissions: make(map[string][]string), + MaxSessions: 2, + QuotaSize: 123, + QuotaFiles: 10, + UploadBandwidth: 128, + DownloadBandwidth: 256, + ExpiresIn: 10, + }, + } + form := make(url.Values) + form.Set("name", group.Name) + form.Set("description", group.Description) + form.Set("home_dir", group.UserSettings.HomeDir) + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("max_sessions", strconv.FormatInt(int64(group.UserSettings.MaxSessions), 10)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidQuotaSize) + form.Set("quota_files", strconv.FormatInt(int64(group.UserSettings.QuotaFiles), 10)) + form.Set("quota_size", strconv.FormatInt(group.UserSettings.QuotaSize, 10)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("upload_bandwidth", strconv.FormatInt(group.UserSettings.UploadBandwidth, 10)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("download_bandwidth", strconv.FormatInt(group.UserSettings.DownloadBandwidth, 10)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("expires_in", strconv.Itoa(group.UserSettings.ExpiresIn)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidMaxFilesize) + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("external_auth_cache_time", "0") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + form.Set(csrfFormToken, csrfToken) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath+"?b=%2", &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) // error parsing the multipart form + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // a new add will fail + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webGroupPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // list groups + req, err = http.NewRequest(http.MethodGet, webGroupsPath, nil) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webGroupsPath+jsonAPISuffix, nil) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // render the new group page + req, err = http.NewRequest(http.MethodGet, path.Join(webGroupPath, group.Name), nil) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // check the added group + groupGet, _, err := httpdtest.GetGroupByName(group.Name, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, group.UserSettings, groupGet.UserSettings) + assert.Equal(t, group.Name, groupGet.Name) + assert.Equal(t, group.Description, groupGet.Description) + // cleanup + req, err = http.NewRequest(http.MethodDelete, path.Join(groupPath, group.Name), nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webGroupPath, group.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestAddWebFoldersMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) + assert.NoError(t, err) + mappedPath := filepath.Clean(os.TempDir()) + folderName := filepath.Base(mappedPath) + folderDesc := "a simple desc" + form := make(url.Values) + form.Set("mapped_path", mappedPath) + form.Set("name", folderName) + form.Set("description", folderDesc) + form.Set("osfs_read_buffer_size", "3") + form.Set("osfs_write_buffer_size", "4") + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // adding the same folder will fail since the name must be unique + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // invalid form + req, err = http.NewRequest(http.MethodPost, webFolderPath, strings.NewReader(form.Encode())) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", "text/plain; boundary=") + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // now render the add folder page + req, err = http.NewRequest(http.MethodGet, webFolderPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + var folder vfs.BaseVirtualFolder + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, folderDesc, folder.Description) + assert.Equal(t, 3, folder.FsConfig.OSConfig.ReadBufferSize) + assert.Equal(t, 4, folder.FsConfig.OSConfig.WriteBufferSize) + // cleanup + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestHTTPFsWebFolderMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) + assert.NoError(t, err) + mappedPath := filepath.Clean(os.TempDir()) + folderName := filepath.Base(mappedPath) + httpfsConfig := vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: "https://127.0.0.1:9998/api/v1", + Username: folderName, + SkipTLSVerify: true, + }, + Password: kms.NewPlainSecret(defaultPassword), + APIKey: kms.NewPlainSecret(defaultTokenAuthPass), + } + form := make(url.Values) + form.Set("mapped_path", mappedPath) + form.Set("name", folderName) + form.Set("fs_provider", "6") + form.Set("http_endpoint", httpfsConfig.Endpoint) + form.Set("http_username", "%name%") + form.Set("http_password", httpfsConfig.Password.GetPayload()) + form.Set("http_api_key", httpfsConfig.APIKey.GetPayload()) + form.Set("http_skip_tls_verify", "checked") + form.Set(csrfFormToken, csrfToken) + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check + var folder vfs.BaseVirtualFolder + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, sdk.HTTPFilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, httpfsConfig.Endpoint, folder.FsConfig.HTTPConfig.Endpoint) + assert.Equal(t, httpfsConfig.Username, folder.FsConfig.HTTPConfig.Username) + assert.Equal(t, httpfsConfig.SkipTLSVerify, folder.FsConfig.HTTPConfig.SkipTLSVerify) + assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.HTTPConfig.Password.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, folder.FsConfig.HTTPConfig.Password.GetKey()) + assert.Empty(t, folder.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, folder.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.NotEmpty(t, folder.FsConfig.HTTPConfig.APIKey.GetPayload()) + assert.Empty(t, folder.FsConfig.HTTPConfig.APIKey.GetKey()) + assert.Empty(t, folder.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + // update + form.Set("http_password", redactedSecret) + form.Set("http_api_key", redactedSecret) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + // check + var updateFolder vfs.BaseVirtualFolder + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &updateFolder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, updateFolder.MappedPath) + assert.Equal(t, folderName, updateFolder.Name) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateFolder.FsConfig.HTTPConfig.Password.GetStatus()) + assert.Equal(t, folder.FsConfig.HTTPConfig.Password.GetPayload(), updateFolder.FsConfig.HTTPConfig.Password.GetPayload()) + assert.Empty(t, updateFolder.FsConfig.HTTPConfig.Password.GetKey()) + assert.Empty(t, updateFolder.FsConfig.HTTPConfig.Password.GetAdditionalData()) + assert.Equal(t, sdkkms.SecretStatusSecretBox, updateFolder.FsConfig.HTTPConfig.APIKey.GetStatus()) + assert.Equal(t, folder.FsConfig.HTTPConfig.APIKey.GetPayload(), updateFolder.FsConfig.HTTPConfig.APIKey.GetPayload()) + assert.Empty(t, updateFolder.FsConfig.HTTPConfig.APIKey.GetKey()) + assert.Empty(t, updateFolder.FsConfig.HTTPConfig.APIKey.GetAdditionalData()) + + // cleanup + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestS3WebFolderMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) + assert.NoError(t, err) + mappedPath := filepath.Clean(os.TempDir()) + folderName := filepath.Base(mappedPath) + folderDesc := "a simple desc" + S3Bucket := "test" + S3Region := "eu-west-1" + S3AccessKey := "access-key" + S3AccessSecret := kms.NewPlainSecret("folder-access-secret") + S3SSEKey := kms.NewPlainSecret("folder-sse-key") + S3SessionToken := "fake session token" + S3RoleARN := "arn:aws:iam::123456789012:user/Development/product_1234/*" + S3Endpoint := "http://127.0.0.1:9000/path?b=c" + S3StorageClass := "Standard" + S3ACL := "public-read-write" + S3KeyPrefix := "somedir/subdir/" + S3UploadPartSize := 5 + S3UploadConcurrency := 4 + S3MaxPartDownloadTime := 120 + S3MaxPartUploadTime := 60 + S3DownloadPartSize := 6 + S3DownloadConcurrency := 3 + form := make(url.Values) + form.Set("mapped_path", mappedPath) + form.Set("name", folderName) + form.Set("description", folderDesc) + form.Set("fs_provider", "1") + form.Set("s3_bucket", S3Bucket) + form.Set("s3_region", S3Region) + form.Set("s3_access_key", S3AccessKey) + form.Set("s3_access_secret", S3AccessSecret.GetPayload()) + form.Set("s3_sse_customer_key", S3SSEKey.GetPayload()) + form.Set("s3_session_token", S3SessionToken) + form.Set("s3_role_arn", S3RoleARN) + form.Set("s3_storage_class", S3StorageClass) + form.Set("s3_acl", S3ACL) + form.Set("s3_endpoint", S3Endpoint) + form.Set("s3_key_prefix", S3KeyPrefix) + form.Set("s3_upload_part_size", strconv.Itoa(S3UploadPartSize)) + form.Set("s3_download_part_max_time", strconv.Itoa(S3MaxPartDownloadTime)) + form.Set("s3_download_part_size", strconv.Itoa(S3DownloadPartSize)) + form.Set("s3_download_concurrency", strconv.Itoa(S3DownloadConcurrency)) + form.Set("s3_upload_part_max_time", strconv.Itoa(S3MaxPartUploadTime)) + form.Set("s3_upload_concurrency", "a") + form.Set(csrfFormToken, csrfToken) + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("s3_upload_concurrency", strconv.Itoa(S3UploadConcurrency)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webFolderPath, &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + var folder vfs.BaseVirtualFolder + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, folderDesc, folder.Description) + assert.Equal(t, sdk.S3FilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, S3Bucket, folder.FsConfig.S3Config.Bucket) + assert.Equal(t, S3Region, folder.FsConfig.S3Config.Region) + assert.Equal(t, S3AccessKey, folder.FsConfig.S3Config.AccessKey) + assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.NotEmpty(t, folder.FsConfig.S3Config.SSECustomerKey.GetPayload()) + assert.Equal(t, S3Endpoint, folder.FsConfig.S3Config.Endpoint) + assert.Equal(t, S3StorageClass, folder.FsConfig.S3Config.StorageClass) + assert.Equal(t, S3ACL, folder.FsConfig.S3Config.ACL) + assert.Equal(t, S3KeyPrefix, folder.FsConfig.S3Config.KeyPrefix) + assert.Equal(t, S3UploadConcurrency, folder.FsConfig.S3Config.UploadConcurrency) + assert.Equal(t, int64(S3UploadPartSize), folder.FsConfig.S3Config.UploadPartSize) + assert.Equal(t, S3MaxPartDownloadTime, folder.FsConfig.S3Config.DownloadPartMaxTime) + assert.Equal(t, S3MaxPartUploadTime, folder.FsConfig.S3Config.UploadPartMaxTime) + assert.Equal(t, S3DownloadConcurrency, folder.FsConfig.S3Config.DownloadConcurrency) + assert.Equal(t, int64(S3DownloadPartSize), folder.FsConfig.S3Config.DownloadPartSize) + assert.False(t, folder.FsConfig.S3Config.ForcePathStyle) + assert.False(t, folder.FsConfig.S3Config.SkipTLSVerify) + // update + S3UploadConcurrency = 10 + form.Set("s3_upload_concurrency", "b") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + form.Set("s3_upload_concurrency", strconv.Itoa(S3UploadConcurrency)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + folder = vfs.BaseVirtualFolder{} + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, folderDesc, folder.Description) + assert.Equal(t, sdk.S3FilesystemProvider, folder.FsConfig.Provider) + assert.Equal(t, S3Bucket, folder.FsConfig.S3Config.Bucket) + assert.Equal(t, S3Region, folder.FsConfig.S3Config.Region) + assert.Equal(t, S3AccessKey, folder.FsConfig.S3Config.AccessKey) + assert.Equal(t, S3RoleARN, folder.FsConfig.S3Config.RoleARN) + assert.NotEmpty(t, folder.FsConfig.S3Config.AccessSecret.GetPayload()) + assert.NotEmpty(t, folder.FsConfig.S3Config.SSECustomerKey.GetPayload()) + assert.Equal(t, S3Endpoint, folder.FsConfig.S3Config.Endpoint) + assert.Equal(t, S3StorageClass, folder.FsConfig.S3Config.StorageClass) + assert.Equal(t, S3KeyPrefix, folder.FsConfig.S3Config.KeyPrefix) + assert.Equal(t, S3UploadConcurrency, folder.FsConfig.S3Config.UploadConcurrency) + assert.Equal(t, int64(S3UploadPartSize), folder.FsConfig.S3Config.UploadPartSize) + + // cleanup + req, _ = http.NewRequest(http.MethodDelete, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestUpdateWebGroupMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webGroupPath, webToken) + assert.NoError(t, err) + group, _, err := httpdtest.AddGroup(getTestGroup(), http.StatusCreated) + assert.NoError(t, err) + + group.UserSettings = dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + HomeDir: filepath.Join(os.TempDir(), util.GenerateUniqueID()), + Permissions: make(map[string][]string), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: defaultUsername, + BufferSize: 1, + }, + }, + }, + } + form := make(url.Values) + form.Set("name", group.Name) + form.Set("description", group.Description) + form.Set("home_dir", group.UserSettings.HomeDir) + form.Set("max_sessions", strconv.FormatInt(int64(group.UserSettings.MaxSessions), 10)) + form.Set("quota_files", strconv.FormatInt(int64(group.UserSettings.QuotaFiles), 10)) + form.Set("quota_size", strconv.FormatInt(group.UserSettings.QuotaSize, 10)) + form.Set("upload_bandwidth", strconv.FormatInt(group.UserSettings.UploadBandwidth, 10)) + form.Set("download_bandwidth", strconv.FormatInt(group.UserSettings.DownloadBandwidth, 10)) + form.Set("upload_data_transfer", "0") + form.Set("download_data_transfer", "0") + form.Set("total_data_transfer", "0") + form.Set("max_upload_file_size", "0") + form.Set("default_shares_expiration", "0") + form.Set("max_shares_expiration", "0") + form.Set("expires_in", "0") + form.Set("password_expiration", "0") + form.Set("password_strength", "0") + form.Set("external_auth_cache_time", "0") + form.Set("fs_provider", strconv.FormatInt(int64(group.UserSettings.FsConfig.Provider), 10)) + form.Set("sftp_endpoint", group.UserSettings.FsConfig.SFTPConfig.Endpoint) + form.Set("sftp_username", group.UserSettings.FsConfig.SFTPConfig.Username) + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + form.Set("sftp_buffer_size", strconv.FormatInt(group.UserSettings.FsConfig.SFTPConfig.BufferSize, 10)) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorFsCredentialsRequired) + + form.Set("sftp_password", defaultPassword) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + req, err = http.NewRequest(http.MethodDelete, path.Join(groupPath, group.Name), nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webGroupPath, group.Name), &b) + assert.NoError(t, err) + req.Header.Set("Content-Type", contentType) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestUpdateWebFolderMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webFolderPath, webToken) + assert.NoError(t, err) + folderName := "vfolderupdate" + folderDesc := "updated desc" + folder := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "folderupdate"), + Description: "dsc", + } + _, _, err = httpdtest.AddFolder(folder, http.StatusCreated) + newMappedPath := folder.MappedPath + "1" + assert.NoError(t, err) + form := make(url.Values) + form.Set("mapped_path", newMappedPath) + form.Set("name", folderName) + form.Set("description", folderDesc) + form.Set(csrfFormToken, "") + b, contentType, err := getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr := executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + form.Set(csrfFormToken, csrfToken) + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusSeeOther, rr) + + req, _ = http.NewRequest(http.MethodGet, path.Join(folderPath, folderName), nil) + setBearerForReq(req, apiToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err = render.DecodeJSON(rr.Body, &folder) + assert.NoError(t, err) + assert.Equal(t, newMappedPath, folder.MappedPath) + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, folderDesc, folder.Description) + + // parse form error + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName)+"??a=a%B3%A2%G3", &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName+"1"), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + form.Set("mapped_path", "arelative/path") + b, contentType, err = getMultipartFormData(form, "", "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(webFolderPath, folderName), &b) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + req.Header.Set("Content-Type", contentType) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // render update folder page + req, err = http.NewRequest(http.MethodGet, path.Join(webFolderPath, folderName), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, path.Join(webFolderPath, folderName+"1"), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webFolderPath, folderName), nil) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webFolderPath, folderName), nil) + setJWTCookieForReq(req, apiToken) // api token is not accepted + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusFound, rr) + assert.Equal(t, webLoginPath, rr.Header().Get("Location")) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webFolderPath, folderName), nil) + setJWTCookieForReq(req, webToken) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestWebFoldersMock(t *testing.T) { + webToken, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + apiToken, err := getJWTAPITokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + mappedPath1 := filepath.Join(os.TempDir(), "vfolder1") + mappedPath2 := filepath.Join(os.TempDir(), "vfolder2") + folderName1 := filepath.Base(mappedPath1) + folderName2 := filepath.Base(mappedPath2) + folderDesc1 := "vfolder1 desc" + folderDesc2 := "vfolder2 desc" + folders := []vfs.BaseVirtualFolder{ + { + Name: folderName1, + MappedPath: mappedPath1, + Description: folderDesc1, + }, + { + Name: folderName2, + MappedPath: mappedPath2, + Description: folderDesc2, + }, + } + for _, folder := range folders { + folderAsJSON, err := json.Marshal(folder) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, folderPath, bytes.NewBuffer(folderAsJSON)) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusCreated, rr) + } + + req, err := http.NewRequest(http.MethodGet, folderPath, nil) + assert.NoError(t, err) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + var foldersGet []vfs.BaseVirtualFolder + err = render.DecodeJSON(rr.Body, &foldersGet) + assert.NoError(t, err) + numFound := 0 + for _, f := range foldersGet { + if f.Name == folderName1 { + assert.Equal(t, mappedPath1, f.MappedPath) + assert.Equal(t, folderDesc1, f.Description) + numFound++ + } + if f.Name == folderName2 { + assert.Equal(t, mappedPath2, f.MappedPath) + assert.Equal(t, folderDesc2, f.Description) + numFound++ + } + } + assert.Equal(t, 2, numFound) + + req, err = http.NewRequest(http.MethodGet, webFoldersPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + req, err = http.NewRequest(http.MethodGet, webFoldersPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, webToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + for _, folder := range folders { + req, _ := http.NewRequest(http.MethodDelete, path.Join(folderPath, folder.Name), nil) + setBearerForReq(req, apiToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + } +} + +func TestAdminForgotPassword(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Filters.RequirePasswordChange = true + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webAdminForgotPwdPath, nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webAdminResetPwdPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webLoginPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + + form := make(url.Values) + form.Set("username", "") + // no csrf token + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusForbidden, rr.Code) + // empty username + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorUsernameRequired) + + lastResetCode = "" + form.Set("username", altAdminUsername) + // disable the admin + admin.Status = 0 + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Len(t, lastResetCode, 0) + + admin.Status = 1 + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + form = make(url.Values) + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusForbidden, rr.Code) + // no password + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + // no code + form.Set("password", defaultPassword) + form.Set("confirm_password", defaultPassword) + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + // disable the admin + admin.Status = 0 + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + form.Set("code", lastResetCode) + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + + admin.Status = 1 + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + // ok + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form.Set(csrfFormToken, csrfToken) + form.Set("username", altAdminUsername) + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + // not working smtp server + smtpCfg = smtp.Config{ + Host: "127.0.0.1", + Port: 3526, + From: "notification@example.com", + TemplatesPath: "templates", + } + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + form = make(url.Values) + form.Set("username", altAdminUsername) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorPwdResetSendEmail) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + form.Set("username", altAdminUsername) + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorPwdResetGeneric) + + req, err = http.NewRequest(http.MethodGet, webAdminForgotPwdPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webAdminResetPwdPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + admin, _, err = httpdtest.GetAdminByUsername(admin.Username, http.StatusOK) + assert.NoError(t, err) + assert.False(t, admin.Filters.RequirePasswordChange) + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) +} + +func TestUserForgotPassword(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + u := getTestUser() + u.Email = "user@test.com" + u.Filters.WebClient = []string{sdk.WebClientPasswordResetDisabled} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientForgotPwdPath, nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientResetPwdPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + + form := make(url.Values) + form.Set("username", "") + // no csrf token + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusForbidden, rr.Code) + // empty username + form.Set(csrfFormToken, csrfToken) + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorUsernameRequired) + // user cannot reset the password + form.Set("username", user.Username) + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorPwdResetForbidded) + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Hour)) + user.Filters.WebClient = []string{sdk.WebClientAPIKeyAuthChangeDisabled} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + // user is expired + lastResetCode = "" + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Len(t, lastResetCode, 0) + + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour)) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + // no login token + form = make(url.Values) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusForbidden, rr.Code) + // no password + form.Set(csrfFormToken, csrfToken) + form.Set("password", "") + form.Set("confirm_password", "") + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + // passwords mismatch + form.Set("password", altAdminPassword) + form.Set("code", lastResetCode) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdNoMatch) + // no code + form.Del("code") + form.Set("confirm_password", altAdminPassword) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + // Invalid login condition + form.Set("code", lastResetCode) + user.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) + // ok + user.Filters.DeniedProtocols = []string{common.ProtocolFTP} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + + loginCookie, csrfToken, err = getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + assert.NoError(t, err) + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", user.Username) + lastResetCode = "" + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusFound, rr.Code) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientForgotPwdPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, webClientResetPwdPath, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + // user does not exist anymore + form = make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("code", lastResetCode) + form.Set("password", "pwd") + form.Set("confirm_password", "pwd") + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorChangePwdGeneric) +} + +func TestAPIForgotPassword(t *testing.T) { + smtpCfg := smtp.Config{ + Host: "127.0.0.1", + Port: 3525, + From: "notification@example.com", + TemplatesPath: "templates", + } + err := smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Email = "" + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + // no email, forgot pwd will not work + lastResetCode = "" + req, err := http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Your account does not have an email address") + + admin.Email = "admin@test.com" + admin, _, err = httpdtest.UpdateAdmin(admin, http.StatusOK) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + // invalid JSON + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer([]byte(`{`))) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + + resetReq := make(map[string]string) + resetReq["code"] = lastResetCode + resetReq["password"] = defaultPassword + asJSON, err := json.Marshal(resetReq) + assert.NoError(t, err) + + // a user cannot use an admin code + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "invalid confirmation code") + + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + // the same code cannot be reused + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "confirmation code not found") + + admin, err = dataprovider.AdminExists(altAdminUsername) + assert.NoError(t, err) + + match, err := admin.CheckPassword(defaultPassword) + assert.NoError(t, err) + assert.True(t, match) + lastResetCode = "" + // now the same for a user + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "Your account does not have an email address") + + user.Email = "user@test.com" + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + // invalid JSON + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer([]byte(`{`))) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + // remove the reset password permission + user.Filters.WebClient = []string{sdk.WebClientPasswordResetDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + resetReq["code"] = lastResetCode + resetReq["password"] = altAdminPassword + asJSON, err = json.Marshal(resetReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "you are not allowed to reset your password") + + user.Filters.WebClient = []string{sdk.WebClientSharesDisabled} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + // the same code cannot be reused + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "confirmation code not found") + + user, err = dataprovider.UserExists(defaultUsername, "") + assert.NoError(t, err) + err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(altAdminPassword)) + assert.NoError(t, err) + + lastResetCode = "" + // a request for a missing admin/user will be silently ignored + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, "missing-admin", "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Empty(t, lastResetCode) + + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, "missing-user", "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.Empty(t, lastResetCode) + + lastResetCode = "" + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + assert.GreaterOrEqual(t, len(lastResetCode), 20) + + smtpCfg = smtp.Config{} + err = smtpCfg.Initialize(configDir, true) + require.NoError(t, err) + + // without an smtp configuration reset password is not available + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "No SMTP configuration") + + req, err = http.NewRequest(http.MethodPost, path.Join(userPath, defaultUsername, "/forgot-password"), nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "No SMTP configuration") + + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + // the admin does not exist anymore + resetReq["code"] = lastResetCode + resetReq["password"] = altAdminPassword + asJSON, err = json.Marshal(resetReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, altAdminUsername, "/reset-password"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusBadRequest, rr) + assert.Contains(t, rr.Body.String(), "unable to associate the confirmation code with an existing admin") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestProviderClosedMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + csrfToken, err := getCSRFTokenFromInternalPageMock(webConfigsPath, token) + assert.NoError(t, err) + // create a role admin + role, resp, err := httpdtest.AddRole(getTestRole(), http.StatusCreated) + assert.NoError(t, err, string(resp)) + a := getTestAdmin() + a.Username = altAdminUsername + a.Password = altAdminPassword + a.Role = role.Name + a.Permissions = []string{dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers, + dataprovider.PermAdminDeleteUsers, dataprovider.PermAdminViewUsers} + admin, _, err := httpdtest.AddAdmin(a, http.StatusCreated) + assert.NoError(t, err) + altToken, err := getJWTWebTokenFromTestServer(altAdminUsername, altAdminPassword) + assert.NoError(t, err) + + dataprovider.Close() + + testReq := make(map[string]any) + testReq["password"] = redactedSecret + asJSON, err := json.Marshal(testReq) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, path.Join(webConfigsPath, "smtp", "test"), bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("X-CSRF-TOKEN", csrfToken) + rr := executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + testReq["base_redirect_url"] = "http://localhost" + testReq["client_secret"] = redactedSecret + asJSON, err = json.Marshal(testReq) + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("X-CSRF-TOKEN", csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodGet, webConfigsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodPost, webConfigsPath, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + getJSONFolders := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, _ := http.NewRequest(http.MethodGet, webFoldersPath+jsonAPISuffix, nil) + setJWTCookieForReq(req, token) + executeRequest(req) + } + getJSONFolders() + + getJSONGroups := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, _ := http.NewRequest(http.MethodGet, webGroupsPath+jsonAPISuffix, nil) + setJWTCookieForReq(req, token) + executeRequest(req) + } + getJSONGroups() + + getJSONUsers := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, _ := http.NewRequest(http.MethodGet, webUsersPath+jsonAPISuffix, nil) + setJWTCookieForReq(req, token) + executeRequest(req) + } + getJSONUsers() + + req, _ = http.NewRequest(http.MethodGet, webUserPath+"/0", nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", "test") + req, _ = http.NewRequest(http.MethodPost, webUserPath+"/0", strings.NewReader(form.Encode())) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + req, _ = http.NewRequest(http.MethodGet, path.Join(webAdminPath, defaultTokenAuthUser), nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, _ = http.NewRequest(http.MethodPost, path.Join(webAdminPath, defaultTokenAuthUser), strings.NewReader(form.Encode())) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + getJSONAdmins := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, _ := http.NewRequest(http.MethodGet, webAdminsPath+jsonAPISuffix, nil) + setJWTCookieForReq(req, token) + executeRequest(req) + } + getJSONAdmins() + + req, _ = http.NewRequest(http.MethodGet, path.Join(webFolderPath, defaultTokenAuthUser), nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, _ = http.NewRequest(http.MethodPost, path.Join(webFolderPath, defaultTokenAuthUser), strings.NewReader(form.Encode())) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, _ = http.NewRequest(http.MethodPost, webUserPath, strings.NewReader(form.Encode())) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, _ = http.NewRequest(http.MethodPost, webUserPath, strings.NewReader(form.Encode())) + setJWTCookieForReq(req, altToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodGet, webIPListPath+"/1/a", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodPost, webIPListPath+"/1/a", nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + getJSONRoles := func() { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + req, err := http.NewRequest(http.MethodGet, webAdminRolesPath+jsonAPISuffix, nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + executeRequest(req) + } + getJSONRoles() + + req, err = http.NewRequest(http.MethodGet, path.Join(webAdminRolePath, role.Name), nil) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminRolePath, role.Name), strings.NewReader(form.Encode())) + assert.NoError(t, err) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusInternalServerError, rr) + + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.BackupsPath = backupsPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + if config.GetProviderConf().Driver != dataprovider.MemoryDataProviderName { + _, err = httpdtest.RemoveAdmin(admin, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveRole(role, http.StatusOK) + assert.NoError(t, err) + } +} + +func TestWebConnectionsMock(t *testing.T) { + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, webConnectionsPath, nil) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) + setJWTCookieForReq(req, token) + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") + + req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) + setJWTCookieForReq(req, token) + setCSRFHeaderForReq(req, "csrfToken") + rr = executeRequest(req) + checkResponseCode(t, http.StatusForbidden, rr) + assert.Contains(t, rr.Body.String(), "Invalid token") + + csrfToken, err := getCSRFTokenFromInternalPageMock(webUserPath, token) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodDelete, path.Join(webConnectionsPath, "id"), nil) + setJWTCookieForReq(req, token) + setCSRFHeaderForReq(req, csrfToken) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) +} + +func TestGetWebStatusMock(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.RateLimitersConfig = []common.RateLimiterConfig{ + { + Average: 1, + Period: 1000, + Burst: 1, + Type: 1, + Protocols: []string{common.ProtocolFTP}, + }, + } + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + token, err := getJWTWebTokenFromTestServer(defaultTokenAuthUser, defaultTokenAuthPass) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodGet, webStatusPath, nil) + setJWTCookieForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestStaticFilesMock(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/static/favicon.png", nil) + assert.NoError(t, err) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, "/openapi/openapi.yaml", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + + req, err = http.NewRequest(http.MethodGet, "/static", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusMovedPermanently, rr) + location := rr.Header().Get("Location") + assert.Equal(t, "/static/", location) + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusNotFound, rr) + + req, err = http.NewRequest(http.MethodGet, "/openapi", nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusMovedPermanently, rr) + location = rr.Header().Get("Location") + assert.Equal(t, "/openapi/", location) + req, err = http.NewRequest(http.MethodGet, location, nil) + assert.NoError(t, err) + rr = executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) +} + +func TestPasswordChangeRequired(t *testing.T) { + user := getTestUser() + assert.False(t, user.MustChangePassword()) + user.Filters.RequirePasswordChange = true + assert.True(t, user.MustChangePassword()) + user.Filters.RequirePasswordChange = false + assert.False(t, user.MustChangePassword()) + user.Filters.PasswordExpiration = 2 + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now()) + assert.False(t, user.MustChangePassword()) + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(49 * time.Hour)) + assert.False(t, user.MustChangePassword()) + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(-49 * time.Hour)) + assert.True(t, user.MustChangePassword()) +} + +func TestPasswordExpiresIn(t *testing.T) { + user := getTestUser() + user.Filters.PasswordExpiration = 30 + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(-15*24*time.Hour + 1*time.Hour)) + res := user.PasswordExpiresIn() + assert.Equal(t, 15, res) + user.Filters.PasswordExpiration = 15 + res = user.PasswordExpiresIn() + assert.Equal(t, 1, res) + user.LastPasswordChange = util.GetTimeAsMsSinceEpoch(time.Now().Add(-15*24*time.Hour - 1*time.Hour)) + res = user.PasswordExpiresIn() + assert.Equal(t, 0, res) + user.Filters.PasswordExpiration = 5 + res = user.PasswordExpiresIn() + assert.Equal(t, -10, res) +} + +func TestSecondFactorRequirements(t *testing.T) { + user := getTestUser() + user.Filters.TwoFactorAuthProtocols = []string{common.ProtocolHTTP, common.ProtocolSSH} + assert.True(t, user.MustSetSecondFactor()) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) + assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) + assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) + + user.Filters.TOTPConfig.Enabled = true + assert.True(t, user.MustSetSecondFactor()) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) + assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) + assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) + + user.Filters.TOTPConfig.Protocols = []string{common.ProtocolHTTP} + assert.True(t, user.MustSetSecondFactor()) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) + assert.True(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) + + user.Filters.TOTPConfig.Protocols = []string{common.ProtocolHTTP, common.ProtocolSSH} + assert.False(t, user.MustSetSecondFactor()) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolFTP)) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolHTTP)) + assert.False(t, user.MustSetSecondFactorForProtocol(common.ProtocolSSH)) +} + +func TestIsNameValid(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"simple name", "user", true}, + {"alphanumeric", "User123", true}, + {"unicode allowed", "你好", true}, + {"emoji allowed", "user😊", true}, + {"name with dot", "file.txt", true}, + {"name with multiple dots", "archive.tar.gz", true}, + {"control char", "abc\u0001", false}, + {"newline", "abc\n", false}, + {"tab", "abc\t", false}, + {"slash", "user/name", false}, + {"backslash", "user\\name", false}, + {"colon", "user:name", false}, + {"single dot", ".", false}, + {"double dot", "..", false}, + {"dot with suffix allowed", ".hidden", true}, + {"name ending with dot", "file.", false}, + {"name ending with space", "file ", false}, + {"CON", "CON", false}, + {"con lowercase", "con", false}, + {"con with extension", "con.txt", false}, + {"LPT1", "LPT1", false}, + {"lpt1 lowercase", "lpt1", false}, + {"COM5 uppercase", "COM5", false}, + {"com9 with extension", "com9.log", false}, + {"NUL", "NUL", false}, + {"Valid because suffix changes base", "con123", true}, + {"base name split", "aux.pdf", false}, + {"valid long name", "auxiliary", true}, + {"space only", " ", false}, + {"dot inside", "ab.cd.ef", true}, + {"unicode that ends with dot", "你好.", false}, + {"unicode that ends with space", "你好 ", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := util.IsNameValid(tt.input) + if result != tt.expected { + t.Errorf("IsNameValid(%q) = %v, expected %v", tt.input, result, tt.expected) + } + }) + } +} + +func startOIDCMockServer() { + go func() { + http.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprintf(w, "OK\n") + }) + http.HandleFunc("/auth/realms/sftpgo/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"issuer":"http://127.0.0.1:11111/auth/realms/sftpgo","authorization_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/auth","token_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token","introspection_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token/introspect","userinfo_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/userinfo","end_session_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/logout","frontchannel_logout_session_supported":true,"frontchannel_logout_supported":true,"jwks_uri":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/certs","check_session_iframe":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/login-status-iframe.html","grant_types_supported":["authorization_code","implicit","refresh_token","password","client_credentials","urn:ietf:params:oauth:grant-type:device_code","urn:openid:params:grant-type:ciba"],"response_types_supported":["code","none","id_token","token","id_token token","code id_token","code token","code id_token token"],"subject_types_supported":["public","pairwise"],"id_token_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"id_token_encryption_alg_values_supported":["RSA-OAEP","RSA-OAEP-256","RSA1_5"],"id_token_encryption_enc_values_supported":["A256GCM","A192GCM","A128GCM","A128CBC-HS256","A192CBC-HS384","A256CBC-HS512"],"userinfo_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512","none"],"request_object_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512","none"],"request_object_encryption_alg_values_supported":["RSA-OAEP","RSA-OAEP-256","RSA1_5"],"request_object_encryption_enc_values_supported":["A256GCM","A192GCM","A128GCM","A128CBC-HS256","A192CBC-HS384","A256CBC-HS512"],"response_modes_supported":["query","fragment","form_post","query.jwt","fragment.jwt","form_post.jwt","jwt"],"registration_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/clients-registrations/openid-connect","token_endpoint_auth_methods_supported":["private_key_jwt","client_secret_basic","client_secret_post","tls_client_auth","client_secret_jwt"],"token_endpoint_auth_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"introspection_endpoint_auth_methods_supported":["private_key_jwt","client_secret_basic","client_secret_post","tls_client_auth","client_secret_jwt"],"introspection_endpoint_auth_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"authorization_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"authorization_encryption_alg_values_supported":["RSA-OAEP","RSA-OAEP-256","RSA1_5"],"authorization_encryption_enc_values_supported":["A256GCM","A192GCM","A128GCM","A128CBC-HS256","A192CBC-HS384","A256CBC-HS512"],"claims_supported":["aud","sub","iss","auth_time","name","given_name","family_name","preferred_username","email","acr"],"claim_types_supported":["normal"],"claims_parameter_supported":true,"scopes_supported":["openid","phone","email","web-origins","offline_access","microprofile-jwt","profile","address","roles"],"request_parameter_supported":true,"request_uri_parameter_supported":true,"require_request_uri_registration":true,"code_challenge_methods_supported":["plain","S256"],"tls_client_certificate_bound_access_tokens":true,"revocation_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/revoke","revocation_endpoint_auth_methods_supported":["private_key_jwt","client_secret_basic","client_secret_post","tls_client_auth","client_secret_jwt"],"revocation_endpoint_auth_signing_alg_values_supported":["PS384","ES384","RS384","HS256","HS512","ES256","RS256","HS384","ES512","PS256","PS512","RS512"],"backchannel_logout_supported":true,"backchannel_logout_session_supported":true,"device_authorization_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/auth/device","backchannel_token_delivery_modes_supported":["poll","ping"],"backchannel_authentication_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/ciba/auth","backchannel_authentication_request_signing_alg_values_supported":["PS384","ES384","RS384","ES256","RS256","ES512","PS256","PS512","RS512"],"require_pushed_authorization_requests":false,"pushed_authorization_request_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/par/request","mtls_endpoint_aliases":{"token_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token","revocation_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/revoke","introspection_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/token/introspect","device_authorization_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/auth/device","registration_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/clients-registrations/openid-connect","userinfo_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/userinfo","pushed_authorization_request_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/par/request","backchannel_authentication_endpoint":"http://127.0.0.1:11111/auth/realms/sftpgo/protocol/openid-connect/ext/ciba/auth"}}`) + }) + http.HandleFunc("/404", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "Not found\n") + }) + if err := http.ListenAndServe(oidcMockAddr, nil); err != nil { + logger.ErrorToConsole("could not start HTTP notification server: %v", err) + os.Exit(1) + } + }() + waitTCPListening(oidcMockAddr) +} + +func waitForUsersQuotaScan(t *testing.T, token string) { + for { + var scans []common.ActiveQuotaScan + req, _ := http.NewRequest(http.MethodGet, quotaScanPath, nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err := render.DecodeJSON(rr.Body, &scans) + + if !assert.NoError(t, err, "Error getting active scans") { + break + } + if len(scans) == 0 { + break + } + time.Sleep(100 * time.Millisecond) + } +} + +func waitForFoldersQuotaScanPath(t *testing.T, token string) { + var scans []common.ActiveVirtualFolderQuotaScan + for { + req, _ := http.NewRequest(http.MethodGet, quotaScanVFolderPath, nil) + setBearerForReq(req, token) + rr := executeRequest(req) + checkResponseCode(t, http.StatusOK, rr) + err := render.DecodeJSON(rr.Body, &scans) + if !assert.NoError(t, err, "Error getting active folders scans") { + break + } + if len(scans) == 0 { + break + } + time.Sleep(100 * time.Millisecond) + } +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + logger.WarnToConsole("tcp server %v not listening: %v", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + logger.InfoToConsole("tcp server %v now listening", address) + conn.Close() + break + } +} + +func startSMTPServer() { + go func() { + if err := smtpd.ListenAndServe(smtpServerAddr, func(_ net.Addr, _ string, _ []string, data []byte) error { + re := regexp.MustCompile(`code is ".*?"`) + code := strings.TrimPrefix(string(re.Find(data)), "code is ") + lastResetCode = strings.ReplaceAll(code, "\"", "") + return nil + }, "SFTPGo test", "localhost"); err != nil { + logger.ErrorToConsole("could not start SMTP server: %v", err) + os.Exit(1) + } + }() + waitTCPListening(smtpServerAddr) +} + +func getTestAdmin() dataprovider.Admin { + return dataprovider.Admin{ + Username: defaultTokenAuthUser, + Password: defaultTokenAuthPass, + Status: 1, + Permissions: []string{dataprovider.PermAdminAny}, + Email: "admin@example.com", + Description: "test admin", + } +} + +func getTestGroup() dataprovider.Group { + return dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "test_group", + Description: "test group description", + }, + } +} + +func getTestRole() dataprovider.Role { + return dataprovider.Role{ + Name: "test_role", + Description: "test role description", + } +} + +func getTestUser() dataprovider.User { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Status: 1, + Description: "test user", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = defaultPerms + return user +} + +func getTestSFTPUser() dataprovider.User { + u := getTestUser() + u.Username = u.Username + "_sftp" + u.FsConfig.Provider = sdk.SFTPFilesystemProvider + u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + return u +} + +func getUserAsJSON(t *testing.T, user dataprovider.User) []byte { + json, err := json.Marshal(user) + assert.NoError(t, err) + return json +} + +func getCSRFTokenFromInternalPageMock(urlPath, token string) (string, error) { + req, err := http.NewRequest(http.MethodGet, urlPath, nil) + if err != nil { + return "", err + } + req.RequestURI = urlPath + setJWTCookieForReq(req, token) + rr := executeRequest(req) + if rr.Code != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", rr.Code) + } + return getCSRFTokenFromBody(rr.Body) +} + +func getCSRFTokenMock(loginURLPath, remoteAddr string) (string, string, error) { + req, err := http.NewRequest(http.MethodGet, loginURLPath, nil) + if err != nil { + return "", "", err + } + req.RemoteAddr = remoteAddr + rr := executeRequest(req) + cookie := rr.Header().Get("Set-Cookie") + if cookie == "" { + return "", "", errors.New("unable to get login cookie") + } + token, err := getCSRFTokenFromBody(bytes.NewBuffer(rr.Body.Bytes())) + return cookie, token, err +} + +func getCSRFToken(url string) (string, string, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "", "", err + } + resp, err := httpclient.GetHTTPClient().Do(req) + if err != nil { + return "", "", err + } + cookie := resp.Header.Get("Set-Cookie") + if cookie == "" { + return "", "", errors.New("no login cookie") + } + + defer resp.Body.Close() + + token, err := getCSRFTokenFromBody(resp.Body) + return cookie, token, err +} + +func getCSRFTokenFromBody(body io.Reader) (string, error) { + doc, err := html.Parse(body) + if err != nil { + return "", err + } + + var csrfToken string + var f func(*html.Node) + + f = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "input" { + var name, value string + for _, attr := range n.Attr { + if attr.Key == "value" { + value = attr.Val + } + if attr.Key == "name" { + name = attr.Val + } + } + if name == csrfFormToken { + csrfToken = value + return + } + } + + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + + f(doc) + + if csrfToken == "" { + return "", errors.New("CSRF token not found") + } + + return csrfToken, nil +} + +func getLoginForm(username, password, csrfToken string) url.Values { + form := make(url.Values) + form.Set("username", username) + form.Set("password", password) + form.Set(csrfFormToken, csrfToken) + return form +} + +func setCSRFHeaderForReq(req *http.Request, csrfToken string) { + req.Header.Set("X-CSRF-TOKEN", csrfToken) +} + +func setBearerForReq(req *http.Request, jwtToken string) { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", jwtToken)) +} + +func setAPIKeyForReq(req *http.Request, apiKey, username string) { + if username != "" { + apiKey += "." + username + } + req.Header.Set("X-SFTPGO-API-KEY", apiKey) +} + +func setLoginCookie(req *http.Request, cookie string) { + req.Header.Set("Cookie", cookie) +} + +func setJWTCookieForReq(req *http.Request, jwtToken string) { + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Cookie", fmt.Sprintf("jwt=%v", jwtToken)) +} + +func getJWTAPITokenFromTestServer(username, password string) (string, error) { + return getJWTAPITokenFromTestServerWithPasscode(username, password, "") +} + +func getJWTAPITokenFromTestServerWithPasscode(username, password, passcode string) (string, error) { + req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) + req.SetBasicAuth(username, password) + if passcode != "" { + req.Header.Set("X-SFTPGO-OTP", passcode) + } + rr := executeRequest(req) + if rr.Code != http.StatusOK { + return "", fmt.Errorf("unexpected status code %v", rr.Code) + } + responseHolder := make(map[string]any) + err := render.DecodeJSON(rr.Body, &responseHolder) + if err != nil { + return "", err + } + return responseHolder["access_token"].(string), nil +} + +func getJWTAPIUserTokenFromTestServer(username, password string) (string, error) { + req, _ := http.NewRequest(http.MethodGet, userTokenPath, nil) + req.SetBasicAuth(username, password) + rr := executeRequest(req) + if rr.Code != http.StatusOK { + return "", fmt.Errorf("unexpected status code %v", rr.Code) + } + responseHolder := make(map[string]any) + err := render.DecodeJSON(rr.Body, &responseHolder) + if err != nil { + return "", err + } + return responseHolder["access_token"].(string), nil +} + +func getJWTWebToken(username, password string) (string, error) { + loginCookie, csrfToken, err := getCSRFToken(httpBaseURL + webLoginPath) + if err != nil { + return "", err + } + form := getLoginForm(username, password, csrfToken) + req, _ := http.NewRequest(http.MethodPost, httpBaseURL+webLoginPath, + bytes.NewBuffer([]byte(form.Encode()))) + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusFound { + return "", fmt.Errorf("unexpected status code %v", resp.StatusCode) + } + cookie := resp.Header.Get("Set-Cookie") + if strings.HasPrefix(cookie, "jwt=") { + return cookie[4:], nil + } + return "", errors.New("no cookie found") +} + +func getCookieFromResponse(rr *httptest.ResponseRecorder) (string, error) { + cookie := strings.Split(rr.Header().Get("Set-Cookie"), ";") + if strings.HasPrefix(cookie[0], "jwt=") { + return cookie[0][4:], nil + } + return "", errors.New("no cookie found") +} + +func getJWTWebClientTokenFromTestServerWithAddr(username, password, remoteAddr string) (string, error) { + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, remoteAddr) + if err != nil { + return "", err + } + form := getLoginForm(username, password, csrfToken) + req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = remoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + if rr.Code != http.StatusFound { + return "", fmt.Errorf("unexpected status code %v", rr) + } + return getCookieFromResponse(rr) +} + +func getJWTWebClientTokenFromTestServer(username, password string) (string, error) { + loginCookie, csrfToken, err := getCSRFTokenMock(webClientLoginPath, defaultRemoteAddr) + if err != nil { + return "", err + } + form := getLoginForm(username, password, csrfToken) + req, _ := http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + if rr.Code != http.StatusFound { + return "", fmt.Errorf("unexpected status code %v", rr) + } + return getCookieFromResponse(rr) +} + +func getJWTWebTokenFromTestServer(username, password string) (string, error) { + loginCookie, csrfToken, err := getCSRFTokenMock(webLoginPath, defaultRemoteAddr) + if err != nil { + return "", err + } + form := getLoginForm(username, password, csrfToken) + req, _ := http.NewRequest(http.MethodPost, webLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.RemoteAddr = defaultRemoteAddr + setLoginCookie(req, loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := executeRequest(req) + if rr.Code != http.StatusFound { + return "", fmt.Errorf("unexpected status code %v", rr) + } + return getCookieFromResponse(rr) +} + +func executeRequest(req *http.Request) *httptest.ResponseRecorder { + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + return rr +} + +func checkResponseCode(t *testing.T, expected int, rr *httptest.ResponseRecorder) { + assert.Equal(t, expected, rr.Code, rr.Body.String()) +} + +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func createTestFile(path string, size int64) error { + baseDir := filepath.Dir(path) + if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(baseDir, os.ModePerm) + if err != nil { + return err + } + } + content := make([]byte, size) + if size > 0 { + _, err := rand.Read(content) + if err != nil { + return err + } + } + return os.WriteFile(path, content, os.ModePerm) +} + +func getExitCodeScriptContent(exitCode int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) + return content +} + +func getMultipartFormData(values url.Values, fileFieldName, filePath string) (bytes.Buffer, string, error) { + var b bytes.Buffer + w := multipart.NewWriter(&b) + for k, v := range values { + for _, s := range v { + if err := w.WriteField(k, s); err != nil { + return b, "", err + } + } + } + if len(fileFieldName) > 0 && len(filePath) > 0 { + fw, err := w.CreateFormFile(fileFieldName, filepath.Base(filePath)) + if err != nil { + return b, "", err + } + f, err := os.Open(filePath) + if err != nil { + return b, "", err + } + defer f.Close() + if _, err = io.Copy(fw, f); err != nil { + return b, "", err + } + } + err := w.Close() + return b, w.FormDataContentType(), err +} + +func generateTOTPPasscode(secret string) (string, error) { + return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) +} + +func isDbDefenderSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} + +func createTestPNG(name string, width, height int, imgColor color.Color) error { + upLeft := image.Point{0, 0} + lowRight := image.Point{width, height} + img := image.NewRGBA(image.Rectangle{upLeft, lowRight}) + for x := 0; x < width; x++ { + for y := 0; y < height; y++ { + img.Set(x, y, imgColor) + } + } + f, err := os.Create(name) + if err != nil { + return err + } + defer f.Close() + return png.Encode(f, img) +} + +func BenchmarkSecretDecryption(b *testing.B) { + s := kms.NewPlainSecret("test data") + s.SetAdditionalData("username") + err := s.Encrypt() + require.NoError(b, err) + for i := 0; i < b.N; i++ { + err = s.Clone().Decrypt() + require.NoError(b, err) + } +} diff --git a/internal/httpd/internal_test.go b/internal/httpd/internal_test.go new file mode 100644 index 00000000..987f0263 --- /dev/null +++ b/internal/httpd/internal_test.go @@ -0,0 +1,4374 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "io/fs" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-jose/go-jose/v4" + josejwt "github.com/go-jose/go-jose/v4/jwt" + "github.com/klauspost/compress/zip" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/sftpgo/sdk/plugin/notifier" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/html" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + httpdCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + httpdKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` + caCRT = `-----BEGIN CERTIFICATE----- +MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 +QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT +CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m +fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 +v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh +yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL +CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U +4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb +KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo +NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb +S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i +M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY +/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy +OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD +AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu +XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F +sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki +pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE +jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX +y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy +WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV +oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV +L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 +DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E +eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli +SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w +kK8u1YiiVenmoQ== +-----END CERTIFICATE-----` + caKey = `-----BEGIN RSA PRIVATE KEY----- +MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x +b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD +OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz +n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj +7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ +De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb +cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq +R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 +8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB +P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z +xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA +AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY +00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa +KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t +Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR +y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz ++465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc +9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx +lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR +nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ +u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz +pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP +L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 +0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM +0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN ++jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u +BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe +KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO +Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 +/hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn +v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX +pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H +RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO +hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH +Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR +V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP +ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz +1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h +iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y +BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q +uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN +38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ +SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq +jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck +vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ +2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB +PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ +VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor +PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== +-----END RSA PRIVATE KEY-----` + caCRL = `-----BEGIN X509 CRL----- +MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN +MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg +ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r +rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV ++jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 +XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE +jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF +DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD +MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 +iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE +ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 +hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 +cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 +sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn +w7bXJGfJcIMrrKs= +-----END X509 CRL-----` + client1Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 +HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE +OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 +Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf +XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO +c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo +Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU +OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC +sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I +l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb +BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz +MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 +QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz +J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ +offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX +G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr +kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au +MU3Bo0A= +-----END CERTIFICATE-----` + client1Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki +J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi +85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW +eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 +ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy +bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO +hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw +eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl +X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ +CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL +j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU +NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf +sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB +Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB +LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g +mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 +T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn +RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa +Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW +asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ +DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 +0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls +ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY +nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe +yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== +-----END RSA PRIVATE KEY-----` + // client 2 crt is revoked + client2Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b +Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua +eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ +cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT +A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 +6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f +Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz +PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY +xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D +EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB +4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO +78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 +7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e +qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N +f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv +/ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S +ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR +E0+r2+4= +-----END CERTIFICATE-----` + client2Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV +jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP +kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK +h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P +QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu +4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq +op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM +qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt +Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL +VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO +qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy +VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl +NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 +4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl +/owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd +aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu +khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz +3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ +eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i +vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB +GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb +Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D +8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl +RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA +ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl +-----END RSA PRIVATE KEY-----` + defaultAdminUsername = "admin" + defaultAdminPass = "password" + defeaultUsername = "test_user" +) + +var ( + configDir = filepath.Join(".", "..", "..") +) + +type failingWriter struct { +} + +func (r *failingWriter) Write(_ []byte) (n int, err error) { + return 0, errors.New("write error") +} + +func (r *failingWriter) WriteHeader(_ int) {} + +func (r *failingWriter) Header() http.Header { + return make(http.Header) +} + +type failingJoseSigner struct{} + +func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { + return nil, errors.New("sign test error") +} + +func (s *failingJoseSigner) Options() jose.SignerOptions { + return jose.SignerOptions{} +} + +func TestShouldBind(t *testing.T) { + c := Conf{ + Bindings: []Binding{ + { + Port: 10000, + }, + }, + } + require.False(t, c.ShouldBind()) + c.Bindings[0].EnableRESTAPI = true + require.True(t, c.ShouldBind()) + + c.Bindings[0].Port = 0 + require.False(t, c.ShouldBind()) + + if runtime.GOOS != osWindows { + c.Bindings[0].Address = "/absolute/path" + require.True(t, c.ShouldBind()) + } +} + +func TestBrandingValidation(t *testing.T) { + b := Binding{ + Branding: Branding{ + WebAdmin: UIBranding{ + LogoPath: "path1", + DefaultCSS: []string{"my.css"}, + }, + WebClient: UIBranding{ + FaviconPath: "favicon1.ico", + DisclaimerPath: "../path2", + ExtraCSS: []string{"1.css"}, + }, + }, + } + b.checkBranding() + assert.Equal(t, "/favicon.png", b.Branding.WebAdmin.FaviconPath) + assert.Equal(t, "/path1", b.Branding.WebAdmin.LogoPath) + assert.Equal(t, []string{"/my.css"}, b.Branding.WebAdmin.DefaultCSS) + assert.Len(t, b.Branding.WebAdmin.ExtraCSS, 0) + assert.Equal(t, "/favicon1.ico", b.Branding.WebClient.FaviconPath) + assert.Equal(t, path.Join(webStaticFilesPath, "/path2"), b.Branding.WebClient.DisclaimerPath) + if assert.Len(t, b.Branding.WebClient.ExtraCSS, 1) { + assert.Equal(t, "/1.css", b.Branding.WebClient.ExtraCSS[0]) + } + b.Branding.WebAdmin.DisclaimerPath = "https://example.com" + b.checkBranding() + assert.Equal(t, "https://example.com", b.Branding.WebAdmin.DisclaimerPath) +} + +func TestRedactedConf(t *testing.T) { + c := Conf{ + SigningPassphrase: "passphrase", + Setup: SetupConfig{ + InstallationCode: "123", + }, + } + redactedField := "[redacted]" + redactedConf := c.getRedacted() + assert.Equal(t, redactedField, redactedConf.SigningPassphrase) + assert.Equal(t, redactedField, redactedConf.Setup.InstallationCode) + assert.NotEqual(t, c.SigningPassphrase, redactedConf.SigningPassphrase) + assert.NotEqual(t, c.Setup.InstallationCode, redactedConf.Setup.InstallationCode) +} + +func TestGetRespStatus(t *testing.T) { + var err error + err = util.NewMethodDisabledError("") + respStatus := getRespStatus(err) + assert.Equal(t, http.StatusForbidden, respStatus) + err = fmt.Errorf("generic error") + respStatus = getRespStatus(err) + assert.Equal(t, http.StatusInternalServerError, respStatus) + respStatus = getRespStatus(plugin.ErrNoSearcher) + assert.Equal(t, http.StatusNotImplemented, respStatus) +} + +func TestMappedStatusCode(t *testing.T) { + err := os.ErrPermission + code := getMappedStatusCode(err) + assert.Equal(t, http.StatusForbidden, code) + err = os.ErrNotExist + code = getMappedStatusCode(err) + assert.Equal(t, http.StatusNotFound, code) + err = common.ErrQuotaExceeded + code = getMappedStatusCode(err) + assert.Equal(t, http.StatusRequestEntityTooLarge, code) + err = os.ErrClosed + code = getMappedStatusCode(err) + assert.Equal(t, http.StatusInternalServerError, code) + err = &http.MaxBytesError{} + code = getMappedStatusCode(err) + assert.Equal(t, http.StatusRequestEntityTooLarge, code) +} + +func TestGCSWebInvalidFormFile(t *testing.T) { + form := make(url.Values) + form.Set("username", "test_username") + form.Set("fs_provider", "2") + req, _ := http.NewRequest(http.MethodPost, webUserPath, strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + err := req.ParseForm() + assert.NoError(t, err) + _, err = getFsConfigFromPostFields(req) + assert.EqualError(t, err, http.ErrNotMultipart.Error()) +} + +func TestBrandingInvalidFormFile(t *testing.T) { + form := make(url.Values) + req, _ := http.NewRequest(http.MethodPost, webConfigsPath, strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + err := req.ParseForm() + assert.NoError(t, err) + _, err = getBrandingConfigFromPostFields(req, &dataprovider.BrandingConfigs{}) + assert.EqualError(t, err, http.ErrNotMultipart.Error()) +} + +func TestTokenDuration(t *testing.T) { + assert.Equal(t, shareTokenDuration, getTokenDuration(tokenAudienceWebShare)) + assert.Equal(t, apiTokenDuration, getTokenDuration(tokenAudienceAPI)) + assert.Equal(t, apiTokenDuration, getTokenDuration(tokenAudienceAPIUser)) + assert.Equal(t, cookieTokenDuration, getTokenDuration(tokenAudienceWebAdmin)) + assert.Equal(t, csrfTokenDuration, getTokenDuration(tokenAudienceCSRF)) + assert.Equal(t, 20*time.Minute, getTokenDuration("")) + + updateTokensDuration(30, 660, 360) + assert.Equal(t, 30*time.Minute, apiTokenDuration) + assert.Equal(t, 11*time.Hour, cookieTokenDuration) + assert.Equal(t, 11*time.Hour, csrfTokenDuration) + assert.Equal(t, 6*time.Hour, shareTokenDuration) + assert.Equal(t, 11*time.Hour, getMaxCookieDuration()) + + csrfTokenDuration = 1 * time.Hour + assert.Equal(t, 11*time.Hour, getMaxCookieDuration()) +} + +func TestVerifyCSRFToken(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) + require.NoError(t, err) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, fs.ErrPermission)) + + rr := httptest.NewRecorder() + tokenString := createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath) + assert.NotEmpty(t, tokenString) + + claims, err := jwt.VerifyToken(server.csrfTokenAuth, tokenString) + require.NoError(t, err) + assert.Empty(t, claims.Ref) + + req.Form = url.Values{} + req.Form.Set(csrfFormToken, tokenString) + err = verifyCSRFToken(req, server.csrfTokenAuth) + assert.ErrorIs(t, err, fs.ErrPermission) + + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) + require.NoError(t, err) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) + req.Form = url.Values{} + req.Form.Set(csrfFormToken, tokenString) + err = verifyCSRFToken(req, server.csrfTokenAuth) + assert.ErrorContains(t, err, "unexpected form token") + + claims = jwt.NewClaims(tokenAudienceCSRF, "", getTokenDuration(tokenAudienceCSRF)) + tokenString, err = josejwt.Signed(server.csrfTokenAuth.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + req, err = http.NewRequest(http.MethodPost, webAdminEventActionPath, nil) + require.NoError(t, err) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) + req.Form = url.Values{} + req.Form.Set(csrfFormToken, tokenString) + err = verifyCSRFToken(req, server.csrfTokenAuth) + assert.ErrorContains(t, err, "the form token is not valid") +} + +func TestInvalidToken(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + admin := dataprovider.Admin{ + Username: "admin", + } + errFake := errors.New("fake error") + asJSON, err := json.Marshal(admin) + assert.NoError(t, err) + req, _ := http.NewRequest(http.MethodPut, path.Join(adminPath, admin.Username), bytes.NewBuffer(asJSON)) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("username", admin.Username) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), jwt.ErrorCtxKey, errFake)) + rr := httptest.NewRecorder() + updateAdmin(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + rr = httptest.NewRecorder() + deleteAdmin(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + adminPwd := pwdChange{ + CurrentPassword: "old", + NewPassword: "new", + } + asJSON, err = json.Marshal(adminPwd) + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodPut, "", bytes.NewBuffer(asJSON)) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), jwt.ErrorCtxKey, errFake)) + rr = httptest.NewRecorder() + changeAdminPassword(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + adm := getAdminFromToken(req) + assert.Empty(t, adm.Username) + + rr = httptest.NewRecorder() + readUserFolder(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getUserFile(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getUserFilesAsZipStream(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getShares(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getShareByID(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addShare(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateShare(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteShare(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + generateTOTPSecret(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + saveTOTPConfig(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getRecoveryCodes(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + generateRecoveryCodes(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getUserProfile(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateUserProfile(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getWebTask(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getAdminProfile(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateAdminProfile(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + loadData(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + loadDataFromRequest(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addUser(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + disableUser2FA(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateUser(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteUser(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getActiveConnections(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + handleCloseConnection(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + server.handleWebRestore(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebAddUserPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateUserPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebTemplateFolderPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebTemplateUserPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + getAllAdmins(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + getAllUsers(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + addFolder(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateFolder(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getFolderByName(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteFolder(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + server.handleWebAddFolderPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateFolderPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebGetConnections(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebConfigsPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + addAdmin(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + disableAdmin2FA(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addAPIKey(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateAPIKey(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteAPIKey(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addGroup(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateGroup(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getGroupByName(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteGroup(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addEventAction(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getEventActionByName(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateEventAction(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteEventAction(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getEventRuleByName(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addEventRule(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateEventRule(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteEventRule(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getUsersQuotaScans(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateUserTransferQuotaUsage(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + doUpdateUserQuotaUsage(rr, req, "", quotaUsage{}) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + doStartUserQuotaScan(rr, req, "") + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getRetentionChecks(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addRole(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateRole(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteRole(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getUsers(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + getUserByUsername(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + searchFsEvents(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + searchProviderEvents(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + searchLogEvents(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + addIPListEntry(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + updateIPListEntry(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + deleteIPListEntry(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") + + rr = httptest.NewRecorder() + server.handleGetWebUsers(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateUserGet(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateRolePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebAddRolePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebAddAdminPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebAddGroupPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateGroupPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebAddEventActionPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateEventActionPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebAddEventRulePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateEventRulePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebUpdateIPListEntryPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + server.handleWebClientTwoFactorRecoveryPost(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + server.handleWebClientTwoFactorPost(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + server.handleWebAdminTwoFactorRecoveryPost(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + server.handleWebAdminTwoFactorPost(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + server.handleWebUpdateIPListEntryPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + form := make(url.Values) + req, _ = http.NewRequest(http.MethodPost, webIPListPath+"/1", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rctx = chi.NewRouteContext() + rctx.URLParams.Add("type", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + rr = httptest.NewRecorder() + server.handleWebAddIPListEntryPost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) +} + +func TestTokenSignatureValidation(t *testing.T) { + tokenValidationMode = 0 + server := httpdServer{ + binding: Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + }, + enableWebAdmin: true, + enableWebClient: true, + enableRESTAPI: true, + } + err := server.initializeRouter() + require.NoError(t, err) + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, tokenPath, nil) + require.NoError(t, err) + req.SetBasicAuth(defaultAdminUsername, defaultAdminPass) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + var resp map[string]any + err = json.Unmarshal(rr.Body.Bytes(), &resp) + assert.NoError(t, err) + accessToken := resp["access_token"] + require.NotEmpty(t, accessToken) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + // change the token validation mode + tokenValidationMode = 2 + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + // Now update the admin + admin, err := dataprovider.AdminExists(defaultAdminUsername) + assert.NoError(t, err) + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) + // token validation mode is 0, the old token is still valid + tokenValidationMode = 0 + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + // change the token validation mode + tokenValidationMode = 2 + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + // the token is invalidated, changing the validation mode has no effect + tokenValidationMode = 0 + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, versionPath, nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + + userPwd := "pwd" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: defeaultUsername, + Password: userPwd, + HomeDir: filepath.Join(os.TempDir(), defeaultUsername), + Status: 1, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + defer func() { + dataprovider.DeleteUser(defeaultUsername, "", "", "") //nolint:errcheck + }() + + tokenValidationMode = 2 + req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) + require.NoError(t, err) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + loginCookie := strings.Split(rr.Header().Get("Set-Cookie"), ";")[0] + assert.NotEmpty(t, loginCookie) + csrfToken, err := getCSRFTokenFromBody(rr.Body) + assert.NoError(t, err) + assert.NotEmpty(t, csrfToken) + // Now login + form := make(url.Values) + form.Set(csrfFormToken, csrfToken) + form.Set("username", defeaultUsername) + form.Set("password", userPwd) + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code) + userCookie := strings.Split(rr.Header().Get("Set-Cookie"), ";")[0] + assert.NotEmpty(t, userCookie) + // Test a WebClient page and a JSON API + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + require.NoError(t, err) + req.Header.Set("Cookie", userCookie) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + require.NoError(t, err) + req.Header.Set("Cookie", userCookie) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + csrfToken, err = getCSRFTokenFromBody(rr.Body) + assert.NoError(t, err) + assert.NotEmpty(t, csrfToken) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, webClientFilePath+"?path=missing.txt", nil) + require.NoError(t, err) + req.Header.Set("Cookie", userCookie) + req.Header.Set(csrfHeaderToken, csrfToken) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + tokenValidationMode = 0 + err = dataprovider.DeleteUser(defeaultUsername, "", "", "") + assert.NoError(t, err) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, webClientFilePath+"?path=missing.txt", nil) + require.NoError(t, err) + req.Header.Set("Cookie", userCookie) + req.Header.Set(csrfHeaderToken, csrfToken) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + tokenValidationMode = 2 + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, webClientFilePath+"?path=missing.txt", nil) + require.NoError(t, err) + req.Header.Set("Cookie", userCookie) + req.Header.Set(csrfHeaderToken, csrfToken) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code) + + tokenValidationMode = 0 +} + +func TestUpdateWebAdminInvalidClaims(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + admin := dataprovider.Admin{ + Username: "", + Password: "password", + } + c := &jwt.Claims{ + Username: admin.Username, + Permissions: admin.Permissions, + } + c.Subject = admin.GetSignature() + token, err := server.tokenAuth.SignWithParams(c, tokenAudienceWebAdmin, "", 10*time.Minute) + assert.NoError(t, err) + resp := c.BuildTokenResponse(token) + + req, err := http.NewRequest(http.MethodGet, webAdminPath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) + parsedToken, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form := make(url.Values) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)) + form.Set("status", "1") + form.Set("default_users_expiration", "30") + req, err = http.NewRequest(http.MethodPost, path.Join(webAdminPath, "admin"), bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("username", "admin") + req = req.WithContext(ctx) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) + server.handleWebUpdateAdminPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) +} + +func TestUpdateSMTPSecrets(t *testing.T) { + currentConfigs := &dataprovider.SMTPConfigs{ + OAuth2: dataprovider.SMTPOAuth2{ + ClientSecret: kms.NewPlainSecret("client secret"), + RefreshToken: kms.NewPlainSecret("refresh token"), + }, + } + redactedClientSecret := kms.NewPlainSecret("secret") + redactedRefreshToken := kms.NewPlainSecret("token") + redactedClientSecret.SetStatus(sdkkms.SecretStatusRedacted) + redactedRefreshToken.SetStatus(sdkkms.SecretStatusRedacted) + newConfigs := &dataprovider.SMTPConfigs{ + Password: kms.NewPlainSecret("pwd"), + OAuth2: dataprovider.SMTPOAuth2{ + ClientSecret: redactedClientSecret, + RefreshToken: redactedRefreshToken, + }, + } + updateSMTPSecrets(newConfigs, currentConfigs) + assert.Nil(t, currentConfigs.Password) + assert.NotNil(t, newConfigs.Password) + assert.Equal(t, currentConfigs.OAuth2.ClientSecret, newConfigs.OAuth2.ClientSecret) + assert.Equal(t, currentConfigs.OAuth2.RefreshToken, newConfigs.OAuth2.RefreshToken) + + clientSecret := kms.NewPlainSecret("plain secret") + refreshToken := kms.NewPlainSecret("plain token") + newConfigs = &dataprovider.SMTPConfigs{ + Password: kms.NewPlainSecret("pwd"), + OAuth2: dataprovider.SMTPOAuth2{ + ClientSecret: clientSecret, + RefreshToken: refreshToken, + }, + } + updateSMTPSecrets(newConfigs, currentConfigs) + assert.Equal(t, clientSecret, newConfigs.OAuth2.ClientSecret) + assert.Equal(t, refreshToken, newConfigs.OAuth2.RefreshToken) +} + +func TestOAuth2Redirect(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state=invalid", nil) + assert.NoError(t, err) + server.handleOAuth2TokenRedirect(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorTitle) + + ip := "127.1.1.4" + tokenString := createOAuth2Token(server.csrfTokenAuth, xid.New().String(), ip) + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, webOAuth2RedirectPath+"?state="+tokenString, nil) //nolint:goconst + assert.NoError(t, err) + req.RemoteAddr = ip + server.handleOAuth2TokenRedirect(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nOAuth2ErrorValidateState) +} + +func TestOAuth2Token(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + // invalid token + _, err = verifyOAuth2Token(server.csrfTokenAuth, "token", "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to verify OAuth2 state") + } + // bad audience + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + + tokenString, err := server.csrfTokenAuth.Sign(claims) + assert.NoError(t, err) + _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid OAuth2 state") + } + // bad IP + tokenString = createOAuth2Token(server.csrfTokenAuth, "state", "127.1.1.1") + _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.2") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid OAuth2 state") + } + // ok + state := xid.New().String() + tokenString = createOAuth2Token(server.csrfTokenAuth, state, "127.1.1.3") + s, err := verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.3") + assert.NoError(t, err) + assert.Equal(t, state, s) + // no jti + claims = jwt.NewClaims(tokenAudienceOAuth2, "127.1.1.4", getTokenDuration(tokenAudienceOAuth2)) + tokenString, err = josejwt.Signed(server.csrfTokenAuth.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = verifyOAuth2Token(server.csrfTokenAuth, tokenString, "127.1.1.4") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid OAuth2 state") + } + // encode error + server.csrfTokenAuth.SetSigner(&failingJoseSigner{}) + tokenString = createOAuth2Token(server.csrfTokenAuth, xid.New().String(), "") + assert.Empty(t, tokenString) + + rr := httptest.NewRecorder() + testReq := make(map[string]any) + testReq["base_redirect_url"] = "http://localhost:8082" + asJSON, err := json.Marshal(testReq) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodPost, webOAuth2TokenPath, bytes.NewBuffer(asJSON)) + assert.NoError(t, err) + server.handleSMTPOAuth2TokenRequestPost(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "unable to create state token") +} + +func TestCSRFToken(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + // invalid token + req := &http.Request{} + err = verifyCSRFToken(req, server.csrfTokenAuth) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to verify form token") + } + // bad audience + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + tokenString, err := server.csrfTokenAuth.Sign(claims) + assert.NoError(t, err) + values := url.Values{} + values.Set(csrfFormToken, tokenString) + req.Form = values + err = verifyCSRFToken(req, server.csrfTokenAuth) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "form token is not valid") + } + + // bad IP + req.RemoteAddr = "127.1.1.1" + tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath) + values.Set(csrfFormToken, tokenString) + req.Form = values + req.RemoteAddr = "127.1.1.2" + err = verifyCSRFToken(req, server.csrfTokenAuth) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "form token is not valid") + } + + claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + tokenString, err = server.csrfTokenAuth.Sign(claims) + assert.NoError(t, err) + assert.NotEmpty(t, tokenString) + + r, err := GetHTTPRouter(Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + RenderOpenAPI: true, + }) + assert.NoError(t, err) + fn := server.verifyCSRFHeader(r) + rr := httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodDelete, path.Join(userPath, "username"), nil) + fn.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token") + + // invalid audience + req.Header.Set(csrfHeaderToken, tokenString) + rr = httptest.NewRecorder() + fn.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), "the token is not valid") + + // invalid IP + tokenString = createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath) + req.Header.Set(csrfHeaderToken, tokenString) + req.RemoteAddr = "172.16.1.2" + rr = httptest.NewRecorder() + fn.ServeHTTP(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), "the token is not valid") + + csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + csrfTokenAuth.SetSigner(&failingJoseSigner{}) + tokenString = createCSRFToken(httptest.NewRecorder(), req, csrfTokenAuth, "", webBaseAdminPath) + assert.Empty(t, tokenString) + rr = httptest.NewRecorder() + createLoginCookie(rr, req, csrfTokenAuth, "", webBaseAdminPath, req.RemoteAddr) + assert.Empty(t, rr.Header().Get("Set-Cookie")) +} + +func TestCreateShareCookieError(t *testing.T) { + username := "share_user" + pwd := util.GenerateUniqueID() + user := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: pwd, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + err := dataprovider.AddUser(user, "", "", "") + assert.NoError(t, err) + share := &dataprovider.Share{ + Name: "test_share_cookie_error", + ShareID: util.GenerateUniqueID(), + Scope: dataprovider.ShareScopeRead, + Password: pwd, + Paths: []string{"/"}, + Username: username, + } + err = dataprovider.AddShare(share, "", "", "") + assert.NoError(t, err) + + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + tokenAuth.SetSigner(&failingJoseSigner{}) + csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + + server := httpdServer{ + tokenAuth: tokenAuth, + csrfTokenAuth: csrfTokenAuth, + } + + c := jwt.NewClaims(tokenAudienceWebLogin, "127.0.0.1", getTokenDuration(tokenAudienceWebLogin)) + token, err := server.csrfTokenAuth.Sign(c) + assert.NoError(t, err) + resp := c.BuildTokenResponse(token) + parsedToken, err := jwt.VerifyToken(server.csrfTokenAuth, resp.Token) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, path.Join(webClientPubSharesPath, share.ShareID, "login"), nil) + assert.NoError(t, err) + req.RemoteAddr = "127.0.0.1:4567" + ctx := req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form := make(url.Values) + form.Set("share_password", pwd) + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseClientPath)) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("id", share.ShareID) + rr := httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, path.Join(webClientPubSharesPath, share.ShareID, "login"), + bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = "127.0.0.1:2345" + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", resp.Token)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req = req.WithContext(ctx) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + server.handleClientShareLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nError500Message) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) +} + +func TestCreateTokenError(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + tokenAuth.SetSigner(&failingJoseSigner{}) + csrfTokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + + server := httpdServer{ + tokenAuth: tokenAuth, + csrfTokenAuth: csrfTokenAuth, + } + rr := httptest.NewRecorder() + admin := dataprovider.Admin{ + Username: defaultAdminUsername, + Password: "password", + } + req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) + + server.generateAndSendToken(rr, req, admin, "") + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + rr = httptest.NewRecorder() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "u", + Password: util.GenerateUniqueID(), + }, + } + req, _ = http.NewRequest(http.MethodGet, userTokenPath, nil) + + server.generateAndSendUserToken(rr, req, "", user) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + c := &jwt.Claims{} + c.ID = xid.New().String() + c.SetExpiry(time.Now().Add(1 * time.Minute)) + tokenString, err := server.csrfTokenAuth.SignWithParams(c, tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + assert.NoError(t, err) + token := c.BuildTokenResponse(tokenString) + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token.Token)) + parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + rr = httptest.NewRecorder() + form := make(url.Values) + form.Set("username", admin.Username) + form.Set("password", admin.Password) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, xid.New().String(), webBaseAdminPath)) + cookie := rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, cookie) + req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Cookie", cookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + server.handleWebAdminLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + // req with no content type + req, _ = http.NewRequest(http.MethodPost, webAdminLoginPath, nil) + rr = httptest.NewRecorder() + server.handleWebAdminLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + req, _ = http.NewRequest(http.MethodPost, webAdminSetupPath, nil) + rr = httptest.NewRecorder() + server.loginAdmin(rr, req, &admin, false, nil, "") + // req with no POST body + req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%AO%GG", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + + req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A1%G2", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminChangePwdPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodGet, webAdminLoginPath+"?a=a%C3%A2%G3", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + _, err = getAdminFromPostFields(req) + assert.Error(t, err) + + req, _ = http.NewRequest(http.MethodPost, webAdminEventActionPath+"?a=a%C3%A2%GG", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + _, err = getEventActionFromPostFields(req) + assert.Error(t, err) + + req, _ = http.NewRequest(http.MethodPost, webAdminEventRulePath+"?a=a%C3%A3%GG", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + _, err = getEventRuleFromPostFields(req) + assert.Error(t, err) + + req, _ = http.NewRequest(http.MethodPost, webIPListPath+"/1?a=a%C3%AO%GG", nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + _, err = getIPListEntryFromPostFields(req, dataprovider.IPListTypeAllowList) + assert.Error(t, err) + + req, _ = http.NewRequest(http.MethodPost, path.Join(webClientSharePath, "shareID", "login?a=a%C3%AO%GG"), bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleClientShareLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + + req, _ = http.NewRequest(http.MethodPost, webClientLoginPath+"?a=a%C3%AO%GG", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + + req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath+"?a=a%C3%AO%GA", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientChangePwdPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath+"?a=a%C3%AO%GB", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientProfilePost(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) + + req, _ = http.NewRequest(http.MethodPost, webAdminProfilePath+"?a=a%C3%AO%GB", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminProfilePost(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) + + req, _ = http.NewRequest(http.MethodPost, webAdminTwoFactorPath+"?a=a%C3%AO%GC", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminTwoFactorPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webAdminTwoFactorRecoveryPath+"?a=a%C3%AO%GD", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminTwoFactorRecoveryPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webClientTwoFactorPath+"?a=a%C3%AO%GC", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientTwoFactorPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webClientTwoFactorRecoveryPath+"?a=a%C3%AO%GD", bytes.NewBuffer([]byte(form.Encode()))) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{}, nil)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientTwoFactorRecoveryPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webAdminForgotPwdPath+"?a=a%C3%A1%GD", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminForgotPwdPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webClientForgotPwdPath+"?a=a%C2%A1%GD", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientForgotPwdPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webAdminResetPwdPath+"?a=a%C3%AO%JD", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAdminPasswordResetPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webAdminRolePath+"?a=a%C3%AO%JE", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebAddRolePost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webClientResetPwdPath+"?a=a%C3%AO%JD", bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + server.handleWebClientPasswordResetPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidForm) + + req, _ = http.NewRequest(http.MethodPost, webChangeClientPwdPath+"?a=a%K3%AO%GA", bytes.NewBuffer([]byte(form.Encode()))) + _, err = getShareFromPostFields(req) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid URL escape") + } + + username := "webclientuser" + user = dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "clientpwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Description: "test user", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + user.Filters.AllowAPIKeyAuth = true + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token.Token)) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + rr = httptest.NewRecorder() + form = make(url.Values) + form.Set("username", user.Username) + form.Set("password", "clientpwd") + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath)) + req, _ = http.NewRequest(http.MethodPost, webClientLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleWebClientLoginPost(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + + err = authenticateUserWithAPIKey(username, "", server.tokenAuth, req) + assert.Error(t, err) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.HomeDir) + assert.NoError(t, err) + + admin.Username += "1" + admin.Status = 1 + admin.Filters.AllowAPIKeyAuth = true + admin.Permissions = []string{dataprovider.PermAdminAny} + err = dataprovider.AddAdmin(&admin, "", "", "") + assert.NoError(t, err) + + err = authenticateAdminWithAPIKey(admin.Username, "", server.tokenAuth, req) + assert.Error(t, err) + + err = dataprovider.DeleteAdmin(admin.Username, "", "", "") + assert.NoError(t, err) +} + +func TestAPIKeyAuthForbidden(t *testing.T) { + r, err := GetHTTPRouter(Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + RenderOpenAPI: true, + }) + require.NoError(t, err) + fn := forbidAPIKeyAuthentication(r) + rr := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, versionPath, nil) + fn.ServeHTTP(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "Invalid token claims") +} + +func TestJWTTokenValidation(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + claims := &jwt.Claims{ + Username: defaultAdminUsername, + } + claims.SetExpiry(time.Now().UTC().Add(-1 * time.Hour)) + _, err = tokenAuth.SignWithParams(claims, tokenAudienceWebAdmin, "", getTokenDuration(tokenAudienceWebAdmin)) + require.NoError(t, err) + + server := httpdServer{ + binding: Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + RenderOpenAPI: true, + }, + } + err = server.initializeRouter() + require.NoError(t, err) + r := server.router + fn := jwtAuthenticatorAPI(r) + rr := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, userPath, nil) + ctx := jwt.NewContext(req.Context(), claims, nil) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + + fn = jwtAuthenticatorWebAdmin(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + + fn = jwtAuthenticatorWebClient(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + + errTest := errors.New("test error") + permFn := server.checkPerms(dataprovider.PermAdminAny) + fn = permFn(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, userPath, nil) + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + permFn = server.checkPerms(dataprovider.PermAdminAny) + fn = permFn(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webUserPath, nil) + req.RequestURI = webUserPath + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + permClientFn := server.checkHTTPUserPerm(sdk.WebClientPubKeyChangeDisabled) + fn = permClientFn(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, nil) + req.RequestURI = webClientProfilePath + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, userProfilePath, nil) + req.RequestURI = userProfilePath + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + fn = server.checkAuthRequirements(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, webClientProfilePath, nil) + req.RequestURI = webClientProfilePath + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + fn = server.checkAuthRequirements(r) + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, webGroupsPath, nil) + req.RequestURI = webGroupsPath + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, userSharesPath, nil) + req.RequestURI = userSharesPath + ctx = jwt.NewContext(req.Context(), claims, errTest) + fn.ServeHTTP(rr, req.WithContext(ctx)) + assert.Equal(t, http.StatusBadRequest, rr.Code) +} + +func TestUpdateContextFromCookie(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + server := httpdServer{ + tokenAuth: tokenAuth, + } + req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) + claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + ctx := jwt.NewContext(req.Context(), claims, nil) + req = server.updateContextFromCookie(req.WithContext(ctx)) + token, err := jwt.FromContext(req.Context()) + require.NoError(t, err) + require.True(t, token.Audience.Contains(tokenAudienceWebClient)) + require.NotEmpty(t, token.ID) +} + +func TestCookieExpiration(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + server := httpdServer{ + tokenAuth: tokenAuth, + } + err = errors.New("test error") + rr := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, tokenPath, nil) + ctx := jwt.NewContext(req.Context(), nil, err) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie := rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) + claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + admin := dataprovider.Admin{ + Username: "newtestadmin", + Password: "password", + Permissions: []string{dataprovider.PermAdminAny}, + } + claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.Username = admin.Username + claims.Permissions = admin.Permissions + claims.Subject = admin.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + admin.Status = 0 + err = dataprovider.AddAdmin(&admin, "", "", "") + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + admin.Status = 1 + admin.Filters.AllowList = []string{"172.16.1.0/24"} + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + admin, err = dataprovider.AdminExists(admin.Username) + assert.NoError(t, err) + tokenID := xid.New().String() + claims = jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.ID = tokenID + claims.Username = admin.Username + claims.Permissions = admin.Permissions + claims.Subject = admin.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) + req.RemoteAddr = "192.168.8.1:1234" + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + req, _ = http.NewRequest(http.MethodGet, tokenPath, nil) + req.RemoteAddr = "172.16.1.12:4567" + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.True(t, strings.HasPrefix(cookie, "jwt=")) + req.Header.Set("Cookie", cookie) + c, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) + if assert.NoError(t, err) { + assert.Equal(t, tokenID, c.ID) + } + + err = dataprovider.DeleteAdmin(admin.Username, "", "", "") + assert.NoError(t, err) + // now check client cookie expiration + username := "client" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "clientpwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Description: "test user", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{"*"} + + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + // the password will be hashed and so the signature will change + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + user, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + user.Filters.AllowedIP = []string{"172.16.4.0/24"} + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + user, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + issuedAt := time.Now().Add(-1 * time.Minute) + expiresAt := time.Now().Add(1 * time.Minute) + + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(expiresAt) + claims.SetIssuedAt(issuedAt) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RemoteAddr = "172.16.3.12:4567" + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RemoteAddr = "172.16.4.16:4567" + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, cookie) + req.Header.Set("Cookie", cookie) + c, err = jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) + if assert.NoError(t, err) { + assert.Equal(t, tokenID, c.ID) + assert.Equal(t, issuedAt.Unix(), c.IssuedAt.Time().Unix()) + assert.NotEqual(t, expiresAt.Unix(), c.Expiry.Time().Unix()) + } + // test a cookie issued more that 12 hours ago + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(expiresAt) + claims.SetIssuedAt(time.Now().Add(-24 * time.Hour)) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.RemoteAddr = "172.16.4.16:6789" + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + // test a disabled user + user.Status = 0 + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + user, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + + claims = jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.ID = tokenID + claims.Username = user.Username + claims.Permissions = user.Filters.WebClient + claims.Subject = user.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + claims.SetIssuedAt(issuedAt) + _, err = server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + ctx = jwt.NewContext(req.Context(), claims, nil) + server.checkCookieExpiration(rr, req.WithContext(ctx)) + cookie = rr.Header().Get("Set-Cookie") + assert.Empty(t, cookie) + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) +} + +func TestGetURLParam(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, adminPwdPath, nil) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("val", "testuser%C3%A0") + rctx.URLParams.Add("inval", "testuser%C3%AO%GG") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + escaped := getURLParam(req, "val") + assert.Equal(t, "testuserà", escaped) + escaped = getURLParam(req, "inval") + assert.Equal(t, "testuser%C3%AO%GG", escaped) +} + +func TestChangePwdValidationErrors(t *testing.T) { + err := doChangeAdminPassword(nil, "", "", "") + require.Error(t, err) + err = doChangeAdminPassword(nil, "a", "b", "c") + require.Error(t, err) + err = doChangeAdminPassword(nil, "a", "a", "a") + require.Error(t, err) + + req, _ := http.NewRequest(http.MethodPut, adminPwdPath, nil) + req = req.WithContext(jwt.NewContext(req.Context(), &jwt.Claims{Claims: josejwt.Claims{ID: xid.New().String()}}, nil)) + err = doChangeAdminPassword(req, "currentpwd", "newpwd", "newpwd") + assert.Error(t, err) +} + +func TestRenderUnexistingFolder(t *testing.T) { + rr := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, folderPath, nil) + renderFolder(rr, req, "path not mapped", &jwt.Claims{}, http.StatusOK) + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func TestCloseConnectionHandler(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.Username = defaultAdminUsername + claims.SetExpiry(time.Now().UTC().Add(1 * time.Hour)) + _, err = tokenAuth.Sign(claims) + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodDelete, activeConnectionsPath+"/connectionID", nil) + assert.NoError(t, err) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("connectionID", "") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + req = req.WithContext(context.WithValue(req.Context(), jwt.TokenCtxKey, claims)) + rr := httptest.NewRecorder() + handleCloseConnection(rr, req) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "connectionID is mandatory") +} + +func TestRenderInvalidTemplate(t *testing.T) { + tmpl, err := template.New("test").Parse("{{.Count}}") + if assert.NoError(t, err) { + noMatchTmpl := "no_match" + adminTemplates[noMatchTmpl] = tmpl + rw := httptest.NewRecorder() + renderAdminTemplate(rw, noMatchTmpl, map[string]string{}) + assert.Equal(t, http.StatusInternalServerError, rw.Code) + clientTemplates[noMatchTmpl] = tmpl + renderClientTemplate(rw, noMatchTmpl, map[string]string{}) + assert.Equal(t, http.StatusInternalServerError, rw.Code) + } +} + +func TestQuotaScanInvalidFs(t *testing.T) { + user := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test", + HomeDir: os.TempDir(), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + }, + } + common.QuotaScans.AddUserQuotaScan(user.Username, "") + err := doUserQuotaScan(user) + assert.Error(t, err) +} + +func TestVerifyTLSConnection(t *testing.T) { + oldCertMgr := certMgr + + caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") + certPath := filepath.Join(os.TempDir(), "testh.crt") + keyPath := filepath.Join(os.TempDir(), "testh.key") + err := os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(certPath, []byte(httpdCert), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(keyPath, []byte(httpdKey), os.ModePerm) + assert.NoError(t, err) + + keyPairs := []common.TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: common.DefaultTLSKeyPaidID, + }, + } + certMgr, err = common.NewCertManager(keyPairs, "", "httpd_test") + assert.NoError(t, err) + + certMgr.SetCARevocationLists([]string{caCrlPath}) + err = certMgr.LoadCRLs() + assert.NoError(t, err) + + crt, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + x509crt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + server := httpdServer{} + state := tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{x509crt}, + } + + err = server.verifyTLSConnection(state) + assert.Error(t, err) // no verified certification chain + + crt, err = tls.X509KeyPair([]byte(caCRT), []byte(caKey)) + assert.NoError(t, err) + + x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crt, x509CAcrt}) + err = server.verifyTLSConnection(state) + assert.NoError(t, err) + + crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + x509crtRevoked, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crtRevoked, x509CAcrt}) + state.PeerCertificates = []*x509.Certificate{x509crtRevoked} + err = server.verifyTLSConnection(state) + assert.EqualError(t, err, common.ErrCrtRevoked.Error()) + + err = os.Remove(caCrlPath) + assert.NoError(t, err) + err = os.Remove(certPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + + certMgr = oldCertMgr +} + +func TestGetFolderFromTemplate(t *testing.T) { + folder := vfs.BaseVirtualFolder{ + MappedPath: "Folder%name%", + Description: "Folder %name% desc", + } + folderName := "folderTemplate" + folderTemplate := getFolderFromTemplate(folder, folderName) + require.Equal(t, folderName, folderTemplate.Name) + require.Equal(t, fmt.Sprintf("Folder%v", folderName), folderTemplate.MappedPath) + require.Equal(t, fmt.Sprintf("Folder %v desc", folderName), folderTemplate.Description) + + folder.FsConfig.Provider = sdk.CryptedFilesystemProvider + folder.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%name%") + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, folderName, folderTemplate.FsConfig.CryptConfig.Passphrase.GetPayload()) + + folder.FsConfig.Provider = sdk.GCSFilesystemProvider + folder.FsConfig.GCSConfig.KeyPrefix = "prefix%name%/" + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, fmt.Sprintf("prefix%v/", folderName), folderTemplate.FsConfig.GCSConfig.KeyPrefix) + + folder.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + folder.FsConfig.AzBlobConfig.KeyPrefix = "a%name%" + folder.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%name%") + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, "a"+folderName, folderTemplate.FsConfig.AzBlobConfig.KeyPrefix) + require.Equal(t, "pwd"+folderName, folderTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + + folder.FsConfig.Provider = sdk.SFTPFilesystemProvider + folder.FsConfig.SFTPConfig.Prefix = "%name%" + folder.FsConfig.SFTPConfig.Username = "sftp_%name%" + folder.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%name%") + folderTemplate = getFolderFromTemplate(folder, folderName) + require.Equal(t, folderName, folderTemplate.FsConfig.SFTPConfig.Prefix) + require.Equal(t, "sftp_"+folderName, folderTemplate.FsConfig.SFTPConfig.Username) + require.Equal(t, "sftp"+folderName, folderTemplate.FsConfig.SFTPConfig.Password.GetPayload()) +} + +func TestGetUserFromTemplate(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Status: 1, + }, + } + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: "Folder%username%", + }, + }) + + username := "userTemplate" + password := "pwdTemplate" + templateFields := userTemplateFields{ + Username: username, + Password: password, + } + + userTemplate := getUserFromTemplate(user, templateFields) + require.Len(t, userTemplate.VirtualFolders, 1) + require.Equal(t, "Folder"+username, userTemplate.VirtualFolders[0].Name) + + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("%password%") + userTemplate = getUserFromTemplate(user, templateFields) + require.Equal(t, password, userTemplate.FsConfig.CryptConfig.Passphrase.GetPayload()) + + user.FsConfig.Provider = sdk.GCSFilesystemProvider + user.FsConfig.GCSConfig.KeyPrefix = "%username%%password%" + userTemplate = getUserFromTemplate(user, templateFields) + require.Equal(t, username+password, userTemplate.FsConfig.GCSConfig.KeyPrefix) + + user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + user.FsConfig.AzBlobConfig.KeyPrefix = "a%username%" + user.FsConfig.AzBlobConfig.AccountKey = kms.NewPlainSecret("pwd%password%%username%") + userTemplate = getUserFromTemplate(user, templateFields) + require.Equal(t, "a"+username, userTemplate.FsConfig.AzBlobConfig.KeyPrefix) + require.Equal(t, "pwd"+password+username, userTemplate.FsConfig.AzBlobConfig.AccountKey.GetPayload()) + + user.FsConfig.Provider = sdk.SFTPFilesystemProvider + user.FsConfig.SFTPConfig.Prefix = "%username%" + user.FsConfig.SFTPConfig.Username = "sftp_%username%" + user.FsConfig.SFTPConfig.Password = kms.NewPlainSecret("sftp%password%") + userTemplate = getUserFromTemplate(user, templateFields) + require.Equal(t, username, userTemplate.FsConfig.SFTPConfig.Prefix) + require.Equal(t, "sftp_"+username, userTemplate.FsConfig.SFTPConfig.Username) + require.Equal(t, "sftp"+password, userTemplate.FsConfig.SFTPConfig.Password.GetPayload()) +} + +func TestJWTTokenCleanup(t *testing.T) { + tokenAuth, err := jwt.NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + server := httpdServer{ + tokenAuth: tokenAuth, + } + admin := dataprovider.Admin{ + Username: "newtestadmin", + Password: "password", + Permissions: []string{dataprovider.PermAdminAny}, + } + claims := jwt.NewClaims(tokenAudienceAPI, "", getTokenDuration(tokenAudienceAPI)) + claims.Username = admin.Username + claims.Permissions = admin.Permissions + claims.Subject = admin.GetSignature() + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + token, err := server.tokenAuth.Sign(claims) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, versionPath, nil) + assert.True(t, isTokenInvalidated(req)) + + fakeToken := "abc" + invalidateTokenString(req, fakeToken, -100*time.Millisecond) + assert.True(t, invalidatedJWTTokens.Get(fakeToken)) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + invalidatedJWTTokens.Add(token, time.Now().Add(-getTokenDuration(tokenAudienceWebAdmin)).UTC()) + require.True(t, isTokenInvalidated(req)) + startCleanupTicker(100 * time.Millisecond) + assert.Eventually(t, func() bool { return !isTokenInvalidated(req) }, 1*time.Second, 200*time.Millisecond) + assert.False(t, invalidatedJWTTokens.Get(fakeToken)) + stopCleanupTicker() +} + +func TestDbTokenManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newTokenManager(1) + dbTokenManager := mgr.(*dbTokenManager) + testToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOlsiV2ViQWRtaW4iLCI6OjEiXSwiZXhwIjoxNjk4NjYwMDM4LCJqdGkiOiJja3ZuazVrYjF1aHUzZXRmZmhyZyIsIm5iZiI6MTY5ODY1ODgwOCwicGVybWlzc2lvbnMiOlsiKiJdLCJzdWIiOiIxNjk3ODIwNDM3NTMyIiwidXNlcm5hbWUiOiJhZG1pbiJ9.LXuFFksvnSuzHqHat6r70yR0jEulNRju7m7SaWrOfy8; csrftoken=mP0C7DqjwpAXsptO2gGCaYBkYw3oNMWB" + key := dbTokenManager.getKey(testToken) + require.Len(t, key, 64) + dbTokenManager.Add(testToken, time.Now().Add(-getTokenDuration(tokenAudienceWebClient)).UTC()) + isInvalidated := dbTokenManager.Get(testToken) + assert.True(t, isInvalidated) + dbTokenManager.Cleanup() + isInvalidated = dbTokenManager.Get(testToken) + assert.False(t, isInvalidated) + dbTokenManager.Add(testToken, time.Now().Add(getTokenDuration(tokenAudienceWebAdmin)).UTC()) + isInvalidated = dbTokenManager.Get(testToken) + assert.True(t, isInvalidated) + dbTokenManager.Cleanup() + isInvalidated = dbTokenManager.Get(testToken) + assert.True(t, isInvalidated) + err := dataprovider.DeleteSharedSession(key, dataprovider.SessionTypeInvalidToken) + assert.NoError(t, err) +} + +func TestDatabaseSharedSessions(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + session1 := dataprovider.Session{ + Key: "1", + Data: map[string]string{"a": "b"}, + Type: dataprovider.SessionTypeOIDCAuth, + Timestamp: 10, + } + err := dataprovider.AddSharedSession(session1) + assert.NoError(t, err) + // Adding another session with the same key but a different type should work + session2 := session1 + session2.Type = dataprovider.SessionTypeOIDCToken + err = dataprovider.AddSharedSession(session2) + assert.NoError(t, err) + err = dataprovider.DeleteSharedSession(session1.Key, dataprovider.SessionTypeInvalidToken) + assert.ErrorIs(t, err, util.ErrNotFound) + _, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeResetCode) + assert.ErrorIs(t, err, util.ErrNotFound) + session1Get, err := dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.NoError(t, err) + assert.Equal(t, session1.Timestamp, session1Get.Timestamp) + var stored map[string]string + err = json.Unmarshal(session1Get.Data.([]byte), &stored) + assert.NoError(t, err) + assert.Equal(t, session1.Data, stored) + session1.Timestamp = 20 + session1.Data = map[string]string{"c": "d"} + err = dataprovider.AddSharedSession(session1) + assert.NoError(t, err) + session1Get, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.NoError(t, err) + assert.Equal(t, session1.Timestamp, session1Get.Timestamp) + stored = make(map[string]string) + err = json.Unmarshal(session1Get.Data.([]byte), &stored) + assert.NoError(t, err) + assert.Equal(t, session1.Data, stored) + err = dataprovider.DeleteSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.NoError(t, err) + err = dataprovider.DeleteSharedSession(session2.Key, dataprovider.SessionTypeOIDCToken) + assert.NoError(t, err) + _, err = dataprovider.GetSharedSession(session1.Key, dataprovider.SessionTypeOIDCAuth) + assert.ErrorIs(t, err, util.ErrNotFound) + _, err = dataprovider.GetSharedSession(session2.Key, dataprovider.SessionTypeOIDCToken) + assert.ErrorIs(t, err, util.ErrNotFound) +} + +func TestAllowedProxyUnixDomainSocket(t *testing.T) { + b := Binding{ + Address: filepath.Join(os.TempDir(), "sock"), + ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"}, + } + err := b.parseAllowedProxy() + assert.NoError(t, err) + if assert.Len(t, b.allowHeadersFrom, 1) { + assert.True(t, b.allowHeadersFrom[0](nil)) + } +} + +func TestProxyListenerWrapper(t *testing.T) { + b := Binding{ + ProxyMode: 0, + } + require.Nil(t, b.listenerWrapper()) + b.ProxyMode = 1 + require.NotNil(t, b.listenerWrapper()) +} + +func TestProxyHeaders(t *testing.T) { + username := "adminTest" + password := "testPwd" + admin := dataprovider.Admin{ + Username: username, + Password: password, + Permissions: []string{dataprovider.PermAdminAny}, + Status: 1, + Filters: dataprovider.AdminFilters{ + AllowList: []string{"172.19.2.0/24"}, + }, + } + + err := dataprovider.AddAdmin(&admin, "", "", "") + assert.NoError(t, err) + + testIP := "10.29.1.9" + validForwardedFor := "172.19.2.6" + b := Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: false, + EnableRESTAPI: true, + ProxyAllowed: []string{testIP, "10.8.0.0/30"}, + ClientIPProxyHeader: "x-forwarded-for", + } + err = b.parseAllowedProxy() + assert.NoError(t, err) + server := newHttpdServer(b, "", "", CorsConfig{Enabled: true}, "") + err = server.initializeRouter() + require.NoError(t, err) + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + req, err := http.NewRequest(http.MethodGet, tokenPath, nil) + assert.NoError(t, err) + req.Header.Set("X-Forwarded-For", validForwardedFor) + req.Header.Set(xForwardedProto, "https") + req.RemoteAddr = "127.0.0.1:123" + req.SetBasicAuth(username, password) + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.NotContains(t, rr.Body.String(), "login from IP 127.0.0.1 not allowed") + + req.RemoteAddr = testIP + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + req.RemoteAddr = "10.8.0.2" + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = testIP + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + cookie := rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, cookie) + req.Header.Set("Cookie", cookie) + parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form := make(url.Values) + form.Set("username", username) + form.Set("password", password) + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Cookie", cookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCredentials) + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = validForwardedFor + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + loginCookie := rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, loginCookie) + req.Header.Set("Cookie", loginCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + cookie = rr.Header().Get("Set-Cookie") + assert.NotContains(t, cookie, "Secure") + + // The login cookie is invalidated after a successful login, the same request will fail + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidCSRF) + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = validForwardedFor + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + loginCookie = rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, loginCookie) + req.Header.Set("Cookie", loginCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + req.Header.Set(xForwardedProto, "https") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + cookie = rr.Header().Get("Set-Cookie") + assert.Contains(t, cookie, "Secure") + + req, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = validForwardedFor + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + loginCookie = rr.Header().Get("Set-Cookie") + assert.NotEmpty(t, loginCookie) + req.Header.Set("Cookie", loginCookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form.Set(csrfFormToken, createCSRFToken(httptest.NewRecorder(), req, server.csrfTokenAuth, "", webBaseAdminPath)) + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req.RemoteAddr = testIP + req.Header.Set("Cookie", loginCookie) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-Forwarded-For", validForwardedFor) + req.Header.Set(xForwardedProto, "http") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + cookie = rr.Header().Get("Set-Cookie") + assert.NotContains(t, cookie, "Secure") + + err = dataprovider.DeleteAdmin(username, "", "", "") + assert.NoError(t, err) +} + +func TestRecoverer(t *testing.T) { + recoveryPath := "/recovery" + b := Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: false, + EnableRESTAPI: true, + } + server := newHttpdServer(b, "../static", "", CorsConfig{}, "../openapi") + err := server.initializeRouter() + require.NoError(t, err) + server.router.Get(recoveryPath, func(_ http.ResponseWriter, _ *http.Request) { + panic("panic") + }) + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + req, err := http.NewRequest(http.MethodGet, recoveryPath, nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) + + server.router = chi.NewRouter() + server.router.Use(middleware.Recoverer) + server.router.Get(recoveryPath, func(_ http.ResponseWriter, _ *http.Request) { + panic("panic") + }) + testServer = httptest.NewServer(server.router) + defer testServer.Close() + + req, err = http.NewRequest(http.MethodGet, recoveryPath, nil) + assert.NoError(t, err) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code, rr.Body.String()) +} + +func TestStreamJSONArray(t *testing.T) { + dataGetter := func(_, _ int) ([]byte, int, error) { + return nil, 0, nil + } + rr := httptest.NewRecorder() + streamJSONArray(rr, 10, dataGetter) + assert.Equal(t, `[]`, rr.Body.String()) + + data := []int{} + for i := 0; i < 10; i++ { + data = append(data, i) + } + + dataGetter = func(_, offset int) ([]byte, int, error) { + if offset >= len(data) { + return nil, 0, nil + } + val := data[offset] + data, err := json.Marshal([]int{val}) + return data, 1, err + } + + rr = httptest.NewRecorder() + streamJSONArray(rr, 1, dataGetter) + assert.Equal(t, `[0,1,2,3,4,5,6,7,8,9]`, rr.Body.String()) +} + +func TestCompressorAbortHandler(t *testing.T) { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", dataprovider.User{}), + nil, + nil, + ) + share := &dataprovider.Share{} + renderCompressedFiles(&failingWriter{}, connection, "", nil, share) +} + +func TestStreamDataAbortHandler(t *testing.T) { + defer func() { + rcv := recover() + assert.Equal(t, http.ErrAbortHandler, rcv) + }() + + streamData(&failingWriter{}, []byte(`["a":"b"]`)) +} + +func TestZipErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) + + testDir := filepath.Join(os.TempDir(), "testDir") + err := os.MkdirAll(testDir, os.ModePerm) + assert.NoError(t, err) + + wr := zip.NewWriter(&failingWriter{}) + err = wr.Close() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "write error") + } + + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", nil, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "write error") + } + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", nil, 2000) + assert.ErrorIs(t, err, util.ErrRecursionTooDeep) + + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), path.Join("/", filepath.Base(testDir), "dir"), nil, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is outside base dir") + } + + testFilePath := filepath.Join(testDir, "ziptest.zip") + err = os.WriteFile(testFilePath, util.GenerateRandomBytes(65535), os.ModePerm) + assert.NoError(t, err) + err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), + "/"+filepath.Base(testDir), nil, 0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "write error") + } + + connection.User.Permissions["/"] = []string{dataprovider.PermListItems} + err = addZipEntry(wr, connection, path.Join("/", filepath.Base(testDir), filepath.Base(testFilePath)), + "/"+filepath.Base(testDir), nil, 0) + assert.ErrorIs(t, err, os.ErrPermission) + + // creating a virtual folder to a missing path stat is ok but readdir fails + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: filepath.Join(os.TempDir(), "mapped"), + }, + VirtualPath: "/vpath", + }) + connection.User = user + wr = zip.NewWriter(bytes.NewBuffer(make([]byte, 0))) + err = addZipEntry(wr, connection, user.VirtualFolders[0].VirtualPath, "/", nil, 0) + assert.Error(t, err) + + user.Filters.FilePatterns = append(user.Filters.FilePatterns, sdk.PatternsFilter{ + Path: "/", + DeniedPatterns: []string{"*.zip"}, + }) + err = addZipEntry(wr, connection, "/"+filepath.Base(testDir), "/", nil, 0) + assert.ErrorIs(t, err, os.ErrPermission) + + err = os.RemoveAll(testDir) + assert.NoError(t, err) +} + +func TestWebAdminRedirect(t *testing.T) { + b := Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: false, + EnableRESTAPI: true, + } + server := newHttpdServer(b, "../static", "", CorsConfig{}, "../openapi") + err := server.initializeRouter() + require.NoError(t, err) + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + req, err := http.NewRequest(http.MethodGet, webRootPath, nil) + assert.NoError(t, err) + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + + req, err = http.NewRequest(http.MethodGet, webBasePath, nil) + assert.NoError(t, err) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusFound, rr.Code, rr.Body.String()) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) +} + +func TestParseRangeRequests(t *testing.T) { + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=24-24" + fileSize := int64(169740) + rangeHeader := "bytes=24-24" + offset, size, err := parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp := fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 24-24/169740", resp) + require.Equal(t, int64(1), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=24-" + rangeHeader = "bytes=24-" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 24-169739/169740", resp) + require.Equal(t, int64(169716), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=-1" + rangeHeader = "bytes=-1" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 169739-169739/169740", resp) + require.Equal(t, int64(1), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=-100" + rangeHeader = "bytes=-100" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 169640-169739/169740", resp) + require.Equal(t, int64(100), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-30" + rangeHeader = "bytes=20-30" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 20-30/169740", resp) + require.Equal(t, int64(11), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-169739" + rangeHeader = "bytes=20-169739" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 20-169739/169740", resp) + require.Equal(t, int64(169720), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-169740" + rangeHeader = "bytes=20-169740" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 20-169739/169740", resp) + require.Equal(t, int64(169720), size) + // curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=20-169741" + rangeHeader = "bytes=20-169741" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 20-169739/169740", resp) + require.Equal(t, int64(169720), size) + //curl --verbose "http://127.0.0.1:8080/static/css/sb-admin-2.min.css" -H "Range: bytes=0-" > /dev/null + rangeHeader = "bytes=0-" + offset, size, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.NoError(t, err) + resp = fmt.Sprintf("bytes %d-%d/%d", offset, offset+size-1, fileSize) + assert.Equal(t, "bytes 0-169739/169740", resp) + require.Equal(t, int64(169740), size) + // now test errors + rangeHeader = "bytes=0-a" + _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.Error(t, err) + rangeHeader = "bytes=" + _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.Error(t, err) + rangeHeader = "bytes=-" + _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.Error(t, err) + rangeHeader = "bytes=500-300" + _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.Error(t, err) + rangeHeader = "bytes=5000000" + _, _, err = parseRangeRequest(rangeHeader[6:], fileSize) + require.Error(t, err) +} + +func TestRequestHeaderErrors(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.Header.Set("If-Unmodified-Since", "not a date") + res := checkIfUnmodifiedSince(req, time.Now()) + assert.Equal(t, condNone, res) + + req, _ = http.NewRequest(http.MethodPost, webClientFilesPath, nil) + res = checkIfModifiedSince(req, time.Now()) + assert.Equal(t, condNone, res) + + req, _ = http.NewRequest(http.MethodPost, webClientFilesPath, nil) + res = checkIfRange(req, time.Now()) + assert.Equal(t, condNone, res) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.Header.Set("If-Modified-Since", "not a date") + res = checkIfModifiedSince(req, time.Now()) + assert.Equal(t, condNone, res) + + req, _ = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.Header.Set("If-Range", time.Now().Format(http.TimeFormat)) + res = checkIfRange(req, time.Time{}) + assert.Equal(t, condFalse, res) + + req.Header.Set("If-Range", "invalid if range date") + res = checkIfRange(req, time.Now()) + assert.Equal(t, condFalse, res) + modTime := getFileObjectModTime(time.Time{}) + assert.Empty(t, modTime) +} + +func TestConnection(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test_httpd_user", + HomeDir: filepath.Clean(os.TempDir()), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.GCSFilesystemProvider, + GCSConfig: vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: "test_bucket_name", + }, + Credentials: kms.NewPlainSecret("invalid JSON payload"), + }, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) + assert.Empty(t, connection.GetClientVersion()) + assert.Empty(t, connection.GetRemoteAddress()) + assert.Empty(t, connection.GetCommand()) + name := "missing file name" + _, err := connection.getFileReader(name, 0, http.MethodGet) + assert.Error(t, err) + connection.User.FsConfig.Provider = sdk.LocalFilesystemProvider + _, err = connection.getFileReader(name, 0, http.MethodGet) + assert.ErrorIs(t, err, os.ErrNotExist) +} + +func TestGetFileWriterErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test_httpd_user", + HomeDir: "invalid", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) + _, err := connection.getFileWriter("name") + assert.Error(t, err) + + user.FsConfig.Provider = sdk.S3FilesystemProvider + user.FsConfig.S3Config = vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "b", + Region: "us-west-1", + AccessKey: "key", + }, + AccessSecret: kms.NewPlainSecret("secret"), + } + connection = newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) + _, err = connection.getFileWriter("/path") + assert.Error(t, err) +} + +func TestThrottledHandler(t *testing.T) { + tr := &throttledReader{ + r: io.NopCloser(bytes.NewBuffer(nil)), + } + assert.Equal(t, int64(0), tr.GetTruncatedSize()) + err := tr.Close() + assert.NoError(t, err) + assert.Empty(t, tr.GetRealFsPath("real path")) + assert.False(t, tr.SetTimes("p", time.Now(), time.Now())) + _, err = tr.Truncate("", 0) + assert.ErrorIs(t, err, vfs.ErrVfsUnsupported) + err = tr.GetAbortError() + assert.ErrorIs(t, err, common.ErrTransferAborted) +} + +func TestHTTPDFile(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test_httpd_user", + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + connection := newConnection( + common.NewBaseConnection(xid.New().String(), common.ProtocolHTTP, "", "", user), + nil, + nil, + ) + + fs, err := user.GetFilesystem("") + assert.NoError(t, err) + + name := "fileName" + p := filepath.Join(os.TempDir(), name) + err = os.WriteFile(p, []byte("contents"), os.ModePerm) + assert.NoError(t, err) + file, err := os.Open(p) + assert.NoError(t, err) + err = file.Close() + assert.NoError(t, err) + + baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, p, p, name, common.TransferDownload, + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + httpdFile := newHTTPDFile(baseTransfer, nil, nil) + // the file is closed, read should fail + buf := make([]byte, 100) + _, err = httpdFile.Read(buf) + assert.Error(t, err) + err = httpdFile.Close() + assert.Error(t, err) + err = httpdFile.Close() + assert.ErrorIs(t, err, common.ErrTransferClosed) + err = os.Remove(p) + assert.NoError(t, err) + + httpdFile.writer = file + httpdFile.File = nil + httpdFile.ErrTransfer = nil + err = httpdFile.closeIO() + assert.Error(t, err) + assert.Error(t, httpdFile.ErrTransfer) + assert.Equal(t, err, httpdFile.ErrTransfer) + httpdFile.SignalClose(nil) + _, err = httpdFile.Write(nil) + assert.ErrorIs(t, err, common.ErrQuotaExceeded) +} + +func TestChangeUserPwd(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, webChangeClientPwdPath, nil) + err := doChangeUserPassword(req, "", "", "") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "please provide the current password and the new one two times") + } + err = doChangeUserPassword(req, "a", "b", "c") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "the two password fields do not match") + } + err = doChangeUserPassword(req, "a", "b", "b") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), errInvalidTokenClaims.Error()) + } +} + +func TestWebUserInvalidClaims(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "", + Password: "pwd", + }, + } + c := &jwt.Claims{ + Username: user.Username, + Permissions: nil, + } + c.Subject = user.GetSignature() + c.SetExpiry(time.Now().Add(10 * time.Minute)) + c.Audience = []string{tokenAudienceAPI} + token, err := server.tokenAuth.Sign(c) + assert.NoError(t, err) + + req, _ := http.NewRequest(http.MethodGet, webClientFilesPath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientGetFiles(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientDirsPath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientGetDirContents(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorDirList403) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientDownloadZipPath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleWebClientDownloadZip(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientEditFilePath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientEditFile(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientSharePath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientAddShareGet(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientSharePath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientUpdateShareGet(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, webClientSharePath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientAddSharePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodPost, webClientSharePath+"/id", nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientUpdateSharePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientSharesPath+jsonAPISuffix, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + getAllShares(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) + + rr = httptest.NewRecorder() + req, _ = http.NewRequest(http.MethodGet, webClientViewPDFPath, nil) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleClientGetPDF(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) +} + +func TestInvalidClaims(t *testing.T) { + server := httpdServer{} + err := server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "", + Password: "pwd", + }, + } + c := &jwt.Claims{ + Username: user.Username, + Permissions: nil, + } + c.Subject = user.GetSignature() + token, err := server.tokenAuth.SignWithParams(c, tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + assert.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, webClientProfilePath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + parsedToken, err := jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx := req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form := make(url.Values) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseClientPath)) + form.Set("public_keys", "") + req, err = http.NewRequest(http.MethodPost, webClientProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleWebClientProfilePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + + admin := dataprovider.Admin{ + Username: "", + Password: user.Password, + } + c = &jwt.Claims{ + Username: admin.Username, + Permissions: nil, + } + c.Subject = admin.GetSignature() + token, err = server.tokenAuth.SignWithParams(c, tokenAudienceWebAdmin, "", getTokenDuration(tokenAudienceWebAdmin)) + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil) + assert.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + parsedToken, err = jwt.VerifyRequest(server.tokenAuth, req, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = req.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + req = req.WithContext(ctx) + + form = make(url.Values) + form.Set(csrfFormToken, createCSRFToken(rr, req, server.csrfTokenAuth, "", webBaseAdminPath)) + form.Set("allow_api_key_auth", "") + req, err = http.NewRequest(http.MethodPost, webAdminProfilePath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + server.handleWebAdminProfilePost(rr, req) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorInvalidToken) +} + +func TestTLSReq(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + req.TLS = &tls.ConnectionState{} + assert.True(t, isTLS(req)) + req.TLS = nil + ctx := context.WithValue(req.Context(), forwardedProtoKey, "https") + assert.True(t, isTLS(req.WithContext(ctx))) + ctx = context.WithValue(req.Context(), forwardedProtoKey, "http") + assert.False(t, isTLS(req.WithContext(ctx))) + assert.Equal(t, "context value forwarded proto", forwardedProtoKey.String()) +} + +func TestSigningKey(t *testing.T) { + signingPassphrase := "test" + server1 := httpdServer{ + signingPassphrase: signingPassphrase, + } + err := server1.initializeRouter() + require.NoError(t, err) + + server2 := httpdServer{ + signingPassphrase: signingPassphrase, + } + err = server2.initializeRouter() + require.NoError(t, err) + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "", + Password: "pwd", + }, + } + c := &jwt.Claims{ + Username: user.Username, + Permissions: nil, + } + c.Subject = user.GetSignature() + token, err := server1.tokenAuth.SignWithParams(c, tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + assert.NoError(t, err) + assert.NotEmpty(t, token) + _, err = jwt.VerifyToken(server1.tokenAuth, token) + assert.NoError(t, err) + _, err = jwt.VerifyToken(server2.tokenAuth, token) + assert.NoError(t, err) +} + +func TestLoginLinks(t *testing.T) { + b := Binding{ + EnableWebAdmin: true, + EnableWebClient: false, + EnableRESTAPI: true, + } + assert.False(t, b.showClientLoginURL()) + b = Binding{ + EnableWebAdmin: false, + EnableWebClient: true, + EnableRESTAPI: true, + } + assert.False(t, b.showAdminLoginURL()) + b = Binding{ + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + } + assert.True(t, b.showAdminLoginURL()) + assert.True(t, b.showClientLoginURL()) + b.HideLoginURL = 3 + assert.False(t, b.showAdminLoginURL()) + assert.False(t, b.showClientLoginURL()) + b.HideLoginURL = 1 + assert.True(t, b.showAdminLoginURL()) + assert.False(t, b.showClientLoginURL()) + b.HideLoginURL = 2 + assert.False(t, b.showAdminLoginURL()) + assert.True(t, b.showClientLoginURL()) +} + +func TestResetCodesCleanup(t *testing.T) { + resetCode := newResetCode(util.GenerateUniqueID(), false) + resetCode.ExpiresAt = time.Now().Add(-1 * time.Minute).UTC() + err := resetCodesMgr.Add(resetCode) + assert.NoError(t, err) + resetCodesMgr.Cleanup() + _, err = resetCodesMgr.Get(resetCode.Code) + assert.Error(t, err) +} + +func TestUserCanResetPassword(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + req.RemoteAddr = "172.16.9.2:55080" + + u := dataprovider.User{} + assert.True(t, isUserAllowedToResetPassword(req, &u)) + u.Filters.DeniedProtocols = []string{common.ProtocolHTTP} + assert.False(t, isUserAllowedToResetPassword(req, &u)) + u.Filters.DeniedProtocols = nil + u.Filters.WebClient = []string{sdk.WebClientPasswordResetDisabled} + assert.False(t, isUserAllowedToResetPassword(req, &u)) + u.Filters.WebClient = nil + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + assert.False(t, isUserAllowedToResetPassword(req, &u)) + u.Filters.DeniedLoginMethods = nil + u.Filters.AllowedIP = []string{"127.0.0.1/8"} + assert.False(t, isUserAllowedToResetPassword(req, &u)) +} + +func TestBrowsableSharePaths(t *testing.T) { + share := dataprovider.Share{ + Paths: []string{"/"}, + Username: defaultAdminUsername, + } + _, err := getUserForShare(share) + if assert.Error(t, err) { + assert.ErrorIs(t, err, util.ErrNotFound) + } + req, err := http.NewRequest(http.MethodGet, "/share", nil) + require.NoError(t, err) + name, err := getBrowsableSharedPath(share.Paths[0], req) + assert.NoError(t, err) + assert.Equal(t, "/", name) + req, err = http.NewRequest(http.MethodGet, "/share?path=abc", nil) + require.NoError(t, err) + name, err = getBrowsableSharedPath(share.Paths[0], req) + assert.NoError(t, err) + assert.Equal(t, "/abc", name) + + share.Paths = []string{"/a/b/c"} + req, err = http.NewRequest(http.MethodGet, "/share?path=abc", nil) + require.NoError(t, err) + name, err = getBrowsableSharedPath(share.Paths[0], req) + assert.NoError(t, err) + assert.Equal(t, "/a/b/c/abc", name) + req, err = http.NewRequest(http.MethodGet, "/share?path=%2Fabc/d", nil) + require.NoError(t, err) + name, err = getBrowsableSharedPath(share.Paths[0], req) + assert.NoError(t, err) + assert.Equal(t, "/a/b/c/abc/d", name) + + req, err = http.NewRequest(http.MethodGet, "/share?path=%2Fabc%2F..%2F..", nil) + require.NoError(t, err) + _, err = getBrowsableSharedPath(share.Paths[0], req) + assert.Error(t, err) + + req, err = http.NewRequest(http.MethodGet, "/share?path=%2Fabc%2F..", nil) + require.NoError(t, err) + name, err = getBrowsableSharedPath(share.Paths[0], req) + assert.NoError(t, err) + assert.Equal(t, "/a/b/c", name) + + share = dataprovider.Share{ + Paths: []string{"/a", "/b"}, + } +} + +func TestSecureMiddlewareIntegration(t *testing.T) { + forwardedHostHeader := "X-Forwarded-Host" + server := httpdServer{ + binding: Binding{ + ProxyAllowed: []string{"192.168.1.0/24"}, + Security: SecurityConf{ + Enabled: true, + AllowedHosts: []string{"*.sftpgo.com"}, + AllowedHostsAreRegex: true, + HostsProxyHeaders: []string{forwardedHostHeader}, + HTTPSProxyHeaders: []HTTPSProxyHeader{ + { + Key: xForwardedProto, + Value: "https", + }, + }, + STSSeconds: 31536000, + STSIncludeSubdomains: true, + STSPreload: true, + ContentTypeNosniff: true, + CacheControl: "private", + CrossOriginOpenerPolicy: "same-origin", + CrossOriginResourcePolicy: "same-site", + CrossOriginEmbedderPolicy: "require-corp", + ReferrerPolicy: "no-referrer", + }, + }, + enableWebAdmin: true, + enableWebClient: true, + enableRESTAPI: true, + } + server.binding.Security.updateProxyHeaders() + err := server.binding.parseAllowedProxy() + assert.NoError(t, err) + assert.Equal(t, []string{forwardedHostHeader, xForwardedProto}, server.binding.Security.proxyHeaders) + assert.Equal(t, map[string]string{xForwardedProto: "https"}, server.binding.Security.getHTTPSProxyHeaders()) + err = server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + r.Host = "127.0.0.1" + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.Equal(t, "no-cache, no-store, max-age=0, must-revalidate, private", rr.Header().Get("Cache-Control")) + + rr = httptest.NewRecorder() + r.Header.Set(forwardedHostHeader, "www.sftpgo.com") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusForbidden, rr.Code) + // the header should be removed + assert.Empty(t, r.Header.Get(forwardedHostHeader)) + + rr = httptest.NewRecorder() + r.Host = "test.sftpgo.com" + r.Header.Set(forwardedHostHeader, "test.example.com") + r.RemoteAddr = "192.168.1.1" + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusForbidden, rr.Code) + assert.NotEmpty(t, r.Header.Get(forwardedHostHeader)) + + rr = httptest.NewRecorder() + r.Header.Set(forwardedHostHeader, "www.sftpgo.com") + r.RemoteAddr = "192.168.1.1" + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + assert.NotEmpty(t, r.Header.Get(forwardedHostHeader)) + assert.Empty(t, rr.Header().Get("Strict-Transport-Security")) + assert.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) + // now set the X-Forwarded-Proto to https, we should get the Strict-Transport-Security header + rr = httptest.NewRecorder() + r.Host = "test.sftpgo.com" + r.Header.Set(xForwardedProto, "https") + r.RemoteAddr = "192.168.1.3" + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + assert.NotEmpty(t, r.Header.Get(forwardedHostHeader)) + assert.Equal(t, "max-age=31536000; includeSubDomains; preload", rr.Header().Get("Strict-Transport-Security")) + assert.Equal(t, "nosniff", rr.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "require-corp", rr.Header().Get("Cross-Origin-Embedder-Policy")) + assert.Equal(t, "same-origin", rr.Header().Get("Cross-Origin-Opener-Policy")) + assert.Equal(t, "same-site", rr.Header().Get("Cross-Origin-Resource-Policy")) + assert.Equal(t, "no-referrer", rr.Header().Get("Referrer-Policy")) + + server.binding.Security.Enabled = false + server.binding.Security.updateProxyHeaders() + assert.Len(t, server.binding.Security.proxyHeaders, 0) +} + +func TestGetCompressedFileName(t *testing.T) { + username := "test" + res := getCompressedFileName(username, []string{"single dir"}) + require.Equal(t, fmt.Sprintf("%s-single dir.zip", username), res) + res = getCompressedFileName(username, []string{"file1", "file2"}) + require.Equal(t, fmt.Sprintf("%s-download.zip", username), res) + res = getCompressedFileName(username, []string{"file1.txt"}) + require.Equal(t, fmt.Sprintf("%s-file1.zip", username), res) + // now files with full paths + res = getCompressedFileName(username, []string{"/dir/single dir"}) + require.Equal(t, fmt.Sprintf("%s-single dir.zip", username), res) + res = getCompressedFileName(username, []string{"/adir/file1", "/adir/file2"}) + require.Equal(t, fmt.Sprintf("%s-download.zip", username), res) + res = getCompressedFileName(username, []string{"/sub/dir/file1.txt"}) + require.Equal(t, fmt.Sprintf("%s-file1.zip", username), res) +} + +func TestRESTAPIDisabled(t *testing.T) { + server := httpdServer{ + enableWebAdmin: true, + enableWebClient: true, + enableRESTAPI: false, + } + err := server.initializeRouter() + require.NoError(t, err) + assert.False(t, server.enableRESTAPI) + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, healthzPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, tokenPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func TestWebAdminSetupWithInstallCode(t *testing.T) { + installationCode = "1234" + // delete all the admins + admins, err := dataprovider.GetAdmins(100, 0, dataprovider.OrderASC) + assert.NoError(t, err) + for _, admin := range admins { + err = dataprovider.DeleteAdmin(admin.Username, "", "", "") + assert.NoError(t, err) + } + // close the provider and initializes it without creating the default admin + providerConf := dataprovider.GetProviderConfig() + providerConf.CreateDefaultAdmin = false + err = dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + server := httpdServer{ + enableWebAdmin: true, + enableWebClient: true, + enableRESTAPI: true, + } + err = server.initializeRouter() + require.NoError(t, err) + + for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webURL, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + } + + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webAdminSetupPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + cookie := rr.Header().Get("Set-Cookie") + r.Header.Set("Cookie", cookie) + parsedToken, err := jwt.VerifyRequest(server.csrfTokenAuth, r, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx := r.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + r = r.WithContext(ctx) + + form := make(url.Values) + csrfToken := createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath) + form.Set(csrfFormToken, csrfToken) + form.Set("install_code", installationCode+"5") + form.Set("username", defaultAdminUsername) + form.Set("password", "password") + form.Set("confirm_password", "password") + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorSetupInstallCode) + + _, err = dataprovider.AdminExists(defaultAdminUsername) + assert.Error(t, err) + form.Set("install_code", installationCode) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminMFAPath, rr.Header().Get("Location")) + + _, err = dataprovider.AdminExists(defaultAdminUsername) + assert.NoError(t, err) + + // delete the admin and test the installation code resolver + err = dataprovider.DeleteAdmin(defaultAdminUsername, "", "", "") + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + SetInstallationCodeResolver(func(_ string) string { + return "5678" + }) + + for _, webURL := range []string{"/", webBasePath, webBaseAdminPath, webAdminLoginPath, webClientLoginPath} { + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webURL, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminSetupPath, rr.Header().Get("Location")) + } + + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminSetupPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + cookie = rr.Header().Get("Set-Cookie") + r.Header.Set("Cookie", cookie) + parsedToken, err = jwt.VerifyRequest(server.csrfTokenAuth, r, jwt.TokenFromCookie) + assert.NoError(t, err) + ctx = r.Context() + ctx = jwt.NewContext(ctx, parsedToken, err) + r = r.WithContext(ctx) + + form = make(url.Values) + csrfToken = createCSRFToken(rr, r, server.csrfTokenAuth, "", webBaseAdminPath) + form.Set(csrfFormToken, csrfToken) + form.Set("install_code", installationCode) + form.Set("username", defaultAdminUsername) + form.Set("password", "password") + form.Set("confirm_password", "password") + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nErrorSetupInstallCode) + + _, err = dataprovider.AdminExists(defaultAdminUsername) + assert.Error(t, err) + form.Set("install_code", "5678") + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminSetupPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + r = r.WithContext(ctx) + r.Header.Set("Cookie", cookie) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminMFAPath, rr.Header().Get("Location")) + + _, err = dataprovider.AdminExists(defaultAdminUsername) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + providerConf.CreateDefaultAdmin = true + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + installationCode = "" + SetInstallationCodeResolver(nil) +} + +func TestDbResetCodeManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newResetCodeManager(1) + resetCode := newResetCode("admin", true) + err := mgr.Add(resetCode) + assert.NoError(t, err) + codeGet, err := mgr.Get(resetCode.Code) + assert.NoError(t, err) + assert.Equal(t, resetCode, codeGet) + err = mgr.Delete(resetCode.Code) + assert.NoError(t, err) + err = mgr.Delete(resetCode.Code) + if assert.Error(t, err) { + assert.ErrorIs(t, err, util.ErrNotFound) + } + _, err = mgr.Get(resetCode.Code) + assert.ErrorIs(t, err, util.ErrNotFound) + // add an expired reset code + resetCode = newResetCode("user", false) + resetCode.ExpiresAt = time.Now().Add(-24 * time.Hour) + err = mgr.Add(resetCode) + assert.NoError(t, err) + _, err = mgr.Get(resetCode.Code) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "reset code expired") + } + mgr.Cleanup() + _, err = mgr.Get(resetCode.Code) + assert.ErrorIs(t, err, util.ErrNotFound) + + dbMgr, ok := mgr.(*dbResetCodeManager) + if assert.True(t, ok) { + _, err = dbMgr.decodeData("astring") + assert.Error(t, err) + } +} + +func TestEventRoleFilter(t *testing.T) { + defaultVal := "default" + req, err := http.NewRequest(http.MethodGet, fsEventsPath+"?role=role1", nil) + require.NoError(t, err) + role := getRoleFilterForEventSearch(req, defaultVal) + assert.Equal(t, defaultVal, role) + role = getRoleFilterForEventSearch(req, "") + assert.Equal(t, "role1", role) +} + +func TestEventsCSV(t *testing.T) { + e := fsEvent{ + Status: 1, + } + data := e.getCSVData() + assert.Equal(t, "OK", data[5]) + e.Status = 2 + data = e.getCSVData() + assert.Equal(t, "KO", data[5]) + e.Status = 3 + data = e.getCSVData() + assert.Equal(t, "Quota exceeded", data[5]) +} + +func TestConfigsFromProvider(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + c := Conf{ + Bindings: []Binding{ + { + Port: 1234, + }, + { + Port: 80, + Security: SecurityConf{ + Enabled: true, + HTTPSRedirect: true, + }, + }, + }, + } + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + configs := dataprovider.Configs{ + ACME: &dataprovider.ACMEConfigs{ + Domain: "domain.com", + Email: "info@domain.com", + HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: 80}, + Protocols: 1, + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + util.CertsBasePath = "" + // crt and key empty + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + util.CertsBasePath = filepath.Clean(os.TempDir()) + // crt not found + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs := c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + crtPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".crt") + err = os.WriteFile(crtPath, nil, 0666) + assert.NoError(t, err) + // key not found + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + keyPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".key") + err = os.WriteFile(keyPath, nil, 0666) + assert.NoError(t, err) + // acme cert used + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Equal(t, configs.ACME.Domain, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 1) + assert.True(t, c.Bindings[0].EnableHTTPS) + assert.False(t, c.Bindings[1].EnableHTTPS) + // protocols does not match + configs.ACME.Protocols = 6 + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + c.acmeDomain = "" + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + + err = os.Remove(crtPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + util.CertsBasePath = "" + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestHTTPSRedirect(t *testing.T) { + acmeWebRoot := filepath.Join(os.TempDir(), "acme") + err := os.MkdirAll(acmeWebRoot, os.ModePerm) + assert.NoError(t, err) + tokenName := "token" + err = os.WriteFile(filepath.Join(acmeWebRoot, tokenName), []byte("val"), 0666) + assert.NoError(t, err) + + acmeConfig := acme.Configuration{ + HTTP01Challenge: acme.HTTP01Challenge{WebRoot: acmeWebRoot}, + } + err = acme.Initialize(acmeConfig, configDir, true) + require.NoError(t, err) + + forwardedHostHeader := "X-Forwarded-Host" + server := httpdServer{ + binding: Binding{ + Security: SecurityConf{ + Enabled: true, + HTTPSRedirect: true, + HostsProxyHeaders: []string{forwardedHostHeader}, + }, + }, + } + err = server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, path.Join(acmeChallengeURI, tokenName), nil) + assert.NoError(t, err) + r.Host = "localhost" + r.RequestURI = path.Join(acmeChallengeURI, tokenName) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code, rr.Body.String()) + + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + r.RequestURI = webAdminLoginPath + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusTemporaryRedirect, rr.Code, rr.Body.String()) + + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + r.RequestURI = webAdminLoginPath + r.Header.Set(forwardedHostHeader, "sftpgo.com") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusTemporaryRedirect, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), "https://sftpgo.com") + + server.binding.Security.HTTPSHost = "myhost:1044" + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + r.RequestURI = webAdminLoginPath + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusTemporaryRedirect, rr.Code, rr.Body.String()) + assert.Contains(t, rr.Body.String(), "https://myhost:1044") + + err = os.RemoveAll(acmeWebRoot) + assert.NoError(t, err) +} + +func TestDisabledAdminLoginMethods(t *testing.T) { + server := httpdServer{ + binding: Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + DisabledLoginMethods: 20, + }, + enableWebAdmin: true, + enableWebClient: true, + enableRESTAPI: true, + } + err := server.initializeRouter() + require.NoError(t, err) + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, tokenPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, defaultAdminUsername, "forgot-password"), nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, path.Join(adminPath, defaultAdminUsername, "reset-password"), nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, webAdminLoginPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, webAdminResetPwdPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, webAdminForgotPwdPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func TestDisabledUserLoginMethods(t *testing.T) { + server := httpdServer{ + binding: Binding{ + Address: "", + Port: 8080, + EnableWebAdmin: true, + EnableWebClient: true, + EnableRESTAPI: true, + DisabledLoginMethods: 40, + }, + enableWebAdmin: true, + enableWebClient: true, + enableRESTAPI: true, + } + err := server.initializeRouter() + require.NoError(t, err) + testServer := httptest.NewServer(server.router) + defer testServer.Close() + + rr := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, userTokenPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, userPath+"/user/forgot-password", nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, userPath+"/user/reset-password", nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, webClientLoginPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, webClientResetPwdPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) + + rr = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, nil) + require.NoError(t, err) + testServer.Config.Handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func TestGetLogEventString(t *testing.T) { + assert.Equal(t, "Login failed", getLogEventString(notifier.LogEventTypeLoginFailed)) + assert.Equal(t, "Login with non-existent user", getLogEventString(notifier.LogEventTypeLoginNoUser)) + assert.Equal(t, "No login tried", getLogEventString(notifier.LogEventTypeNoLoginTried)) + assert.Equal(t, "Algorithm negotiation failed", getLogEventString(notifier.LogEventTypeNotNegotiated)) + assert.Equal(t, "Login succeeded", getLogEventString(notifier.LogEventTypeLoginOK)) + assert.Empty(t, getLogEventString(0)) +} + +func TestUserQuotaUsage(t *testing.T) { + usage := userQuotaUsage{ + QuotaSize: 100, + } + require.True(t, usage.HasQuotaInfo()) + require.NotEmpty(t, usage.GetQuotaSize()) + providerConf := dataprovider.GetProviderConfig() + quotaTracking := dataprovider.GetQuotaTracking() + providerConf.TrackQuota = 0 + err := dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + assert.False(t, usage.HasQuotaInfo()) + providerConf.TrackQuota = quotaTracking + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + usage.QuotaSize = 0 + assert.False(t, usage.HasQuotaInfo()) + assert.Empty(t, usage.GetQuotaSize()) + assert.Equal(t, 0, usage.GetQuotaSizePercentage()) + assert.False(t, usage.IsQuotaSizeLow()) + assert.False(t, usage.IsDiskQuotaLow()) + assert.False(t, usage.IsQuotaLow()) + usage.UsedQuotaSize = 9 + assert.NotEmpty(t, usage.GetQuotaSize()) + usage.QuotaSize = 10 + assert.True(t, usage.IsQuotaSizeLow()) + assert.True(t, usage.IsDiskQuotaLow()) + assert.True(t, usage.IsQuotaLow()) + usage.DownloadDataTransfer = 1 + assert.True(t, usage.HasQuotaInfo()) + assert.True(t, usage.HasTranferQuota()) + assert.Empty(t, usage.GetQuotaFiles()) + assert.Equal(t, 0, usage.GetQuotaFilesPercentage()) + usage.QuotaFiles = 1 + assert.NotEmpty(t, usage.GetQuotaFiles()) + usage.QuotaFiles = 0 + usage.UsedQuotaFiles = 9 + assert.NotEmpty(t, usage.GetQuotaFiles()) + usage.QuotaFiles = 10 + usage.DownloadDataTransfer = 0 + assert.True(t, usage.IsQuotaFilesLow()) + assert.True(t, usage.IsDiskQuotaLow()) + assert.False(t, usage.IsTotalTransferQuotaLow()) + assert.False(t, usage.IsUploadTransferQuotaLow()) + assert.False(t, usage.IsDownloadTransferQuotaLow()) + assert.Equal(t, 0, usage.GetTotalTransferQuotaPercentage()) + assert.Equal(t, 0, usage.GetUploadTransferQuotaPercentage()) + assert.Equal(t, 0, usage.GetDownloadTransferQuotaPercentage()) + assert.Empty(t, usage.GetTotalTransferQuota()) + assert.Empty(t, usage.GetUploadTransferQuota()) + assert.Empty(t, usage.GetDownloadTransferQuota()) + usage.TotalDataTransfer = 3 + usage.UsedUploadDataTransfer = 1 * 1048576 + assert.NotEmpty(t, usage.GetTotalTransferQuota()) + usage.TotalDataTransfer = 0 + assert.NotEmpty(t, usage.GetTotalTransferQuota()) + assert.NotEmpty(t, usage.GetUploadTransferQuota()) + usage.UploadDataTransfer = 2 + assert.NotEmpty(t, usage.GetUploadTransferQuota()) + usage.UsedDownloadDataTransfer = 1 * 1048576 + assert.NotEmpty(t, usage.GetDownloadTransferQuota()) + usage.DownloadDataTransfer = 2 + assert.NotEmpty(t, usage.GetDownloadTransferQuota()) + assert.False(t, usage.IsTransferQuotaLow()) + usage.UsedDownloadDataTransfer = 8 * 1048576 + usage.TotalDataTransfer = 10 + assert.True(t, usage.IsTotalTransferQuotaLow()) + assert.True(t, usage.IsTransferQuotaLow()) + usage.TotalDataTransfer = 0 + usage.UploadDataTransfer = 0 + usage.DownloadDataTransfer = 0 + assert.False(t, usage.IsTransferQuotaLow()) + usage.UploadDataTransfer = 10 + usage.UsedUploadDataTransfer = 9 * 1048576 + assert.True(t, usage.IsUploadTransferQuotaLow()) + assert.True(t, usage.IsTransferQuotaLow()) + usage.DownloadDataTransfer = 10 + usage.UsedDownloadDataTransfer = 9 * 1048576 + assert.True(t, usage.IsDownloadTransferQuotaLow()) + assert.True(t, usage.IsTransferQuotaLow()) +} + +func TestShareRedirectURL(t *testing.T) { + shareID := util.GenerateUniqueID() + base := path.Join(webClientPubSharesPath, shareID) + next := path.Join(webClientPubSharesPath, shareID, "browse") + ok, res := checkShareRedirectURL(next, base) + assert.True(t, ok) + assert.Equal(t, next, res) + next = path.Join(webClientPubSharesPath, shareID, "browse") + "?a=b" + ok, res = checkShareRedirectURL(next, base) + assert.True(t, ok) + assert.Equal(t, next, res) + next = path.Join(webClientPubSharesPath, shareID) + ok, res = checkShareRedirectURL(next, base) + assert.True(t, ok) + assert.Equal(t, path.Join(base, "download"), res) + next = path.Join(webClientEditFilePath, shareID) + ok, res = checkShareRedirectURL(next, base) + assert.False(t, ok) + assert.Empty(t, res) + next = path.Join(webClientPubSharesPath, shareID) + "?compress=false&a=b" + ok, res = checkShareRedirectURL(next, base) + assert.True(t, ok) + assert.Equal(t, path.Join(base, "download?compress=false&a=b"), res) + next = path.Join(webClientPubSharesPath, shareID) + "?compress=true&b=c" + ok, res = checkShareRedirectURL(next, base) + assert.True(t, ok) + assert.Equal(t, path.Join(base, "download?compress=true&b=c"), res) + ok, res = checkShareRedirectURL("http://foo\x7f.com/ab", "http://foo\x7f.com/") + assert.False(t, ok) + assert.Empty(t, res) + ok, res = checkShareRedirectURL("http://foo.com/?foo\nbar", "http://foo.com") + assert.False(t, ok) + assert.Empty(t, res) +} + +func TestI18NMessages(t *testing.T) { + msg := i18nListDirMsg(http.StatusForbidden) + require.Equal(t, util.I18nErrorDirList403, msg) + msg = i18nListDirMsg(http.StatusInternalServerError) + require.Equal(t, util.I18nErrorDirListGeneric, msg) + msg = i18nFsMsg(http.StatusForbidden) + require.Equal(t, util.I18nError403Message, msg) + msg = i18nFsMsg(http.StatusInternalServerError) + require.Equal(t, util.I18nErrorFsGeneric, msg) +} + +func TestI18NErrors(t *testing.T) { + err := util.NewValidationError("error text") + errI18n := util.NewI18nError(err, util.I18nError500Message) + assert.ErrorIs(t, errI18n, util.ErrValidation) + assert.Equal(t, err.Error(), errI18n.Error()) + assert.Equal(t, util.I18nError500Message, getI18NErrorString(errI18n, "")) + assert.Equal(t, util.I18nError500Message, errI18n.Message) + assert.Equal(t, "{}", errI18n.Args()) + var e1 *util.ValidationError + assert.ErrorAs(t, errI18n, &e1) + var e2 *util.I18nError + assert.ErrorAs(t, errI18n, &e2) + err2 := util.NewI18nError(fs.ErrNotExist, util.I18nError500Message) + assert.ErrorIs(t, err2, &util.I18nError{}) + assert.ErrorIs(t, err2, fs.ErrNotExist) + assert.NotErrorIs(t, err2, fs.ErrExist) + assert.Equal(t, util.I18nError403Message, getI18NErrorString(fs.ErrClosed, util.I18nError403Message)) + errorString := getI18NErrorString(nil, util.I18nError500Message) + assert.Equal(t, util.I18nError500Message, errorString) + errI18nWrap := util.NewI18nError(errI18n, util.I18nError404Message) + assert.Equal(t, util.I18nError500Message, errI18nWrap.Message) + errI18n = util.NewI18nError(err, util.I18nError500Message, util.I18nErrorArgs(map[string]any{"a": "b"})) + assert.Equal(t, util.I18nError500Message, errI18n.Message) + assert.Equal(t, `{"a":"b"}`, errI18n.Args()) +} + +func TestConvertEnabledLoginMethods(t *testing.T) { + b := Binding{ + EnabledLoginMethods: 0, + DisabledLoginMethods: 1, + } + b.convertLoginMethods() + assert.Equal(t, 1, b.DisabledLoginMethods) + b.DisabledLoginMethods = 0 + b.EnabledLoginMethods = 1 + b.convertLoginMethods() + assert.Equal(t, 14, b.DisabledLoginMethods) + b.DisabledLoginMethods = 0 + b.EnabledLoginMethods = 2 + b.convertLoginMethods() + assert.Equal(t, 13, b.DisabledLoginMethods) + b.DisabledLoginMethods = 0 + b.EnabledLoginMethods = 3 + b.convertLoginMethods() + assert.Equal(t, 12, b.DisabledLoginMethods) + b.DisabledLoginMethods = 0 + b.EnabledLoginMethods = 4 + b.convertLoginMethods() + assert.Equal(t, 11, b.DisabledLoginMethods) + b.DisabledLoginMethods = 0 + b.EnabledLoginMethods = 7 + b.convertLoginMethods() + assert.Equal(t, 8, b.DisabledLoginMethods) + b.DisabledLoginMethods = 0 + b.EnabledLoginMethods = 15 + b.convertLoginMethods() + assert.Equal(t, 0, b.DisabledLoginMethods) +} + +func TestValidateBaseURL(t *testing.T) { + tests := []struct { + name string + inputURL string + expectedURL string + expectErr bool + }{ + { + name: "Valid HTTPS URL", + inputURL: "https://sftp.example.com", + expectedURL: "https://sftp.example.com", + expectErr: false, + }, + { + name: "Remove trailing slash", + inputURL: "https://sftp.example.com/", + expectedURL: "https://sftp.example.com", + expectErr: false, + }, + { + name: "Remove multiple trailing slashes", + inputURL: "http://192.168.1.100:8080///", + expectedURL: "http://192.168.1.100:8080", + expectErr: false, + }, + { + name: "Empty BaseURL (optional case)", + inputURL: "", + expectedURL: "", + expectErr: false, + }, + { + name: "Unsupported scheme (FTP)", + inputURL: "ftp://files.example.com", + expectErr: true, + }, + { + name: "Malformed URL string", + inputURL: "not-a-url", + expectErr: true, + }, + { + name: "Missing Host", + inputURL: "https://", + expectErr: true, + }, + { + name: "Preserve path without trailing slash", + inputURL: "https://example.com/sftp/", + expectedURL: "https://example.com/sftp", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &Binding{ + BaseURL: tt.inputURL, + } + + err := b.validateBaseURL() + + if (err != nil) != tt.expectErr { + t.Errorf("validateBaseURL() error = %v, expectErr %v", err, tt.expectErr) + return + } + + if !tt.expectErr && b.BaseURL != tt.expectedURL { + t.Errorf("validateBaseURL() got = %v, want %v", b.BaseURL, tt.expectedURL) + } + }) + } +} + +func getCSRFTokenFromBody(body io.Reader) (string, error) { + doc, err := html.Parse(body) + if err != nil { + return "", err + } + + var csrfToken string + var f func(*html.Node) + + f = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "input" { + var name, value string + for _, attr := range n.Attr { + if attr.Key == "value" { + value = attr.Val + } + if attr.Key == "name" { + name = attr.Val + } + } + if name == csrfFormToken { + csrfToken = value + return + } + } + + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + + f(doc) + + if csrfToken == "" { + return "", errors.New("CSRF token not found") + } + + return csrfToken, nil +} + +func isSharedProviderSupported() bool { + // SQLite shares the implementation with other SQL-based provider but it makes no sense + // to use it outside test cases + switch dataprovider.GetProviderStatus().Driver { + case dataprovider.MySQLDataProviderName, dataprovider.PGSQLDataProviderName, + dataprovider.CockroachDataProviderName, dataprovider.SQLiteDataProviderName: + return true + default: + return false + } +} diff --git a/internal/httpd/middleware.go b/internal/httpd/middleware.go new file mode 100644 index 00000000..199fa89f --- /dev/null +++ b/internal/httpd/middleware.go @@ -0,0 +1,615 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "errors" + "fmt" + "io/fs" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/rs/xid" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + forwardedProtoKey = &contextKey{"forwarded proto"} + errInvalidToken = errors.New("invalid JWT token") +) + +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "context value " + k.name +} + +func validateJWTToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { + token, err := jwt.FromContext(r.Context()) + + var redirectPath string + if audience == tokenAudienceWebAdmin { + redirectPath = webAdminLoginPath + } else { + redirectPath = webClientLoginPath + if uri := r.RequestURI; strings.HasPrefix(uri, webClientFilesPath) { + redirectPath += "?next=" + url.QueryEscape(uri) //nolint:goconst + } + } + + isAPIToken := (audience == tokenAudienceAPI || audience == tokenAudienceAPIUser) + + doRedirect := func(message string, err error) { + if isAPIToken { + sendAPIResponse(w, r, err, message, http.StatusUnauthorized) + } else { + http.Redirect(w, r, redirectPath, http.StatusFound) + } + } + + if err != nil { + logger.Debug(logSender, "", "error getting jwt token: %v", err) + doRedirect(http.StatusText(http.StatusUnauthorized), err) + return errInvalidToken + } + + if isTokenInvalidated(r) { + logger.Debug(logSender, "", "the token has been invalidated") + doRedirect("Your token is no longer valid", nil) + return errInvalidToken + } + // a user with a partial token will be always redirected to the appropriate two factor auth page + if err := checkPartialAuth(w, r, audience, token.Audience); err != nil { + return err + } + if !token.Audience.Contains(audience) { + logger.Debug(logSender, "", "the token is not valid for audience %q", audience) + doRedirect("Your token audience is not valid", nil) + return errInvalidToken + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + logger.Debug(logSender, "", "the token with id %q is not valid for the ip address %q", token.ID, ipAddr) + doRedirect("Your token is not valid", nil) + return err + } + if err := checkTokenSignature(r, token); err != nil { + doRedirect("Your token is no longer valid", nil) + return err + } + return nil +} + +func (s *httpdServer) validateJWTPartialToken(w http.ResponseWriter, r *http.Request, audience tokenAudience) error { + token, err := jwt.FromContext(r.Context()) + var notFoundFunc func(w http.ResponseWriter, r *http.Request, err error) + if audience == tokenAudienceWebAdminPartial { + notFoundFunc = s.renderNotFoundPage + } else { + notFoundFunc = s.renderClientNotFoundPage + } + if err != nil { + notFoundFunc(w, r, nil) + return errInvalidToken + } + if isTokenInvalidated(r) { + notFoundFunc(w, r, nil) + return errInvalidToken + } + if !token.Audience.Contains(audience) { + logger.Debug(logSender, "", "the partial token with id %q is not valid for audience %q", token.ID, audience) + notFoundFunc(w, r, nil) + return errInvalidToken + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := validateIPForToken(token, ipAddr); err != nil { + logger.Debug(logSender, "", "the partial token with id %q is not valid for the ip address %q", token.ID, ipAddr) + notFoundFunc(w, r, nil) + return err + } + + return nil +} + +func (s *httpdServer) jwtAuthenticatorPartial(audience tokenAudience) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := s.validateJWTPartialToken(w, r, audience); err != nil { + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) + } +} + +func jwtAuthenticatorAPI(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := validateJWTToken(w, r, tokenAudienceAPI); err != nil { + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) +} + +func jwtAuthenticatorAPIUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := validateJWTToken(w, r, tokenAudienceAPIUser); err != nil { + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) +} + +func jwtAuthenticatorWebAdmin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := validateJWTToken(w, r, tokenAudienceWebAdmin); err != nil { + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) +} + +func jwtAuthenticatorWebClient(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := validateJWTToken(w, r, tokenAudienceWebClient); err != nil { + return + } + + // Token is authenticated, pass it through + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) checkHTTPUserPerm(perm string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, err := jwt.FromContext(r.Context()) + if err != nil { + if isWebRequest(r) { + s.renderClientBadRequestPage(w, r, err) + } else { + sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + } + return + } + // for web client perms are negated and not granted + if claims.HasPerm(perm) { + if isWebRequest(r) { + s.renderClientForbiddenPage(w, r, errors.New("you don't have permission for this action")) + } else { + sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) + } + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// checkAuthRequirements checks if the user must set a second factor auth or change the password +func (s *httpdServer) checkAuthRequirements(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, err := jwt.FromContext(r.Context()) + if err != nil { + if isWebRequest(r) { + if isWebClientRequest(r) { + s.renderClientBadRequestPage(w, r, err) + } else { + s.renderBadRequestPage(w, r, err) + } + } else { + sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + } + return + } + if claims.MustSetTwoFactorAuth || claims.MustChangePassword { + var err error + if claims.MustSetTwoFactorAuth { + if len(claims.RequiredTwoFactorProtocols) > 0 { + protocols := strings.Join(claims.RequiredTwoFactorProtocols, ", ") + err = util.NewI18nError( + util.NewGenericError( + fmt.Sprintf("Two-factor authentication requirements not met, please configure two-factor authentication for the following protocols: %v", + protocols)), + util.I18nError2FARequired, + util.I18nErrorArgs(map[string]any{ + "val": protocols, + }), + ) + } else { + err = util.NewI18nError( + util.NewGenericError("Two-factor authentication requirements not met, please configure two-factor authentication"), + util.I18nError2FARequiredGeneric, + ) + } + } else { + err = util.NewI18nError( + util.NewGenericError("Password change required. Please set a new password to continue to use your account"), + util.I18nErrorChangePwdRequired, + ) + } + if isWebRequest(r) { + if isWebClientRequest(r) { + s.renderClientForbiddenPage(w, r, err) + } else { + s.renderForbiddenPage(w, r, err) + } + } else { + sendAPIResponse(w, r, err, "", http.StatusForbidden) + } + return + } + + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) requireBuiltinLogin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isLoggedInWithOIDC(r) { + err := util.NewI18nError( + util.NewGenericError("This feature is not available if you are logged in with OpenID"), + util.I18nErrorNoOIDCFeature, + ) + if isWebClientRequest(r) { + s.renderClientForbiddenPage(w, r, err) + } else { + s.renderForbiddenPage(w, r, err) + } + return + } + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) checkPerms(perms ...string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, err := jwt.FromContext(r.Context()) + if err != nil { + if isWebRequest(r) { + s.renderBadRequestPage(w, r, err) + } else { + sendAPIResponse(w, r, err, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + } + return + } + + for _, perm := range perms { + if !claims.HasPerm(perm) { + if isWebRequest(r) { + s.renderForbiddenPage(w, r, util.NewI18nError(fs.ErrPermission, util.I18nError403Message)) + } else { + sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) + } + return + } + } + + next.ServeHTTP(w, r) + }) + } +} + +func (s *httpdServer) verifyCSRFHeader(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tokenString := r.Header.Get(csrfHeaderToken) + token, err := jwt.VerifyToken(s.csrfTokenAuth, tokenString) + if err != nil || token == nil { + logger.Debug(logSender, "", "error validating CSRF header: %v", err) + sendAPIResponse(w, r, err, "Invalid token", http.StatusForbidden) + return + } + + if !token.Audience.Contains(tokenAudienceCSRF) { + logger.Debug(logSender, "", "error validating CSRF header token audience") + sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) + return + } + + if err := validateIPForToken(token, util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + logger.Debug(logSender, "", "error validating CSRF header IP audience") + sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) + return + } + if err := checkCSRFTokenRef(r, token); err != nil { + sendAPIResponse(w, r, errors.New("the token is not valid"), "", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} + +func checkNodeToken(tokenAuth *jwt.Signer) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bearer := r.Header.Get(dataprovider.NodeTokenHeader) + if bearer == "" { + next.ServeHTTP(w, r) + return + } + const prefix = "Bearer " + if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) { + bearer = bearer[len(prefix):] + } + if invalidatedJWTTokens.Get(bearer) { + logger.Debug(logSender, "", "the node token has been invalidated") + sendAPIResponse(w, r, fmt.Errorf("the provided token is not valid"), "", http.StatusUnauthorized) + return + } + claims, err := dataprovider.AuthenticateNodeToken(bearer) + if err != nil { + logger.Debug(logSender, "", "unable to authenticate node token %q: %v", bearer, err) + sendAPIResponse(w, r, fmt.Errorf("the provided token cannot be authenticated"), "", http.StatusUnauthorized) + return + } + defer invalidatedJWTTokens.Add(bearer, time.Now().Add(2*time.Minute).UTC()) + + c := &jwt.Claims{ + Username: claims.Username, + Permissions: claims.Permissions, + NodeID: dataprovider.GetNodeName(), + Role: claims.Role, + } + + token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, util.GetIPFromRemoteAddress(r.RemoteAddr), getTokenDuration(tokenAudienceAPI)) + if err != nil { + sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + resp := c.BuildTokenResponse(token) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) + + next.ServeHTTP(w, r) + }) + } +} + +func checkAPIKeyAuth(tokenAuth *jwt.Signer, scope dataprovider.APIKeyScope) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiKey := r.Header.Get("X-SFTPGO-API-KEY") + if apiKey == "" { + next.ServeHTTP(w, r) + return + } + keyParams := strings.SplitN(apiKey, ".", 3) + if len(keyParams) < 2 { + logger.Debug(logSender, "", "invalid api key %q", apiKey) + sendAPIResponse(w, r, errors.New("the provided api key is not valid"), "", http.StatusBadRequest) + return + } + keyID := keyParams[0] + key := keyParams[1] + apiUser := "" + if len(keyParams) > 2 { + apiUser = keyParams[2] + } + + k, err := dataprovider.APIKeyExists(keyID) + if err != nil { + handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), util.NewRecordNotFoundError("invalid api key")) //nolint:errcheck + logger.Debug(logSender, "", "invalid api key %q: %v", apiKey, err) + sendAPIResponse(w, r, errors.New("the provided api key is not valid"), "", http.StatusBadRequest) + return + } + if k.Scope != scope { + handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), dataprovider.ErrInvalidCredentials) //nolint:errcheck + logger.Debug(logSender, "", "unable to authenticate api key %q: invalid scope: got %d, wanted: %d", + apiKey, k.Scope, scope) + sendAPIResponse(w, r, fmt.Errorf("the provided api key is invalid for this request"), "", http.StatusForbidden) + return + } + if err := k.Authenticate(key); err != nil { + handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), dataprovider.ErrInvalidCredentials) //nolint:errcheck + logger.Debug(logSender, "", "unable to authenticate api key %q: %v", apiKey, err) + sendAPIResponse(w, r, fmt.Errorf("the provided api key cannot be authenticated"), "", http.StatusUnauthorized) + return + } + if scope == dataprovider.APIKeyScopeAdmin { + if k.Admin != "" { + apiUser = k.Admin + } + if err := authenticateAdminWithAPIKey(apiUser, keyID, tokenAuth, r); err != nil { + handleDefenderEventLoginFailed(util.GetIPFromRemoteAddress(r.RemoteAddr), err) //nolint:errcheck + logger.Debug(logSender, "", "unable to authenticate admin %q associated with api key %q: %v", + apiUser, apiKey, err) + sendAPIResponse(w, r, fmt.Errorf("the admin associated with the provided api key cannot be authenticated"), + "", http.StatusUnauthorized) + return + } + common.DelayLogin(nil) + } else { + if k.User != "" { + apiUser = k.User + } + if err := authenticateUserWithAPIKey(apiUser, keyID, tokenAuth, r); err != nil { + logger.Debug(logSender, "", "unable to authenticate user %q associated with api key %q: %v", + apiUser, apiKey, err) + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: apiUser}}, + dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), err, r) + code := http.StatusUnauthorized + if errors.Is(err, common.ErrInternalFailure) { + code = http.StatusInternalServerError + } + sendAPIResponse(w, r, errors.New("the user associated with the provided api key cannot be authenticated"), + "", code) + return + } + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: apiUser}}, + dataprovider.LoginMethodPassword, util.GetIPFromRemoteAddress(r.RemoteAddr), nil, r) + } + dataprovider.UpdateAPIKeyLastUse(&k) //nolint:errcheck + + next.ServeHTTP(w, r) + }) + } +} + +func forbidAPIKeyAuthentication(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + if claims.APIKeyID != "" { + sendAPIResponse(w, r, nil, "API key authentication is not allowed", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} + +func authenticateAdminWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error { + if username == "" { + return errors.New("the provided key is not associated with any admin and no username was provided") + } + admin, err := dataprovider.AdminExists(username) + if err != nil { + return err + } + if !admin.Filters.AllowAPIKeyAuth { + return fmt.Errorf("API key authentication disabled for admin %q", admin.Username) + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := admin.CanLogin(ipAddr); err != nil { + return err + } + c := &jwt.Claims{ + Username: admin.Username, + Permissions: admin.Permissions, + Role: admin.Role, + APIKeyID: keyID, + } + c.Subject = admin.GetSignature() + + token, err := tokenAuth.SignWithParams(c, tokenAudienceAPI, ipAddr, getTokenDuration(tokenAudienceAPI)) + if err != nil { + return err + } + resp := c.BuildTokenResponse(token) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) + dataprovider.UpdateAdminLastLogin(&admin) + common.DelayLogin(nil) + return nil +} + +func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwt.Signer, r *http.Request) error { + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + protocol := common.ProtocolHTTP + if username == "" { + err := errors.New("the provided key is not associated with any user and no username was provided") + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + return err + } + if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { + return err + } + user, err := dataprovider.GetUserWithGroupSettings(username, "") + if err != nil { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + return err + } + if !user.Filters.AllowAPIKeyAuth { + err := fmt.Errorf("API key authentication disabled for user %q", user.Username) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + return err + } + if err := user.CheckLoginConditions(); err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + return err + } + connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) + if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + return err + } + defer user.CloseFs() //nolint:errcheck + err = user.CheckFsRoot(connectionID) + if err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + return common.ErrInternalFailure + } + c := &jwt.Claims{ + Username: user.Username, + Permissions: user.Filters.WebClient, + Role: user.Role, + APIKeyID: keyID, + } + c.Subject = user.GetSignature() + + token, err := tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser)) + if err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + return err + } + resp := c.BuildTokenResponse(token) + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", resp.Token)) + dataprovider.UpdateLastLogin(&user) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, nil, r) + + return nil +} + +func checkPartialAuth(w http.ResponseWriter, r *http.Request, audience string, tokenAudience []string) error { + if audience == tokenAudienceWebAdmin && slices.Contains(tokenAudience, tokenAudienceWebAdminPartial) { + http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) + return errInvalidToken + } + if audience == tokenAudienceWebClient && slices.Contains(tokenAudience, tokenAudienceWebClientPartial) { + http.Redirect(w, r, webClientTwoFactorPath, http.StatusFound) + return errInvalidToken + } + return nil +} + +func cacheControlMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate, private") + next.ServeHTTP(w, r) + }) +} + +func cleanCacheControlMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Del("Cache-Control") + next.ServeHTTP(w, r) + }) +} diff --git a/internal/httpd/oauth2.go b/internal/httpd/oauth2.go new file mode 100644 index 00000000..bb3e7806 --- /dev/null +++ b/internal/httpd/oauth2.go @@ -0,0 +1,170 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/json" + "errors" + "sync" + "time" + + "golang.org/x/oauth2" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + oauth2Mgr oauth2Manager +) + +func newOAuth2Manager(isShared int) oauth2Manager { + if isShared == 1 { + logger.Info(logSender, "", "using provider OAuth2 manager") + return &dbOAuth2Manager{} + } + logger.Info(logSender, "", "using memory OAuth2 manager") + return &memoryOAuth2Manager{ + pendingAuths: make(map[string]oauth2PendingAuth), + } +} + +type oauth2PendingAuth struct { + State string `json:"state"` + Provider int `json:"provider"` + ClientID string `json:"client_id"` + ClientSecret *kms.Secret `json:"client_secret"` + RedirectURL string `json:"redirect_url"` + IssuedAt int64 `json:"issued_at"` + Verifier string `json:"verifier"` +} + +func newOAuth2PendingAuth(provider int, redirectURL, clientID string, clientSecret *kms.Secret) oauth2PendingAuth { + return oauth2PendingAuth{ + State: util.GenerateOpaqueString(), + Provider: provider, + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + Verifier: oauth2.GenerateVerifier(), + } +} + +type oauth2Manager interface { + addPendingAuth(pendingAuth oauth2PendingAuth) + removePendingAuth(state string) + getPendingAuth(state string) (oauth2PendingAuth, error) + cleanup() +} + +type memoryOAuth2Manager struct { + mu sync.RWMutex + pendingAuths map[string]oauth2PendingAuth +} + +func (o *memoryOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) { + o.mu.Lock() + defer o.mu.Unlock() + + o.pendingAuths[pendingAuth.State] = pendingAuth +} + +func (o *memoryOAuth2Manager) removePendingAuth(state string) { + o.mu.Lock() + defer o.mu.Unlock() + + delete(o.pendingAuths, state) +} + +func (o *memoryOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) { + o.mu.RLock() + defer o.mu.RUnlock() + + authReq, ok := o.pendingAuths[state] + if !ok { + return oauth2PendingAuth{}, errors.New("oauth2: no auth request found for the specified state") + } + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt + if diff > authStateValidity { + return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old") + } + return authReq, nil +} + +func (o *memoryOAuth2Manager) cleanup() { + o.mu.Lock() + defer o.mu.Unlock() + + for k, auth := range o.pendingAuths { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt + // remove old pending auth requests + if diff < 0 || diff > authStateValidity { + delete(o.pendingAuths, k) + } + } +} + +type dbOAuth2Manager struct{} + +func (o *dbOAuth2Manager) addPendingAuth(pendingAuth oauth2PendingAuth) { + if err := pendingAuth.ClientSecret.Encrypt(); err != nil { + logger.Error(logSender, "", "unable to encrypt oauth2 secret: %v", err) + return + } + session := dataprovider.Session{ + Key: pendingAuth.State, + Data: pendingAuth, + Type: dataprovider.SessionTypeOAuth2Auth, + Timestamp: pendingAuth.IssuedAt + authStateValidity, + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (o *dbOAuth2Manager) removePendingAuth(state string) { + dataprovider.DeleteSharedSession(state, dataprovider.SessionTypeOAuth2Auth) //nolint:errcheck +} + +func (o *dbOAuth2Manager) getPendingAuth(state string) (oauth2PendingAuth, error) { + session, err := dataprovider.GetSharedSession(state, dataprovider.SessionTypeOAuth2Auth) + if err != nil { + return oauth2PendingAuth{}, errors.New("oauth2: unable to get the auth request for the specified state") + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return oauth2PendingAuth{}, errors.New("oauth2: auth request is too old") + } + return o.decodePendingAuthData(session.Data) +} + +func (o *dbOAuth2Manager) decodePendingAuthData(data any) (oauth2PendingAuth, error) { + if val, ok := data.([]byte); ok { + authReq := oauth2PendingAuth{} + err := json.Unmarshal(val, &authReq) + if err != nil { + return authReq, err + } + err = authReq.ClientSecret.TryDecrypt() + return authReq, err + } + logger.Error(logSender, "", "invalid oauth2 auth request data type %T", data) + return oauth2PendingAuth{}, errors.New("oauth2: invalid auth request data") +} + +func (o *dbOAuth2Manager) cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOAuth2Auth, time.Now()) //nolint:errcheck +} diff --git a/internal/httpd/oauth2_test.go b/internal/httpd/oauth2_test.go new file mode 100644 index 00000000..0e488136 --- /dev/null +++ b/internal/httpd/oauth2_test.go @@ -0,0 +1,135 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/json" + "testing" + "time" + + "github.com/rs/xid" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func TestMemoryOAuth2Manager(t *testing.T) { + mgr := newOAuth2Manager(0) + m, ok := mgr.(*memoryOAuth2Manager) + require.True(t, ok) + require.Len(t, m.pendingAuths, 0) + _, err := m.getPendingAuth(xid.New().String()) + require.Error(t, err) + assert.Contains(t, err.Error(), "no auth request found") + auth := newOAuth2PendingAuth(1, "https://...", "cid", kms.NewPlainSecret("mysecret")) + m.addPendingAuth(auth) + require.Len(t, m.pendingAuths, 1) + a, err := m.getPendingAuth(auth.State) + assert.NoError(t, err) + assert.Equal(t, auth.State, a.State) + assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus()) + m.removePendingAuth(auth.State) + _, err = m.getPendingAuth(auth.State) + require.Error(t, err) + assert.Contains(t, err.Error(), "no auth request found") + require.Len(t, m.pendingAuths, 0) + state := xid.New().String() + auth = oauth2PendingAuth{ + State: state, + Provider: 1, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + } + m.addPendingAuth(auth) + auth = oauth2PendingAuth{ + State: xid.New().String(), + Provider: 1, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), + } + m.addPendingAuth(auth) + require.Len(t, m.pendingAuths, 2) + _, err = m.getPendingAuth(auth.State) + require.Error(t, err) + assert.Contains(t, err.Error(), "auth request is too old") + m.cleanup() + require.Len(t, m.pendingAuths, 1) + m.removePendingAuth(state) + require.Len(t, m.pendingAuths, 0) +} + +func TestDbOAuth2Manager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newOAuth2Manager(1) + m, ok := mgr.(*dbOAuth2Manager) + require.True(t, ok) + _, err := m.getPendingAuth(xid.New().String()) + require.Error(t, err) + auth := newOAuth2PendingAuth(1, "https://...", "client_id", kms.NewPlainSecret("my db secret")) + m.addPendingAuth(auth) + a, err := m.getPendingAuth(auth.State) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusPlain, a.ClientSecret.GetStatus()) + session, err := dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) + assert.NoError(t, err) + authReq := oauth2PendingAuth{} + err = json.Unmarshal(session.Data.([]byte), &authReq) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, authReq.ClientSecret.GetStatus()) + m.cleanup() + _, err = m.getPendingAuth(auth.State) + assert.NoError(t, err) + m.removePendingAuth(auth.State) + _, err = m.getPendingAuth(auth.State) + assert.Error(t, err) + auth = oauth2PendingAuth{ + State: xid.New().String(), + Provider: 1, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), + ClientSecret: kms.NewPlainSecret("db secret"), + } + m.addPendingAuth(auth) + _, err = m.getPendingAuth(auth.State) + assert.Error(t, err) + _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) + assert.NoError(t, err) + m.cleanup() + _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) + assert.Error(t, err) + _, err = m.decodePendingAuthData("not a byte array") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid auth request data") + _, err = m.decodePendingAuthData([]byte("{not a json")) + require.Error(t, err) + // adding a request with a non plain secret will fail + auth = oauth2PendingAuth{ + State: xid.New().String(), + Provider: 1, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), + ClientSecret: kms.NewPlainSecret("db secret"), + } + auth.ClientSecret.SetStatus(sdkkms.SecretStatusSecretBox) + m.addPendingAuth(auth) + _, err = dataprovider.GetSharedSession(auth.State, dataprovider.SessionTypeOAuth2Auth) + assert.Error(t, err) + asJSON, err := json.Marshal(auth) + assert.NoError(t, err) + _, err = m.decodePendingAuthData(asJSON) + assert.Error(t, err) +} diff --git a/internal/httpd/oidc.go b/internal/httpd/oidc.go new file mode 100644 index 00000000..7c43fac5 --- /dev/null +++ b/internal/httpd/oidc.go @@ -0,0 +1,859 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/rs/xid" + "golang.org/x/oauth2" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + oidcCookieKey = "oidc" + adminRoleFieldValue = "admin" + authStateValidity = 2 * 60 * 1000 // 2 minutes + tokenUpdateInterval = 3 * 60 * 1000 // 3 minutes + tokenDeleteInterval = 2 * 3600 * 1000 // 2 hours +) + +var ( + oidcTokenKey = &contextKey{"OIDC token key"} + oidcGeneratedToken = &contextKey{"OIDC generated token"} +) + +// OAuth2Config defines an interface for OAuth2 methods, so we can mock them +type OAuth2Config interface { + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource +} + +// OIDCTokenVerifier defines an interface for OpenID token verifier, so we can mock them +type OIDCTokenVerifier interface { + Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error) +} + +// OIDC defines the OpenID Connect configuration +type OIDC struct { + // ClientID is the application's ID + ClientID string `json:"client_id" mapstructure:"client_id"` + // ClientSecret is the application's secret + ClientSecret string `json:"client_secret" mapstructure:"client_secret"` + ClientSecretFile string `json:"client_secret_file" mapstructure:"client_secret_file"` + // ConfigURL is the identifier for the service. + // SFTPGo will try to retrieve the provider configuration on startup and then + // will refuse to start if it fails to connect to the specified URL + ConfigURL string `json:"config_url" mapstructure:"config_url"` + // RedirectBaseURL is the base URL to redirect to after OpenID authentication. + // The suffix "/web/oidc/redirect" will be added to this base URL, adding also the + // "web_root" if configured + RedirectBaseURL string `json:"redirect_base_url" mapstructure:"redirect_base_url"` + // ID token claims field to map to the SFTPGo username + UsernameField string `json:"username_field" mapstructure:"username_field"` + // Optional ID token claims field to map to a SFTPGo role. + // If the defined ID token claims field is set to "admin" the authenticated user + // is mapped to an SFTPGo admin. + // You don't need to specify this field if you want to use OpenID only for the + // Web Client UI + RoleField string `json:"role_field" mapstructure:"role_field"` + // If set, the `RoleField` is ignored and the SFTPGo role is assumed based on + // the login link used + ImplicitRoles bool `json:"implicit_roles" mapstructure:"implicit_roles"` + // Scopes required by the OAuth provider to retrieve information about the authenticated user. + // The "openid" scope is required. + // Refer to your OAuth provider documentation for more information about this + Scopes []string `json:"scopes" mapstructure:"scopes"` + // Custom token claims fields to pass to the pre-login hook + CustomFields []string `json:"custom_fields" mapstructure:"custom_fields"` + // InsecureSkipSignatureCheck causes SFTPGo to skip JWT signature validation. + // It's intended for special cases where providers, such as Azure, use the "none" + // algorithm. Skipping the signature validation can cause security issues + InsecureSkipSignatureCheck bool `json:"insecure_skip_signature_check" mapstructure:"insecure_skip_signature_check"` + // Debug enables the OIDC debug mode. In debug mode, the received id_token will be logged + // at the debug level + Debug bool `json:"debug" mapstructure:"debug"` + provider *oidc.Provider + verifier OIDCTokenVerifier + providerLogoutURL string + oauth2Config OAuth2Config +} + +func (o *OIDC) isEnabled() bool { + return o.provider != nil +} + +func (o *OIDC) hasRoles() bool { + return o.isEnabled() && (o.RoleField != "" || o.ImplicitRoles) +} + +func (o *OIDC) getForcedRole(audience string) string { + if !o.ImplicitRoles { + return "" + } + if audience == tokenAudienceWebAdmin { + return adminRoleFieldValue + } + return "" +} + +func (o *OIDC) getRedirectURL() string { + url := o.RedirectBaseURL + if strings.HasSuffix(o.RedirectBaseURL, "/") { + url = strings.TrimSuffix(o.RedirectBaseURL, "/") + } + url += webOIDCRedirectPath + logger.Debug(logSender, "", "oidc redirect URL: %q", url) + return url +} + +func (o *OIDC) initialize() error { + if o.ConfigURL == "" { + return nil + } + if o.UsernameField == "" { + return errors.New("oidc: username field cannot be empty") + } + if o.RedirectBaseURL == "" { + return errors.New("oidc: redirect base URL cannot be empty") + } + if !slices.Contains(o.Scopes, oidc.ScopeOpenID) { + return fmt.Errorf("oidc: required scope %q is not set", oidc.ScopeOpenID) + } + if o.ClientSecretFile != "" { + secret, err := util.ReadConfigFromFile(o.ClientSecretFile, configurationDir) + if err != nil { + return err + } + o.ClientSecret = secret + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + provider, err := oidc.NewProvider(ctx, o.ConfigURL) + if err != nil { + return fmt.Errorf("oidc: unable to initialize provider for URL %q: %w", o.ConfigURL, err) + } + claims := make(map[string]any) + // we cannot get an error here because the response body was already parsed as JSON + // on provider creation + provider.Claims(&claims) //nolint:errcheck + endSessionEndPoint, ok := claims["end_session_endpoint"] + if ok { + if val, ok := endSessionEndPoint.(string); ok { + o.providerLogoutURL = val + logger.Debug(logSender, "", "oidc end session endpoint %q", o.providerLogoutURL) + } + } + o.provider = provider + o.verifier = nil + o.oauth2Config = &oauth2.Config{ + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + Endpoint: o.provider.Endpoint(), + RedirectURL: o.getRedirectURL(), + Scopes: o.Scopes, + } + + return nil +} + +func (o *OIDC) getVerifier(ctx context.Context) OIDCTokenVerifier { + if o.verifier != nil { + return o.verifier + } + return o.provider.VerifierContext(ctx, &oidc.Config{ + ClientID: o.ClientID, + InsecureSkipSignatureCheck: o.InsecureSkipSignatureCheck, + }) +} + +type oidcPendingAuth struct { + State string `json:"state"` + Nonce string `json:"nonce"` + Audience tokenAudience `json:"audience"` + IssuedAt int64 `json:"issued_at"` + Verifier string `json:"verifier"` +} + +func newOIDCPendingAuth(audience tokenAudience) oidcPendingAuth { + return oidcPendingAuth{ + State: util.GenerateOpaqueString(), + Nonce: util.GenerateOpaqueString(), + Audience: audience, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now()), + Verifier: oauth2.GenerateVerifier(), + } +} + +type oidcToken struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + SessionID string `json:"session_id"` + IDToken string `json:"id_token"` + Nonce string `json:"nonce"` + Username string `json:"username"` + Permissions []string `json:"permissions"` + HideUserPageSections int `json:"hide_user_page_sections,omitempty"` + MustSetTwoFactorAuth bool `json:"must_set_2fa,omitempty"` + MustChangePassword bool `json:"must_change_password,omitempty"` + RequiredTwoFactorProtocols []string `json:"required_two_factor_protocols,omitempty"` + TokenRole string `json:"token_role,omitempty"` // SFTPGo role name + Role any `json:"role"` // oidc user role: SFTPGo user or admin + CustomFields *map[string]any `json:"custom_fields,omitempty"` + Cookie string `json:"cookie"` + UsedAt int64 `json:"used_at"` +} + +func (t *oidcToken) parseClaims(claims map[string]any, usernameField, roleField string, customFields []string, + forcedRole string, +) error { + getClaimsFields := func() []string { + keys := make([]string, 0, len(claims)) + for k := range claims { + keys = append(keys, k) + } + return keys + } + + var username string + val, ok := getOIDCFieldFromClaims(claims, usernameField) + if ok { + username, ok = val.(string) + } + if !ok || username == "" { + logger.Warn(logSender, "", "username field %q not found, empty or not a string, claims fields: %+v", + usernameField, getClaimsFields()) + return errors.New("no username field") + } + t.Username = username + if forcedRole != "" { + t.Role = forcedRole + } else { + t.getRoleFromField(claims, roleField) + } + t.CustomFields = nil + if len(customFields) > 0 { + for _, field := range customFields { + if val, ok := getOIDCFieldFromClaims(claims, field); ok { + if t.CustomFields == nil { + customFields := make(map[string]any) + t.CustomFields = &customFields + } + logger.Debug(logSender, "", "custom field %q found in token claims", field) + (*t.CustomFields)[field] = val + } else { + logger.Info(logSender, "", "custom field %q not found in token claims", field) + } + } + } + sid, ok := claims["sid"].(string) + if ok { + t.SessionID = sid + } + return nil +} + +func (t *oidcToken) getRoleFromField(claims map[string]any, roleField string) { + role, ok := getOIDCFieldFromClaims(claims, roleField) + if ok { + t.Role = role + } +} + +func (t *oidcToken) isAdmin() bool { + switch v := t.Role.(type) { + case string: + return v == adminRoleFieldValue + case []any: + for _, s := range v { + if val, ok := s.(string); ok && val == adminRoleFieldValue { + return true + } + } + return false + default: + return false + } +} + +func (t *oidcToken) isExpired() bool { + if t.ExpiresAt == 0 { + return false + } + return t.ExpiresAt < util.GetTimeAsMsSinceEpoch(time.Now()) +} + +func (t *oidcToken) refresh(ctx context.Context, config OAuth2Config, verifier OIDCTokenVerifier, r *http.Request) error { + if t.RefreshToken == "" { + logger.Debug(logSender, "", "refresh token not set, unable to refresh cookie %q", t.Cookie) + return errors.New("refresh token not set") + } + oauth2Token := oauth2.Token{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + } + if t.ExpiresAt > 0 { + oauth2Token.Expiry = util.GetTimeFromMsecSinceEpoch(t.ExpiresAt) + } + + newToken, err := config.TokenSource(ctx, &oauth2Token).Token() + if err != nil { + logger.Debug(logSender, "", "unable to refresh token for cookie %q: %v", t.Cookie, err) + return err + } + rawIDToken, ok := newToken.Extra("id_token").(string) + if !ok { + logger.Debug(logSender, "", "the refreshed token has no id token, cookie %q", t.Cookie) + return errors.New("the refreshed token has no id token") + } + + t.AccessToken = newToken.AccessToken + t.TokenType = newToken.TokenType + t.RefreshToken = newToken.RefreshToken + t.IDToken = rawIDToken + if !newToken.Expiry.IsZero() { + t.ExpiresAt = util.GetTimeAsMsSinceEpoch(newToken.Expiry) + } else { + t.ExpiresAt = 0 + } + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + logger.Debug(logSender, "", "unable to verify refreshed id token for cookie %q: %v", t.Cookie, err) + return err + } + if idToken.Nonce != "" && idToken.Nonce != t.Nonce { + logger.Warn(logSender, "", "unable to verify refreshed id token for cookie %q: nonce mismatch, expected: %q, actual: %q", + t.Cookie, t.Nonce, idToken.Nonce) + return errors.New("the refreshed token nonce mismatch") + } + claims := make(map[string]any) + err = idToken.Claims(&claims) + if err != nil { + logger.Warn(logSender, "", "unable to get refreshed id token claims for cookie %q: %v", t.Cookie, err) + return err + } + sid, ok := claims["sid"].(string) + if ok { + t.SessionID = sid + } + err = t.refreshUser(r) + if err != nil { + logger.Debug(logSender, "", "unable to refresh user after token refresh for cookie %q: %v", t.Cookie, err) + return err + } + logger.Debug(logSender, "", "oidc token refreshed for user %q, cookie %q", t.Username, t.Cookie) + oidcMgr.addToken(*t) + + return nil +} + +func (t *oidcToken) refreshUser(r *http.Request) error { + if t.isAdmin() { + admin, err := dataprovider.AdminExists(t.Username) + if err != nil { + return err + } + if err := admin.CanLogin(util.GetIPFromRemoteAddress(r.RemoteAddr)); err != nil { + return err + } + t.Permissions = admin.Permissions + t.TokenRole = admin.Role + t.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections + return nil + } + user, err := dataprovider.GetUserWithGroupSettings(t.Username, "") + if err != nil { + return err + } + if err := user.CheckLoginConditions(); err != nil { + return err + } + if err := checkHTTPClientUser(&user, r, xid.New().String(), true, false); err != nil { + return err + } + t.Permissions = user.Filters.WebClient + t.TokenRole = user.Role + t.MustSetTwoFactorAuth = user.MustSetSecondFactor() + t.MustChangePassword = user.MustChangePassword() + t.RequiredTwoFactorProtocols = user.Filters.TwoFactorAuthProtocols + return nil +} + +func (t *oidcToken) getUser(r *http.Request) error { + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + params := common.EventParams{ + Name: t.Username, + IP: ipAddr, + Protocol: common.ProtocolOIDC, + Timestamp: time.Now(), + Status: 1, + } + if t.isAdmin() { + params.Event = common.IDPLoginAdmin + _, admin, err := common.HandleIDPLoginEvent(params, t.CustomFields) + if err != nil { + return err + } + if admin == nil { + a, err := dataprovider.AdminExists(t.Username) + if err != nil { + return err + } + admin = &a + } + if err := admin.CanLogin(ipAddr); err != nil { + return err + } + t.Permissions = admin.Permissions + t.TokenRole = admin.Role + t.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections + dataprovider.UpdateAdminLastLogin(admin) + common.DelayLogin(nil) + return nil + } + params.Event = common.IDPLoginUser + user, _, err := common.HandleIDPLoginEvent(params, t.CustomFields) + if err != nil { + return err + } + if user == nil { + u, err := dataprovider.GetUserAfterIDPAuth(t.Username, ipAddr, common.ProtocolOIDC, t.CustomFields) + if err != nil { + return err + } + user = &u + } + if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolOIDC); err != nil { + updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r) + return fmt.Errorf("access denied: %w", err) + } + if err := user.CheckLoginConditions(); err != nil { + updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r) + return err + } + connectionID := fmt.Sprintf("%s_%s", common.ProtocolOIDC, xid.New().String()) + if err := checkHTTPClientUser(user, r, connectionID, true, true); err != nil { + updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r) + return err + } + defer user.CloseFs() //nolint:errcheck + err = user.CheckFsRoot(connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) + updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, common.ErrInternalFailure, r) + return err + } + updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, nil, r) + dataprovider.UpdateLastLogin(user) + t.Permissions = user.Filters.WebClient + t.TokenRole = user.Role + t.MustSetTwoFactorAuth = user.MustSetSecondFactor() + t.MustChangePassword = user.MustChangePassword() + t.RequiredTwoFactorProtocols = user.Filters.TwoFactorAuthProtocols + return nil +} + +func (s *httpdServer) validateOIDCToken(w http.ResponseWriter, r *http.Request, isAdmin bool) (oidcToken, error) { + doRedirect := func() { + removeOIDCCookie(w, r) + if isAdmin { + http.Redirect(w, r, webAdminLoginPath, http.StatusFound) + return + } + http.Redirect(w, r, webClientLoginPath, http.StatusFound) + } + + cookie, err := r.Cookie(oidcCookieKey) + if err != nil { + logger.Debug(logSender, "", "no oidc cookie, redirecting to login page") + doRedirect() + return oidcToken{}, errInvalidToken + } + token, err := oidcMgr.getToken(cookie.Value) + if err != nil { + logger.Debug(logSender, "", "error getting oidc token associated with cookie %q: %v", cookie.Value, err) + doRedirect() + return oidcToken{}, errInvalidToken + } + if token.isExpired() { + logger.Debug(logSender, "", "oidc token associated with cookie %q is expired", token.Cookie) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + if err = token.refresh(ctx, s.binding.OIDC.oauth2Config, s.binding.OIDC.getVerifier(ctx), r); err != nil { + setFlashMessage(w, r, newFlashMessage("Your OpenID token is expired, please log-in again", util.I18nOIDCTokenExpired)) + doRedirect() + return oidcToken{}, errInvalidToken + } + } else { + oidcMgr.updateTokenUsage(token) + } + if isAdmin { + if !token.isAdmin() { + logger.Debug(logSender, "", "oidc token associated with cookie %q is not valid for admin users", token.Cookie) + setFlashMessage(w, r, newFlashMessage( + "Your OpenID token is not valid for the SFTPGo Web Admin UI. Please logout from your OpenID server and log-in as an SFTPGo admin", + util.I18nOIDCTokenInvalidAdmin, + )) + doRedirect() + return oidcToken{}, errInvalidToken + } + return token, nil + } + if token.isAdmin() { + logger.Debug(logSender, "", "oidc token associated with cookie %q is valid for admin users", token.Cookie) + setFlashMessage(w, r, newFlashMessage( + "Your OpenID token is not valid for the SFTPGo Web Client UI. Please logout from your OpenID server and log-in as an SFTPGo user", + util.I18nOIDCTokenInvalidUser, + )) + doRedirect() + return oidcToken{}, errInvalidToken + } + return token, nil +} + +func (s *httpdServer) oidcTokenAuthenticator(audience tokenAudience) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if canSkipOIDCValidation(r) { + next.ServeHTTP(w, r) + return + } + token, err := s.validateOIDCToken(w, r, audience == tokenAudienceWebAdmin) + if err != nil { + return + } + claims := jwt.Claims{ + Username: dataprovider.ConvertName(token.Username), + Permissions: token.Permissions, + Role: token.TokenRole, + HideUserPageSections: token.HideUserPageSections, + } + claims.ID = token.Cookie + if audience == tokenAudienceWebClient { + claims.MustSetTwoFactorAuth = token.MustSetTwoFactorAuth + claims.MustChangePassword = token.MustChangePassword + claims.RequiredTwoFactorProtocols = token.RequiredTwoFactorProtocols + } + tokenString, err := s.tokenAuth.SignWithParams(&claims, audience, util.GetIPFromRemoteAddress(r.RemoteAddr), + getTokenDuration(audience)) + if err != nil { + setFlashMessage(w, r, newFlashMessage("Unable to create cookie", util.I18nError500Message)) + if audience == tokenAudienceWebAdmin { + http.Redirect(w, r, webAdminLoginPath, http.StatusFound) + } else { + http.Redirect(w, r, webClientLoginPath, http.StatusFound) + } + return + } + ctx := context.WithValue(r.Context(), oidcTokenKey, token.Cookie) + ctx = context.WithValue(ctx, oidcGeneratedToken, tokenString) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func (s *httpdServer) handleWebAdminOIDCLogin(w http.ResponseWriter, r *http.Request) { + s.oidcLoginRedirect(w, r, tokenAudienceWebAdmin) +} + +func (s *httpdServer) handleWebClientOIDCLogin(w http.ResponseWriter, r *http.Request) { + s.oidcLoginRedirect(w, r, tokenAudienceWebClient) +} + +func (s *httpdServer) oidcLoginRedirect(w http.ResponseWriter, r *http.Request, audience tokenAudience) { + pendingAuth := newOIDCPendingAuth(audience) + oidcMgr.addPendingAuth(pendingAuth) + http.Redirect(w, r, s.binding.OIDC.oauth2Config.AuthCodeURL(pendingAuth.State, + oidc.Nonce(pendingAuth.Nonce), oauth2.S256ChallengeOption(pendingAuth.Verifier)), http.StatusFound) +} + +func (s *httpdServer) debugTokenClaims(claims map[string]any, rawIDToken string) { + if s.binding.OIDC.Debug { + if claims == nil { + logger.Debug(logSender, "", "raw id token %q", rawIDToken) + } else { + logger.Debug(logSender, "", "raw id token %q, parsed claims %+v", rawIDToken, claims) + } + } +} + +func (s *httpdServer) handleOIDCRedirect(w http.ResponseWriter, r *http.Request) { + state := r.URL.Query().Get("state") + authReq, err := oidcMgr.getPendingAuth(state) + if err != nil { + logger.Debug(logSender, "", "oidc authentication state did not match") + oidcMgr.removePendingAuth(state) + s.renderClientMessagePage(w, r, util.I18nInvalidAuthReqTitle, http.StatusBadRequest, + util.NewI18nError(err, util.I18nInvalidAuth), "") + return + } + oidcMgr.removePendingAuth(state) + + doRedirect := func() { + if authReq.Audience == tokenAudienceWebAdmin { + http.Redirect(w, r, webAdminLoginPath, http.StatusFound) + return + } + http.Redirect(w, r, webClientLoginPath, http.StatusFound) + } + doLogout := func(rawIDToken string) { + s.logoutFromOIDCOP(rawIDToken) + } + + ctx, cancel := context.WithTimeout(r.Context(), 20*time.Second) + defer cancel() + + oauth2Token, err := s.binding.OIDC.oauth2Config.Exchange(ctx, r.URL.Query().Get("code"), + oauth2.VerifierOption(authReq.Verifier)) + if err != nil { + logger.Debug(logSender, "", "failed to exchange oidc token: %v", err) + setFlashMessage(w, r, newFlashMessage("Failed to exchange OpenID token", util.I18nOIDCErrTokenExchange)) + doRedirect() + return + } + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + logger.Debug(logSender, "", "no id_token field in OAuth2 OpenID token") + setFlashMessage(w, r, newFlashMessage("No id_token field in OAuth2 OpenID token", util.I18nOIDCTokenInvalid)) + doRedirect() + return + } + s.debugTokenClaims(nil, rawIDToken) + idToken, err := s.binding.OIDC.getVerifier(ctx).Verify(ctx, rawIDToken) + if err != nil { + logger.Debug(logSender, "", "failed to verify oidc token: %v", err) + setFlashMessage(w, r, newFlashMessage("Failed to verify OpenID token", util.I18nOIDCTokenInvalid)) + doRedirect() + doLogout(rawIDToken) + return + } + if idToken.Nonce != authReq.Nonce { + logger.Debug(logSender, "", "oidc authentication nonce did not match") + setFlashMessage(w, r, newFlashMessage("OpenID authentication nonce did not match", util.I18nOIDCTokenInvalid)) + doRedirect() + doLogout(rawIDToken) + return + } + + claims := make(map[string]any) + err = idToken.Claims(&claims) + if err != nil { + logger.Debug(logSender, "", "unable to get oidc token claims: %v", err) + setFlashMessage(w, r, newFlashMessage("Unable to get OpenID token claims", util.I18nOIDCTokenInvalid)) + doRedirect() + doLogout(rawIDToken) + return + } + s.debugTokenClaims(claims, rawIDToken) + token := oidcToken{ + AccessToken: oauth2Token.AccessToken, + TokenType: oauth2Token.TokenType, + RefreshToken: oauth2Token.RefreshToken, + IDToken: rawIDToken, + Nonce: idToken.Nonce, + Cookie: util.GenerateOpaqueString(), + } + if !oauth2Token.Expiry.IsZero() { + token.ExpiresAt = util.GetTimeAsMsSinceEpoch(oauth2Token.Expiry) + } + err = token.parseClaims(claims, s.binding.OIDC.UsernameField, s.binding.OIDC.RoleField, + s.binding.OIDC.CustomFields, s.binding.OIDC.getForcedRole(authReq.Audience)) + if err != nil { + logger.Debug(logSender, "", "unable to parse oidc token claims: %v", err) + setFlashMessage(w, r, newFlashMessage(fmt.Sprintf("Unable to parse OpenID token claims: %v", err), util.I18nOIDCTokenInvalid)) + doRedirect() + doLogout(rawIDToken) + return + } + switch authReq.Audience { + case tokenAudienceWebAdmin: + if !token.isAdmin() { + logger.Debug(logSender, "", "wrong oidc token role, the mapped user is not an SFTPGo admin") + setFlashMessage(w, r, newFlashMessage( + "Wrong OpenID role, the logged in user is not an SFTPGo admin", + util.I18nOIDCTokenInvalidRoleAdmin)) + doRedirect() + doLogout(rawIDToken) + return + } + case tokenAudienceWebClient: + if token.isAdmin() { + logger.Debug(logSender, "", "wrong oidc token role, the mapped user is an SFTPGo admin") + setFlashMessage(w, r, newFlashMessage( + "Wrong OpenID role, the logged in user is an SFTPGo admin", + util.I18nOIDCTokenInvalidRoleUser, + )) + doRedirect() + doLogout(rawIDToken) + return + } + } + err = token.getUser(r) + if err != nil { + logger.Debug(logSender, "", "unable to get the sftpgo user associated with oidc token: %v", err) + setFlashMessage(w, r, newFlashMessage("Unable to get the user associated with the OpenID token", util.I18nOIDCErrGetUser)) + doRedirect() + doLogout(rawIDToken) + return + } + + loginOIDCUser(w, r, token) +} + +func loginOIDCUser(w http.ResponseWriter, r *http.Request, token oidcToken) { + oidcMgr.addToken(token) + + cookie := http.Cookie{ + Name: oidcCookieKey, + Value: token.Cookie, + Path: "/", + HttpOnly: true, + Secure: isTLS(r), + SameSite: http.SameSiteLaxMode, + } + // we don't set a cookie expiration so we can refresh the token without setting a new cookie + // the cookie will be invalidated on browser close + http.SetCookie(w, &cookie) + w.Header().Add("Cache-Control", `no-cache="Set-Cookie"`) + if token.isAdmin() { + http.Redirect(w, r, webUsersPath, http.StatusFound) + return + } + http.Redirect(w, r, webClientFilesPath, http.StatusFound) +} + +func (s *httpdServer) logoutOIDCUser(w http.ResponseWriter, r *http.Request) { + if oidcKey, ok := r.Context().Value(oidcTokenKey).(string); ok { + removeOIDCCookie(w, r) + token, err := oidcMgr.getToken(oidcKey) + if err == nil { + s.logoutFromOIDCOP(token.IDToken) + } + oidcMgr.removeToken(oidcKey) + } +} + +func (s *httpdServer) logoutFromOIDCOP(idToken string) { + if s.binding.OIDC.providerLogoutURL == "" { + logger.Debug(logSender, "", "oidc: provider logout URL not set, unable to logout from the OP") + return + } + go s.doOIDCFromLogout(idToken) +} + +func (s *httpdServer) doOIDCFromLogout(idToken string) { + logoutURL, err := url.Parse(s.binding.OIDC.providerLogoutURL) + if err != nil { + logger.Warn(logSender, "", "oidc: unable to parse logout URL: %v", err) + return + } + query := logoutURL.Query() + if idToken != "" { + query.Set("id_token_hint", idToken) + } + logoutURL.RawQuery = query.Encode() + resp, err := httpclient.RetryableGet(logoutURL.String()) + if err != nil { + logger.Warn(logSender, "", "oidc: error calling logout URL %q: %v", logoutURL.String(), err) + return + } + defer resp.Body.Close() + logger.Debug(logSender, "", "oidc: logout url response code %v", resp.StatusCode) +} + +func removeOIDCCookie(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: oidcCookieKey, + Value: "", + Path: "/", + Expires: time.Unix(0, 0), + MaxAge: -1, + HttpOnly: true, + Secure: isTLS(r), + SameSite: http.SameSiteLaxMode, + }) +} + +// canSkipOIDCValidation returns true if there is no OIDC cookie but a jwt cookie is set +// and so we check if the user is logged in using a built-in user +func canSkipOIDCValidation(r *http.Request) bool { + _, err := r.Cookie(oidcCookieKey) + if err != nil { + _, err = r.Cookie(jwt.CookieKey) + return err == nil + } + return false +} + +func isLoggedInWithOIDC(r *http.Request) bool { + _, ok := r.Context().Value(oidcTokenKey).(string) + return ok +} + +func getOIDCFieldFromClaims(claims map[string]any, fieldName string) (any, bool) { + if fieldName == "" { + return nil, false + } + val, ok := claims[fieldName] + if ok { + return val, true + } + if !strings.Contains(fieldName, ".") { + return nil, false + } + + getStructValue := func(outer any, field string) (any, bool) { + switch v := outer.(type) { + case map[string]any: + res, ok := v[field] + return res, ok + } + return nil, false + } + + for idx, field := range strings.Split(fieldName, ".") { + if idx == 0 { + val, ok = getStructValue(claims, field) + } else { + val, ok = getStructValue(val, field) + } + if !ok { + return nil, false + } + } + + return val, ok +} diff --git a/internal/httpd/oidc_test.go b/internal/httpd/oidc_test.go new file mode 100644 index 00000000..4a14c76c --- /dev/null +++ b/internal/httpd/oidc_test.go @@ -0,0 +1,1784 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/fs" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "reflect" + "runtime" + "testing" + "time" + "unsafe" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + oidcMockAddr = "127.0.0.1:11111" +) + +type mockTokenSource struct { + token *oauth2.Token + err error +} + +func (t *mockTokenSource) Token() (*oauth2.Token, error) { + return t.token, t.err +} + +type mockOAuth2Config struct { + tokenSource *mockTokenSource + authCodeURL string + token *oauth2.Token + err error +} + +func (c *mockOAuth2Config) AuthCodeURL(_ string, _ ...oauth2.AuthCodeOption) string { + return c.authCodeURL +} + +func (c *mockOAuth2Config) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return c.token, c.err +} + +func (c *mockOAuth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { + return c.tokenSource +} + +type mockOIDCVerifier struct { + token *oidc.IDToken + err error +} + +func (v *mockOIDCVerifier) Verify(_ context.Context, _ string) (*oidc.IDToken, error) { + return v.token, v.err +} + +// hack because the field is unexported +func setIDTokenClaims(idToken *oidc.IDToken, claims []byte) { + pointerVal := reflect.ValueOf(idToken) + val := reflect.Indirect(pointerVal) + member := val.FieldByName("claims") + ptr := unsafe.Pointer(member.UnsafeAddr()) + realPtr := (*[]byte)(ptr) + *realPtr = claims +} + +func TestOIDCInitialization(t *testing.T) { + config := OIDC{} + err := config.initialize() + assert.NoError(t, err) + secret := "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c" + config = OIDC{ + ClientID: "sftpgo-client", + ClientSecret: util.GenerateUniqueID(), + ConfigURL: fmt.Sprintf("http://%v/", oidcMockAddr), + RedirectBaseURL: "http://127.0.0.1:8081/", + UsernameField: "preferred_username", + RoleField: "sftpgo_role", + } + err = config.initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "oidc: required scope \"openid\" is not set") + } + config.Scopes = []string{oidc.ScopeOpenID} + config.ClientSecretFile = "missing file" + err = config.initialize() + assert.ErrorIs(t, err, fs.ErrNotExist) + secretFile := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + defer os.Remove(secretFile) + err = os.WriteFile(secretFile, []byte(secret), 0600) + assert.NoError(t, err) + config.ClientSecretFile = secretFile + err = config.initialize() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "oidc: unable to initialize provider") + } + assert.Equal(t, secret, config.ClientSecret) + config.ConfigURL = fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr) + err = config.initialize() + assert.NoError(t, err) + assert.Equal(t, "http://127.0.0.1:8081"+webOIDCRedirectPath, config.getRedirectURL()) +} + +func TestOIDCLoginLogout(t *testing.T) { + tokenValidationMode = 2 + + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + server := getTestOIDCServer() + err := server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth) + + expiredAuthReq := oidcPendingAuth{ + State: util.GenerateOpaqueString(), + Nonce: util.GenerateOpaqueString(), + Audience: tokenAudienceWebClient, + IssuedAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-10 * time.Minute)), + } + oidcMgr.addPendingAuth(expiredAuthReq) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+expiredAuthReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), util.I18nInvalidAuth) + oidcMgr.removePendingAuth(expiredAuthReq.State) + + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + err: common.ErrGenericFailure, + } + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: common.ErrGenericFailure, + } + + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminOIDCLoginPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 1) + var state string + for k := range oidcMgr.pendingAuths { + state = k + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminLoginPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + // now the same for the web client + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientOIDCLoginPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webOIDCRedirectPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 1) + for k := range oidcMgr.pendingAuths { + state = k + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+state, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientLoginPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + // now return an OAuth2 token without the id_token + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + token: &oauth2.Token{ + AccessToken: "123", + Expiry: time.Now().Add(5 * time.Minute), + }, + err: nil, + } + authReq := newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // now fail to verify the id token + token := &oauth2.Token{ + AccessToken: "123", + Expiry: time.Now().Add(5 * time.Minute), + } + token = token.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + token: token, + err: nil, + } + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // id token nonce does not match + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: &oidc.IDToken{}, + } + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // null id token claims + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: &oidc.IDToken{ + Nonce: authReq.Nonce, + }, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // invalid id token claims: no username + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken := &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // invalid id token clamims: username not a string + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"aud": "my_client_id","preferred_username": 1}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // invalid audience + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // invalid audience + authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"test"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // mapped user not found + authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"test","sftpgo_role":"admin"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + // admin login ok + authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sftpgo_role":"admin","sid":"sid123"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webUsersPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 1) + // admin profile is not available + var tokenCookie string + for k := range oidcMgr.tokens { + tokenCookie = k + } + oidcToken, err := oidcMgr.getToken(tokenCookie) + assert.NoError(t, err) + assert.Equal(t, "sid123", oidcToken.SessionID) + assert.True(t, oidcToken.isAdmin()) + assert.False(t, oidcToken.isExpired()) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webAdminProfilePath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusForbidden, rr.Code) + // the admin can access the allowed pages + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + // try with an invalid cookie + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String())) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + // Web Client is not available with an admin token + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + // logout the admin user + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 0) + // now login and logout a user + username := "test_oidc_user" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "pwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + WebClient: []string{sdk.WebClientSharesDisabled}, + }, + }, + } + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_user"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 1) + // user profile is not available + for k := range oidcMgr.tokens { + tokenCookie = k + } + oidcToken, err = oidcMgr.getToken(tokenCookie) + assert.NoError(t, err) + assert.Empty(t, oidcToken.SessionID) + assert.False(t, oidcToken.isAdmin()) + assert.False(t, oidcToken.isExpired()) + if assert.Len(t, oidcToken.Permissions, 1) { + assert.Equal(t, sdk.WebClientSharesDisabled, oidcToken.Permissions[0]) + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientProfilePath, nil) + assert.NoError(t, err) + r.RequestURI = webClientProfilePath + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + // the user can access the allowed pages + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusOK, rr.Code) + // try with an invalid cookie + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, xid.New().String())) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + // Web Admin is not available with a client cookie + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + // logout the user + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 0) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + + tokenValidationMode = 0 +} + +func TestOIDCRefreshToken(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + r, err := http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + token := oidcToken{ + Cookie: util.GenerateOpaqueString(), + AccessToken: xid.New().String(), + TokenType: "Bearer", + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-1 * time.Minute)), + Nonce: xid.New().String(), + Role: adminRoleFieldValue, + Username: defaultAdminUsername, + } + config := mockOAuth2Config{ + tokenSource: &mockTokenSource{ + err: common.ErrGenericFailure, + }, + } + verifier := mockOIDCVerifier{ + err: common.ErrGenericFailure, + } + err = token.refresh(context.Background(), &config, &verifier, r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "refresh token not set") + } + token.RefreshToken = xid.New().String() + err = token.refresh(context.Background(), &config, &verifier, r) + assert.ErrorIs(t, err, common.ErrGenericFailure) + + newToken := &oauth2.Token{ + AccessToken: xid.New().String(), + RefreshToken: xid.New().String(), + Expiry: time.Now().Add(5 * time.Minute), + } + config = mockOAuth2Config{ + tokenSource: &mockTokenSource{ + token: newToken, + }, + } + verifier = mockOIDCVerifier{ + token: &oidc.IDToken{}, + } + err = token.refresh(context.Background(), &config, &verifier, r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "the refreshed token has no id token") + } + newToken = newToken.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + newToken.Expiry = time.Time{} + config = mockOAuth2Config{ + tokenSource: &mockTokenSource{ + token: newToken, + }, + } + verifier = mockOIDCVerifier{ + err: common.ErrGenericFailure, + } + err = token.refresh(context.Background(), &config, &verifier, r) + assert.ErrorIs(t, err, common.ErrGenericFailure) + + newToken = newToken.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + newToken.Expiry = time.Now().Add(5 * time.Minute) + config = mockOAuth2Config{ + tokenSource: &mockTokenSource{ + token: newToken, + }, + } + verifier = mockOIDCVerifier{ + token: &oidc.IDToken{ + Nonce: xid.New().String(), // nonce is different from the expected one + }, + } + err = token.refresh(context.Background(), &config, &verifier, r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "the refreshed token nonce mismatch") + } + verifier = mockOIDCVerifier{ + token: &oidc.IDToken{ + Nonce: "", // empty token is fine on refresh but claims are not set + }, + } + err = token.refresh(context.Background(), &config, &verifier, r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "oidc: claims not set") + } + idToken := &oidc.IDToken{ + Nonce: token.Nonce, + } + setIDTokenClaims(idToken, []byte(`{"sid":"id_token_sid"}`)) + verifier = mockOIDCVerifier{ + token: idToken, + } + err = token.refresh(context.Background(), &config, &verifier, r) + assert.NoError(t, err) + assert.Len(t, token.Permissions, 1) + token.Role = nil + // user does not exist + err = token.refresh(context.Background(), &config, &verifier, r) + assert.Error(t, err) + require.Len(t, oidcMgr.tokens, 1) + oidcMgr.removeToken(token.Cookie) + require.Len(t, oidcMgr.tokens, 0) +} + +func TestOIDCRefreshUser(t *testing.T) { + token := oidcToken{ + Cookie: util.GenerateOpaqueString(), + AccessToken: xid.New().String(), + TokenType: "Bearer", + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(1 * time.Minute)), + Nonce: xid.New().String(), + Role: adminRoleFieldValue, + Username: "missing username", + } + r, err := http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + err = token.refreshUser(r) + assert.Error(t, err) + admin := dataprovider.Admin{ + Username: "test_oidc_admin_refresh", + Password: "p", + Permissions: []string{dataprovider.PermAdminAny}, + Status: 0, + Filters: dataprovider.AdminFilters{ + Preferences: dataprovider.AdminPreferences{ + HideUserPageSections: 1 + 2 + 4, + }, + }, + } + err = dataprovider.AddAdmin(&admin, "", "", "") + assert.NoError(t, err) + + token.Username = admin.Username + err = token.refreshUser(r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is disabled") + } + + admin.Status = 1 + err = dataprovider.UpdateAdmin(&admin, "", "", "") + assert.NoError(t, err) + err = token.refreshUser(r) + assert.NoError(t, err) + assert.Equal(t, admin.Permissions, token.Permissions) + assert.Equal(t, admin.Filters.Preferences.HideUserPageSections, token.HideUserPageSections) + + err = dataprovider.DeleteAdmin(admin.Username, "", "", "") + assert.NoError(t, err) + + username := "test_oidc_user_refresh_token" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "p", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 0, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + DeniedProtocols: []string{common.ProtocolHTTP}, + WebClient: []string{sdk.WebClientSharesDisabled, sdk.WebClientWriteDisabled}, + }, + }, + } + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + token.Role = nil + token.Username = username + assert.False(t, token.isAdmin()) + err = token.refreshUser(r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is disabled") + } + user, err = dataprovider.UserExists(username, "") + assert.NoError(t, err) + user.Status = 1 + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + err = token.refreshUser(r) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "protocol HTTP is not allowed") + } + + user.Filters.DeniedProtocols = []string{common.ProtocolFTP} + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + err = token.refreshUser(r) + assert.NoError(t, err) + assert.Equal(t, user.Filters.WebClient, token.Permissions) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) +} + +func TestValidateOIDCToken(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + server := getTestOIDCServer() + err := server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + assert.NoError(t, err) + _, err = server.validateOIDCToken(rr, r, false) + assert.ErrorIs(t, err, errInvalidToken) + // expired token and refresh error + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{ + err: common.ErrGenericFailure, + }, + } + token := oidcToken{ + Cookie: util.GenerateOpaqueString(), + AccessToken: xid.New().String(), + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)), + } + oidcMgr.addToken(token) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie)) + _, err = server.validateOIDCToken(rr, r, false) + assert.ErrorIs(t, err, errInvalidToken) + oidcMgr.removeToken(token.Cookie) + assert.Len(t, oidcMgr.tokens, 0) + + server.tokenAuth.SetSigner(&failingJoseSigner{}) + token = oidcToken{ + Cookie: util.GenerateOpaqueString(), + AccessToken: util.GenerateUniqueID(), + } + oidcMgr.addToken(token) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + oidcMgr.removeToken(token.Cookie) + assert.Len(t, oidcMgr.tokens, 0) + + token = oidcToken{ + Cookie: util.GenerateOpaqueString(), + AccessToken: xid.New().String(), + Role: "admin", + } + oidcMgr.addToken(token) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, token.Cookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + oidcMgr.removeToken(token.Cookie) + assert.Len(t, oidcMgr.tokens, 0) +} + +func TestSkipOIDCAuth(t *testing.T) { + server := getTestOIDCServer() + err := server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + + claims := jwt.NewClaims(tokenAudienceWebClient, "", getTokenDuration(tokenAudienceWebClient)) + claims.Username = "user" + tokenString, err := server.tokenAuth.Sign(claims) + assert.NoError(t, err) + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", jwt.CookieKey, tokenString)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) +} + +func TestOIDCLogoutErrors(t *testing.T) { + server := getTestOIDCServer() + assert.Empty(t, server.binding.OIDC.providerLogoutURL) + server.logoutFromOIDCOP("") + server.binding.OIDC.providerLogoutURL = "http://foo\x7f.com/" + server.doOIDCFromLogout("") + server.binding.OIDC.providerLogoutURL = "http://127.0.0.1:11234" + server.doOIDCFromLogout("") +} + +func TestOIDCToken(t *testing.T) { + admin := dataprovider.Admin{ + Username: "test_oidc_admin", + Password: "p", + Permissions: []string{dataprovider.PermAdminAny}, + Status: 0, + } + err := dataprovider.AddAdmin(&admin, "", "", "") + assert.NoError(t, err) + + token := oidcToken{ + Username: admin.Username, + } + // role not initialized, user with the specified username does not exist + req, err := http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + err = token.getUser(req) + assert.ErrorIs(t, err, util.ErrNotFound) + token.Role = "admin" + req, err = http.NewRequest(http.MethodGet, webUsersPath, nil) + assert.NoError(t, err) + err = token.getUser(req) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is disabled") + } + err = dataprovider.DeleteAdmin(admin.Username, "", "", "") + assert.NoError(t, err) + + username := "test_oidc_user" + token.Username = username + token.Role = "" + err = token.getUser(req) + if assert.Error(t, err) { + assert.ErrorIs(t, err, util.ErrNotFound) + } + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "p", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 0, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + DeniedProtocols: []string{common.ProtocolHTTP}, + DeniedLoginMethods: []string{dataprovider.LoginMethodPassword}, + }, + }, + } + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + err = token.getUser(req) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is disabled") + } + user, err = dataprovider.UserExists(username, "") + assert.NoError(t, err) + user.Status = 1 + user.Password = "np" + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + err = token.getUser(req) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "protocol HTTP is not allowed") + } + + user.Filters.DeniedProtocols = nil + user.FsConfig.Provider = sdk.SFTPFilesystemProvider + user.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: "127.0.0.1:8022", + Username: username, + }, + Password: kms.NewPlainSecret("np"), + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + err = token.getUser(req) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SFTP loop") + } + + common.Config.PostConnectHook = fmt.Sprintf("http://%v/404", oidcMockAddr) + + err = token.getUser(req) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "access denied") + } + + common.Config.PostConnectHook = "" + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) +} + +func TestOIDCImplicitRoles(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + + server := getTestOIDCServer() + server.binding.OIDC.ImplicitRoles = true + err := server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + + authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + token := &oauth2.Token{ + AccessToken: "1234", + Expiry: time.Now().Add(5 * time.Minute), + } + token = token.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + token: token, + } + idToken := &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webUsersPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 1) + var tokenCookie string + for k := range oidcMgr.tokens { + tokenCookie = k + } + // Web Client is not available with an admin token + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientFilesPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + // logout the admin user + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webAdminLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 0) + // now login and logout a user + username := "test_oidc_implicit_user" + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: "pwd", + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + WebClient: []string{sdk.WebClientSharesDisabled}, + }, + }, + } + err = dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"test_oidc_implicit_user"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 1) + for k := range oidcMgr.tokens { + tokenCookie = k + } + + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webClientLogoutPath, nil) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 0) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) +} + +func TestMemoryOIDCManager(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + require.Len(t, oidcMgr.pendingAuths, 0) + authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + require.Len(t, oidcMgr.pendingAuths, 1) + _, err := oidcMgr.getPendingAuth(authReq.State) + assert.NoError(t, err) + oidcMgr.removePendingAuth(authReq.State) + require.Len(t, oidcMgr.pendingAuths, 0) + authReq.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-600 * time.Second)) + oidcMgr.addPendingAuth(authReq) + require.Len(t, oidcMgr.pendingAuths, 1) + _, err = oidcMgr.getPendingAuth(authReq.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "too old") + } + oidcMgr.cleanup() + require.Len(t, oidcMgr.pendingAuths, 0) + + token := oidcToken{ + AccessToken: xid.New().String(), + Nonce: xid.New().String(), + SessionID: xid.New().String(), + Cookie: util.GenerateOpaqueString(), + Username: xid.New().String(), + Role: "admin", + Permissions: []string{dataprovider.PermAdminAny}, + } + require.Len(t, oidcMgr.tokens, 0) + oidcMgr.addToken(token) + require.Len(t, oidcMgr.tokens, 1) + _, err = oidcMgr.getToken(xid.New().String()) + assert.Error(t, err) + storedToken, err := oidcMgr.getToken(token.Cookie) + assert.NoError(t, err) + token.UsedAt = 0 // ensure we don't modify the stored token + assert.Greater(t, storedToken.UsedAt, int64(0)) + token.UsedAt = storedToken.UsedAt + assert.Equal(t, token, storedToken) + // the usage will not be updated, it is recent + oidcMgr.updateTokenUsage(storedToken) + storedToken, err = oidcMgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Equal(t, token, storedToken) + usedAt := util.GetTimeAsMsSinceEpoch(time.Now().Add(-5 * time.Minute)) + storedToken.UsedAt = usedAt + oidcMgr.tokens[token.Cookie] = storedToken + storedToken, err = oidcMgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Equal(t, usedAt, storedToken.UsedAt) + token.UsedAt = storedToken.UsedAt + assert.Equal(t, token, storedToken) + oidcMgr.updateTokenUsage(storedToken) + storedToken, err = oidcMgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Greater(t, storedToken.UsedAt, usedAt) + token.UsedAt = storedToken.UsedAt + assert.Equal(t, token, storedToken) + storedToken.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) - tokenDeleteInterval - 1 + oidcMgr.tokens[token.Cookie] = storedToken + storedToken, err = oidcMgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "token is too old") + } + oidcMgr.removeToken(xid.New().String()) + require.Len(t, oidcMgr.tokens, 1) + oidcMgr.removeToken(token.Cookie) + require.Len(t, oidcMgr.tokens, 0) + oidcMgr.addToken(token) + usedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-6 * time.Hour)) + token.UsedAt = usedAt + oidcMgr.tokens[token.Cookie] = token + newToken := oidcToken{ + Cookie: util.GenerateOpaqueString(), + } + oidcMgr.addToken(newToken) + oidcMgr.cleanup() + require.Len(t, oidcMgr.tokens, 1) + _, err = oidcMgr.getToken(token.Cookie) + assert.Error(t, err) + _, err = oidcMgr.getToken(newToken.Cookie) + assert.NoError(t, err) + oidcMgr.removeToken(newToken.Cookie) + require.Len(t, oidcMgr.tokens, 0) +} + +func TestOIDCEvMgrIntegration(t *testing.T) { + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + newProviderConf := providerConf + newProviderConf.NamingRules = 5 + err = dataprovider.Initialize(newProviderConf, configDir, true) + assert.NoError(t, err) + // add a special chars to check json replacer + username := `test_'oidc_eventmanager` + u := map[string]any{ + "username": "{{.Name}}", + "status": 1, + "home_dir": filepath.Join(os.TempDir(), "{{.IDPFieldcustom1.sub}}"), + "permissions": map[string][]string{ + "/": {dataprovider.PermAny}, + }, + "description": "{{.IDPFieldcustom2}}", + } + userTmpl, err := json.Marshal(u) + require.NoError(t, err) + a := map[string]any{ + "username": "{{.Name}}", + "status": 1, + "permissions": []string{dataprovider.PermAdminAny}, + } + adminTmpl, err := json.Marshal(a) + require.NoError(t, err) + + action := &dataprovider.BaseEventAction{ + Name: "a", + Type: dataprovider.ActionTypeIDPAccountCheck, + Options: dataprovider.BaseEventActionOptions{ + IDPConfig: dataprovider.EventActionIDPAccountCheck{ + Mode: 0, + TemplateUser: string(userTmpl), + TemplateAdmin: string(adminTmpl), + }, + }, + } + err = dataprovider.AddEventAction(action, "", "", "") + assert.NoError(t, err) + rule := &dataprovider.EventRule{ + Name: "r", + Status: 1, + Trigger: dataprovider.EventTriggerIDPLogin, + Conditions: dataprovider.EventConditions{ + IDPLoginEvent: 0, + }, + Actions: []dataprovider.EventAction{ + { + BaseEventAction: dataprovider.BaseEventAction{ + Name: action.Name, + }, + Options: dataprovider.EventActionOptions{ + ExecuteSync: true, + }, + }, + }, + } + err = dataprovider.AddEventRule(rule, "", "", "") + assert.NoError(t, err) + + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + server := getTestOIDCServer() + server.binding.OIDC.ImplicitRoles = true + server.binding.OIDC.CustomFields = []string{"custom1.sub", "custom2"} + err = server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + // login a user with OIDC + _, err = dataprovider.UserExists(username, "") + assert.ErrorIs(t, err, util.ErrNotFound) + authReq := newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + token := &oauth2.Token{ + AccessToken: "1234", + Expiry: time.Now().Add(5 * time.Minute), + } + token = token.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + token: token, + } + idToken := &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`","custom1":{"sub":"val1"},"custom2":"desc"}`)) //nolint:goconst + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + user, err := dataprovider.UserExists(username, "") + assert.NoError(t, err) + assert.Equal(t, filepath.Join(os.TempDir(), "val1"), user.GetHomeDir()) + assert.Equal(t, "desc", user.Description) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + // login an admin with OIDC + _, err = dataprovider.AdminExists(username) + assert.ErrorIs(t, err, util.ErrNotFound) + authReq = newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webUsersPath, rr.Header().Get("Location")) + + _, err = dataprovider.AdminExists(username) + assert.NoError(t, err) + err = dataprovider.DeleteAdmin(username, "", "", "") + assert.NoError(t, err) + // set invalid templates and try again + action.Options.IDPConfig.TemplateUser = `{}` + action.Options.IDPConfig.TemplateAdmin = `{}` + err = dataprovider.UpdateEventAction(action, "", "", "") + assert.NoError(t, err) + + for _, audience := range []string{tokenAudienceWebAdmin, tokenAudienceWebClient} { + authReq = newOIDCPendingAuth(audience) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+util.JSONEscape(username)+`"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + } + for k := range oidcMgr.tokens { + oidcMgr.removeToken(k) + } + + err = dataprovider.DeleteEventRule(rule.Name, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteEventAction(action.Name, "", "", "") + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestOIDCPreLoginHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + username := "test_oidc_user_prelogin" + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + preLoginPath := filepath.Join(os.TempDir(), "prelogin.sh") + providerConf := dataprovider.GetProviderConfig() + err := dataprovider.Close() + assert.NoError(t, err) + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + newProviderConf := providerConf + newProviderConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(newProviderConf, configDir, true) + assert.NoError(t, err) + server := getTestOIDCServer() + server.binding.OIDC.CustomFields = []string{"field1", "field2"} + err = server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + + _, err = dataprovider.UserExists(username, "") + assert.ErrorIs(t, err, util.ErrNotFound) + // now login with OIDC + authReq := newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + token := &oauth2.Token{ + AccessToken: "1234", + Expiry: time.Now().Add(5 * time.Minute), + } + token = token.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + token: token, + } + idToken := &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientFilesPath, rr.Header().Get("Location")) + _, err = dataprovider.UserExists(username, "") + assert.NoError(t, err) + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(u.HomeDir) + assert.NoError(t, err) + + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, true), os.ModePerm) + assert.NoError(t, err) + + authReq = newOIDCPendingAuth(tokenAudienceWebClient) + oidcMgr.addPendingAuth(authReq) + idToken = &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"`+username+`","field1":"value1","field2":"value2","field3":"value3"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webClientLoginPath, rr.Header().Get("Location")) + _, err = dataprovider.UserExists(username, "") + assert.ErrorIs(t, err, util.ErrNotFound) + if assert.Len(t, oidcMgr.tokens, 1) { + for k := range oidcMgr.tokens { + oidcMgr.removeToken(k) + } + } + require.Len(t, oidcMgr.pendingAuths, 0) + require.Len(t, oidcMgr.tokens, 0) + + err = dataprovider.Close() + assert.NoError(t, err) + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestOIDCIsAdmin(t *testing.T) { + type test struct { + input any + want bool + } + + emptySlice := make([]any, 0) + + tests := []test{ + {input: "admin", want: true}, + {input: append(emptySlice, "admin"), want: true}, + {input: append(emptySlice, "user", "admin"), want: true}, + {input: "user", want: false}, + {input: emptySlice, want: false}, + {input: append(emptySlice, 1), want: false}, + {input: 1, want: false}, + {input: nil, want: false}, + {input: map[string]string{"admin": "admin"}, want: false}, + } + for _, tc := range tests { + token := oidcToken{ + Role: tc.input, + } + assert.Equal(t, tc.want, token.isAdmin(), "%v should return %t", tc.input, tc.want) + } +} + +func TestParseAdminRole(t *testing.T) { + claims := make(map[string]any) + rawClaims := []byte(`{ + "sub": "35666371", + "email": "example@example.com", + "preferred_username": "Sally", + "name": "Sally Tyler", + "updated_at": "2018-04-13T22:08:45Z", + "given_name": "Sally", + "family_name": "Tyler", + "params": { + "sftpgo_role": "admin", + "subparams": { + "sftpgo_role": "admin", + "inner": { + "sftpgo_role": ["user","admin"] + } + } + }, + "at_hash": "lPLhxI2wjEndc-WfyroDZA", + "rt_hash": "mCmxPtA04N-55AxlEUbq-A", + "aud": "78d1d040-20c9-0136-5146-067351775fae92920", + "exp": 1523664997, + "iat": 1523657797 + }`) + err := json.Unmarshal(rawClaims, &claims) + assert.NoError(t, err) + + type test struct { + input string + want bool + val any + } + + tests := []test{ + {input: "", want: false}, + {input: "sftpgo_role", want: false}, + {input: "params.sftpgo_role", want: true, val: "admin"}, + {input: "params.subparams.sftpgo_role", want: true, val: "admin"}, + {input: "params.subparams.inner.sftpgo_role", want: true, val: []any{"user", "admin"}}, + {input: "email", want: false}, + {input: "missing", want: false}, + {input: "params.email", want: false}, + {input: "missing.sftpgo_role", want: false}, + {input: "params", want: false}, + {input: "params.subparams.inner.sftpgo_role.missing", want: false}, + } + + for _, tc := range tests { + token := oidcToken{} + token.getRoleFromField(claims, tc.input) + assert.Equal(t, tc.want, token.isAdmin(), "%q should return %t", tc.input, tc.want) + if tc.want { + assert.Equal(t, tc.val, token.Role) + } + } +} + +func TestOIDCWithLoginFormsDisabled(t *testing.T) { + oidcMgr, ok := oidcMgr.(*memoryOIDCManager) + require.True(t, ok) + + server := getTestOIDCServer() + server.binding.OIDC.ImplicitRoles = true + server.binding.DisabledLoginMethods = 12 + server.binding.EnableWebAdmin = true + server.binding.EnableWebClient = true + err := server.binding.OIDC.initialize() + assert.NoError(t, err) + err = server.initializeRouter() + require.NoError(t, err) + // login with an admin user + authReq := newOIDCPendingAuth(tokenAudienceWebAdmin) + oidcMgr.addPendingAuth(authReq) + token := &oauth2.Token{ + AccessToken: "1234", + Expiry: time.Now().Add(5 * time.Minute), + } + token = token.WithExtra(map[string]any{ + "id_token": "id_token_val", + }) + server.binding.OIDC.oauth2Config = &mockOAuth2Config{ + tokenSource: &mockTokenSource{}, + authCodeURL: webOIDCRedirectPath, + token: token, + } + idToken := &oidc.IDToken{ + Nonce: authReq.Nonce, + Expiry: time.Now().Add(5 * time.Minute), + } + setIDTokenClaims(idToken, []byte(`{"preferred_username":"admin","sid":"sid456"}`)) + server.binding.OIDC.verifier = &mockOIDCVerifier{ + err: nil, + token: idToken, + } + rr := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, webOIDCRedirectPath+"?state="+authReq.State, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusFound, rr.Code) + assert.Equal(t, webUsersPath, rr.Header().Get("Location")) + var tokenCookie string + for k := range oidcMgr.tokens { + tokenCookie = k + } + // we should be able to create admins without setting a password + adminUsername := "testAdmin" + form := make(url.Values) + form.Set(csrfFormToken, createCSRFToken(rr, r, server.csrfTokenAuth, tokenCookie, webBaseAdminPath)) + form.Set("username", adminUsername) + form.Set("password", "") + form.Set("status", "1") + form.Set("permissions", "*") + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminPath, bytes.NewBuffer([]byte(form.Encode()))) + assert.NoError(t, err) + r.Header.Set("Cookie", fmt.Sprintf("%v=%v", oidcCookieKey, tokenCookie)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusSeeOther, rr.Code) + _, err = dataprovider.AdminExists(adminUsername) + assert.NoError(t, err) + err = dataprovider.DeleteAdmin(adminUsername, "", "", "") + assert.NoError(t, err) + // login and password related routes are disabled + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminLoginPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webAdminTwoFactorPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusNotFound, rr.Code) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webClientLoginPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusMethodNotAllowed, rr.Code) + rr = httptest.NewRecorder() + r, err = http.NewRequest(http.MethodPost, webClientForgotPwdPath, nil) + assert.NoError(t, err) + server.router.ServeHTTP(rr, r) + assert.Equal(t, http.StatusNotFound, rr.Code) +} + +func TestDbOIDCManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newOIDCManager(1) + pendingAuth := newOIDCPendingAuth(tokenAudienceWebAdmin) + mgr.addPendingAuth(pendingAuth) + authReq, err := mgr.getPendingAuth(pendingAuth.State) + assert.NoError(t, err) + assert.Equal(t, pendingAuth, authReq) + pendingAuth.IssuedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + mgr.addPendingAuth(pendingAuth) + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "auth request is too old") + } + mgr.removePendingAuth(pendingAuth.State) + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the auth request for the specified state") + } + mgr.addPendingAuth(pendingAuth) + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "auth request is too old") + } + mgr.cleanup() + _, err = mgr.getPendingAuth(pendingAuth.State) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the auth request for the specified state") + } + + token := oidcToken{ + Cookie: util.GenerateOpaqueString(), + AccessToken: xid.New().String(), + TokenType: "Bearer", + RefreshToken: xid.New().String(), + ExpiresAt: util.GetTimeAsMsSinceEpoch(time.Now().Add(-2 * time.Minute)), + SessionID: xid.New().String(), + IDToken: xid.New().String(), + Nonce: xid.New().String(), + Username: xid.New().String(), + Permissions: []string{dataprovider.PermAdminAny}, + Role: "admin", + } + mgr.addToken(token) + tokenGet, err := mgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Greater(t, tokenGet.UsedAt, int64(0)) + token.UsedAt = tokenGet.UsedAt + assert.Equal(t, token, tokenGet) + time.Sleep(100 * time.Millisecond) + mgr.updateTokenUsage(token) + // no change + tokenGet, err = mgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.Equal(t, token.UsedAt, tokenGet.UsedAt) + tokenGet.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + tokenGet.RefreshToken = xid.New().String() + mgr.updateTokenUsage(tokenGet) + tokenGet, err = mgr.getToken(token.Cookie) + assert.NoError(t, err) + assert.NotEmpty(t, tokenGet.RefreshToken) + assert.NotEqual(t, token.RefreshToken, tokenGet.RefreshToken) + assert.Greater(t, tokenGet.UsedAt, token.UsedAt) + mgr.removeToken(token.Cookie) + tokenGet, err = mgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the token for the specified session") + } + // add an expired token + token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(-24 * time.Hour)) + session := dataprovider.Session{ + Key: token.Cookie, + Data: token, + Type: dataprovider.SessionTypeOIDCToken, + Timestamp: token.UsedAt + tokenDeleteInterval, + } + err = dataprovider.AddSharedSession(session) + assert.NoError(t, err) + _, err = mgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "token is too old") + } + mgr.cleanup() + _, err = mgr.getToken(token.Cookie) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to get the token for the specified session") + } + // adding a session without a key should fail + session.Key = "" + err = dataprovider.AddSharedSession(session) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to save a session with an empty key") + } + session.Key = xid.New().String() + session.Type = 1000 + err = dataprovider.AddSharedSession(session) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "invalid session type") + } + + dbMgr, ok := mgr.(*dbOIDCManager) + if assert.True(t, ok) { + _, err = dbMgr.decodePendingAuthData(2) + assert.Error(t, err) + _, err = dbMgr.decodeTokenData(true) + assert.Error(t, err) + } +} + +func getTestOIDCServer() *httpdServer { + return &httpdServer{ + binding: Binding{ + OIDC: OIDC{ + ClientID: "sftpgo-client", + ClientSecret: "jRsmE0SWnuZjP7djBqNq0mrf8QN77j2c", + ConfigURL: fmt.Sprintf("http://%v/auth/realms/sftpgo", oidcMockAddr), + RedirectBaseURL: "http://127.0.0.1:8081/", + UsernameField: "preferred_username", + RoleField: "sftpgo_role", + ImplicitRoles: false, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + CustomFields: nil, + Debug: true, + }, + }, + enableWebAdmin: true, + enableWebClient: true, + } +} + +func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { + content := []byte("#!/bin/sh\n\n") + if nonJSONResponse { + content = append(content, []byte("echo 'text response'\n")...) + return content + } + if len(user.Username) > 0 { + u, _ := json.Marshal(user) + content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) + } + return content +} diff --git a/internal/httpd/oidcmanager.go b/internal/httpd/oidcmanager.go new file mode 100644 index 00000000..79336748 --- /dev/null +++ b/internal/httpd/oidcmanager.go @@ -0,0 +1,242 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/json" + "errors" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + oidcMgr oidcManager +) + +func newOIDCManager(isShared int) oidcManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider OIDC manager") + return &dbOIDCManager{} + } + logger.Info(logSender, "", "using memory OIDC manager") + return &memoryOIDCManager{ + pendingAuths: make(map[string]oidcPendingAuth), + tokens: make(map[string]oidcToken), + } +} + +type oidcManager interface { + addPendingAuth(pendingAuth oidcPendingAuth) + removePendingAuth(state string) + getPendingAuth(state string) (oidcPendingAuth, error) + addToken(token oidcToken) + getToken(cookie string) (oidcToken, error) + removeToken(cookie string) + updateTokenUsage(token oidcToken) + cleanup() +} + +type memoryOIDCManager struct { + authMutex sync.RWMutex + pendingAuths map[string]oidcPendingAuth + tokenMutex sync.RWMutex + tokens map[string]oidcToken +} + +func (o *memoryOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { + o.authMutex.Lock() + o.pendingAuths[pendingAuth.State] = pendingAuth + o.authMutex.Unlock() +} + +func (o *memoryOIDCManager) removePendingAuth(state string) { + o.authMutex.Lock() + defer o.authMutex.Unlock() + + delete(o.pendingAuths, state) +} + +func (o *memoryOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { + o.authMutex.RLock() + defer o.authMutex.RUnlock() + + authReq, ok := o.pendingAuths[state] + if !ok { + return oidcPendingAuth{}, errors.New("oidc: no auth request found for the specified state") + } + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - authReq.IssuedAt + if diff > authStateValidity { + return oidcPendingAuth{}, errors.New("oidc: auth request is too old") + } + return authReq, nil +} + +func (o *memoryOIDCManager) addToken(token oidcToken) { + o.tokenMutex.Lock() + token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + o.tokens[token.Cookie] = token + o.tokenMutex.Unlock() +} + +func (o *memoryOIDCManager) getToken(cookie string) (oidcToken, error) { + o.tokenMutex.RLock() + defer o.tokenMutex.RUnlock() + + token, ok := o.tokens[cookie] + if !ok { + return oidcToken{}, errors.New("oidc: no token found for the specified session") + } + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + if diff > tokenDeleteInterval { + return oidcToken{}, errors.New("oidc: token is too old") + } + return token, nil +} + +func (o *memoryOIDCManager) removeToken(cookie string) { + o.tokenMutex.Lock() + defer o.tokenMutex.Unlock() + + delete(o.tokens, cookie) +} + +func (o *memoryOIDCManager) updateTokenUsage(token oidcToken) { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + if diff > tokenUpdateInterval { + o.addToken(token) + } +} + +func (o *memoryOIDCManager) cleanup() { + o.cleanupAuthRequests() + o.cleanupTokens() +} + +func (o *memoryOIDCManager) cleanupAuthRequests() { + o.authMutex.Lock() + defer o.authMutex.Unlock() + + for k, auth := range o.pendingAuths { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - auth.IssuedAt + // remove old pending auth requests + if diff < 0 || diff > authStateValidity { + delete(o.pendingAuths, k) + } + } +} + +func (o *memoryOIDCManager) cleanupTokens() { + o.tokenMutex.Lock() + defer o.tokenMutex.Unlock() + + for k, token := range o.tokens { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + // remove tokens unused from more than tokenDeleteInterval + if diff > tokenDeleteInterval { + delete(o.tokens, k) + } + } +} + +type dbOIDCManager struct{} + +func (o *dbOIDCManager) addPendingAuth(pendingAuth oidcPendingAuth) { + session := dataprovider.Session{ + Key: pendingAuth.State, + Data: pendingAuth, + Type: dataprovider.SessionTypeOIDCAuth, + Timestamp: pendingAuth.IssuedAt + authStateValidity, + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (o *dbOIDCManager) removePendingAuth(state string) { + dataprovider.DeleteSharedSession(state, dataprovider.SessionTypeOIDCAuth) //nolint:errcheck +} + +func (o *dbOIDCManager) getPendingAuth(state string) (oidcPendingAuth, error) { + session, err := dataprovider.GetSharedSession(state, dataprovider.SessionTypeOIDCAuth) + if err != nil { + return oidcPendingAuth{}, errors.New("oidc: unable to get the auth request for the specified state") + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return oidcPendingAuth{}, errors.New("oidc: auth request is too old") + } + return o.decodePendingAuthData(session.Data) +} + +func (o *dbOIDCManager) decodePendingAuthData(data any) (oidcPendingAuth, error) { + if val, ok := data.([]byte); ok { + authReq := oidcPendingAuth{} + err := json.Unmarshal(val, &authReq) + return authReq, err + } + logger.Error(logSender, "", "invalid oidc auth request data type %T", data) + return oidcPendingAuth{}, errors.New("oidc: invalid auth request data") +} + +func (o *dbOIDCManager) addToken(token oidcToken) { + token.UsedAt = util.GetTimeAsMsSinceEpoch(time.Now()) + session := dataprovider.Session{ + Key: token.Cookie, + Data: token, + Type: dataprovider.SessionTypeOIDCToken, + Timestamp: token.UsedAt + tokenDeleteInterval, + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (o *dbOIDCManager) removeToken(cookie string) { + dataprovider.DeleteSharedSession(cookie, dataprovider.SessionTypeOIDCToken) //nolint:errcheck +} + +func (o *dbOIDCManager) updateTokenUsage(token oidcToken) { + diff := util.GetTimeAsMsSinceEpoch(time.Now()) - token.UsedAt + if diff > tokenUpdateInterval { + o.addToken(token) + } +} + +func (o *dbOIDCManager) getToken(cookie string) (oidcToken, error) { + session, err := dataprovider.GetSharedSession(cookie, dataprovider.SessionTypeOIDCToken) + if err != nil { + return oidcToken{}, errors.New("oidc: unable to get the token for the specified session") + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return oidcToken{}, errors.New("oidc: token is too old") + } + return o.decodeTokenData(session.Data) +} + +func (o *dbOIDCManager) decodeTokenData(data any) (oidcToken, error) { + if val, ok := data.([]byte); ok { + token := oidcToken{} + err := json.Unmarshal(val, &token) + return token, err + } + logger.Error(logSender, "", "invalid oidc token data type %T", data) + return oidcToken{}, errors.New("oidc: invalid token data") +} + +func (o *dbOIDCManager) cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCAuth, time.Now()) //nolint:errcheck + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeOIDCToken, time.Now()) //nolint:errcheck +} diff --git a/internal/httpd/resetcode.go b/internal/httpd/resetcode.go new file mode 100644 index 00000000..0be7d890 --- /dev/null +++ b/internal/httpd/resetcode.go @@ -0,0 +1,140 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/json" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + resetCodeLifespan = 10 * time.Minute + resetCodesMgr resetCodeManager +) + +type resetCodeManager interface { + Add(code *resetCode) error + Get(code string) (*resetCode, error) + Delete(code string) error + Cleanup() +} + +func newResetCodeManager(isShared int) resetCodeManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider reset code manager") + return &dbResetCodeManager{} + } + logger.Info(logSender, "", "using memory reset code manager") + return &memoryResetCodeManager{} +} + +type resetCode struct { + Code string `json:"code"` + Username string `json:"username"` + IsAdmin bool `json:"is_admin"` + ExpiresAt time.Time `json:"expires_at"` +} + +func newResetCode(username string, isAdmin bool) *resetCode { + return &resetCode{ + Code: util.GenerateUniqueID(), + Username: username, + IsAdmin: isAdmin, + ExpiresAt: time.Now().Add(resetCodeLifespan).UTC(), + } +} + +func (c *resetCode) isExpired() bool { + return c.ExpiresAt.Before(time.Now().UTC()) +} + +type memoryResetCodeManager struct { + resetCodes sync.Map +} + +func (m *memoryResetCodeManager) Add(code *resetCode) error { + m.resetCodes.Store(code.Code, code) + return nil +} + +func (m *memoryResetCodeManager) Get(code string) (*resetCode, error) { + c, ok := m.resetCodes.Load(code) + if !ok { + return nil, util.NewRecordNotFoundError("reset code not found") + } + return c.(*resetCode), nil +} + +func (m *memoryResetCodeManager) Delete(code string) error { + m.resetCodes.Delete(code) + return nil +} + +func (m *memoryResetCodeManager) Cleanup() { + m.resetCodes.Range(func(key, value any) bool { + c, ok := value.(*resetCode) + if !ok || c.isExpired() { + m.resetCodes.Delete(key) + } + return true + }) +} + +type dbResetCodeManager struct{} + +func (m *dbResetCodeManager) Add(code *resetCode) error { + session := dataprovider.Session{ + Key: code.Code, + Data: code, + Type: dataprovider.SessionTypeResetCode, + Timestamp: util.GetTimeAsMsSinceEpoch(code.ExpiresAt), + } + return dataprovider.AddSharedSession(session) +} + +func (m *dbResetCodeManager) Get(code string) (*resetCode, error) { + session, err := dataprovider.GetSharedSession(code, dataprovider.SessionTypeResetCode) + if err != nil { + return nil, err + } + if session.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now()) { + // expired + return nil, util.NewRecordNotFoundError("reset code expired") + } + return m.decodeData(session.Data) +} + +func (m *dbResetCodeManager) decodeData(data any) (*resetCode, error) { + if val, ok := data.([]byte); ok { + c := &resetCode{} + err := json.Unmarshal(val, c) + return c, err + } + logger.Error(logSender, "", "invalid reset code data type %T", data) + return nil, util.NewRecordNotFoundError("invalid reset code") +} + +func (m *dbResetCodeManager) Delete(code string) error { + return dataprovider.DeleteSharedSession(code, dataprovider.SessionTypeResetCode) +} + +func (m *dbResetCodeManager) Cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeResetCode, time.Now()) //nolint:errcheck +} diff --git a/internal/httpd/resources.go b/internal/httpd/resources.go new file mode 100644 index 00000000..54bc58c9 --- /dev/null +++ b/internal/httpd/resources.go @@ -0,0 +1,27 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !bundle + +package httpd + +import ( + "net/http" + + "github.com/go-chi/chi/v5" +) + +func serveStaticDir(router chi.Router, path, fsDirPath string, disableDirectoryIndex bool) { + fileServer(router, path, http.Dir(fsDirPath), disableDirectoryIndex) +} diff --git a/internal/httpd/resources_embedded.go b/internal/httpd/resources_embedded.go new file mode 100644 index 00000000..e15bc985 --- /dev/null +++ b/internal/httpd/resources_embedded.go @@ -0,0 +1,36 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build bundle + +package httpd + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/drakkan/sftpgo/v2/internal/bundle" +) + +func serveStaticDir(router chi.Router, path, fsDirPath string, disableDirectoryIndex bool) { + switch path { + case webStaticFilesPath: + fileServer(router, path, bundle.GetStaticFs(), disableDirectoryIndex) + case webOpenAPIPath: + fileServer(router, path, bundle.GetOpenAPIFs(), disableDirectoryIndex) + default: + fileServer(router, path, http.Dir(fsDirPath), disableDirectoryIndex) + } +} diff --git a/internal/httpd/server.go b/internal/httpd/server.go new file mode 100644 index 00000000..3fa54c1b --- /dev/null +++ b/internal/httpd/server.go @@ -0,0 +1,1920 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "log" + "net" + "net/http" + "net/url" + "path" + "path/filepath" + "slices" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" + "github.com/go-jose/go-jose/v4" + "github.com/rs/cors" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/unrolled/secure" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + jsonAPISuffix = "/json" +) + +var ( + compressor = middleware.NewCompressor(5) + xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") +) + +type httpdServer struct { + binding Binding + staticFilesPath string + openAPIPath string + enableWebAdmin bool + enableWebClient bool + enableRESTAPI bool + renderOpenAPI bool + isShared int + router *chi.Mux + tokenAuth *jwt.Signer + csrfTokenAuth *jwt.Signer + signingPassphrase string + cors CorsConfig +} + +func newHttpdServer(b Binding, staticFilesPath, signingPassphrase string, cors CorsConfig, + openAPIPath string, +) *httpdServer { + if openAPIPath == "" { + b.RenderOpenAPI = false + } + return &httpdServer{ + binding: b, + staticFilesPath: staticFilesPath, + openAPIPath: openAPIPath, + enableWebAdmin: b.EnableWebAdmin, + enableWebClient: b.EnableWebClient, + enableRESTAPI: b.EnableRESTAPI, + renderOpenAPI: b.RenderOpenAPI, + signingPassphrase: signingPassphrase, + cors: cors, + } +} + +func (s *httpdServer) setShared(value int) { + s.isShared = value +} + +func (s *httpdServer) listenAndServe() error { + if err := s.initializeRouter(); err != nil { + return err + } + httpServer := &http.Server{ + Handler: s.router, + ReadHeaderTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 1 << 16, // 64KB + ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), + } + if certMgr != nil && s.binding.EnableHTTPS { + certID := common.DefaultTLSKeyPaidID + if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" { + certID = s.binding.GetAddress() + } + config := &tls.Config{ + GetCertificate: certMgr.GetCertificateFunc(certID), + MinVersion: util.GetTLSVersion(s.binding.MinTLSVersion), + NextProtos: util.GetALPNProtocols(s.binding.Protocols), + CipherSuites: util.GetTLSCiphersFromNames(s.binding.TLSCipherSuites), + } + httpServer.TLSConfig = config + logger.Debug(logSender, "", "configured TLS cipher suites for binding %q: %v, certID: %v", + s.binding.GetAddress(), httpServer.TLSConfig.CipherSuites, certID) + if s.binding.isMutualTLSEnabled() { + httpServer.TLSConfig.ClientCAs = certMgr.GetRootCAs() + httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + httpServer.TLSConfig.VerifyConnection = s.verifyTLSConnection + } + return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, true, + s.binding.listenerWrapper(), logSender) + } + return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, false, + s.binding.listenerWrapper(), logSender) +} + +func (s *httpdServer) verifyTLSConnection(state tls.ConnectionState) error { + if certMgr != nil { + var clientCrt *x509.Certificate + var clientCrtName string + if len(state.PeerCertificates) > 0 { + clientCrt = state.PeerCertificates[0] + clientCrtName = clientCrt.Subject.String() + } + if len(state.VerifiedChains) == 0 { + logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain") + return errors.New("TLS connection cannot be verified: unable to get verification chain") + } + for _, verifiedChain := range state.VerifiedChains { + var caCrt *x509.Certificate + if len(verifiedChain) > 0 { + caCrt = verifiedChain[len(verifiedChain)-1] + } + if certMgr.IsRevoked(clientCrt, caCrt) { + logger.Debug(logSender, "", "tls handshake error, client certificate %q has been revoked", clientCrtName) + return common.ErrCrtRevoked + } + } + } + + return nil +} + +func (s *httpdServer) refreshCookie(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.checkCookieExpiration(w, r) + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) renderClientLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := loginPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18nLoginTitle, + CurrentURL: webClientLoginPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + FormDisabled: s.binding.isWebClientLoginFormDisabled(), + CheckRedirect: true, + } + if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { + data.CurrentURL += "?next=" + url.QueryEscape(next) + } + if s.binding.showAdminLoginURL() { + data.AltLoginURL = webAdminLoginPath + data.AltLoginName = s.binding.webAdminBranding().ShortName + } + if smtp.IsEnabled() && !data.FormDisabled { + data.ForgotPwdURL = webClientForgotPwdPath + } + if s.binding.OIDC.isEnabled() && !s.binding.isWebClientOIDCLoginDisabled() { + data.OpenIDLoginURL = webClientOIDCLoginPath + } + renderClientTemplate(w, templateCommonLogin, data) +} + +func (s *httpdServer) handleWebClientLogout(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + removeCookie(w, r, webBaseClientPath) + s.logoutOIDCUser(w, r) + + http.Redirect(w, r, webClientLoginPath, http.StatusFound) +} + +func (s *httpdServer) handleWebClientChangePwdPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + if err := r.ParseForm(); err != nil { + s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err := doChangeUserPassword(r, strings.TrimSpace(r.Form.Get("current_password")), + strings.TrimSpace(r.Form.Get("new_password1")), strings.TrimSpace(r.Form.Get("new_password2"))) + if err != nil { + s.renderClientChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) + return + } + s.handleWebClientLogout(w, r) +} + +func (s *httpdServer) handleClientWebLogin(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + if !dataprovider.HasAdmin() { + http.Redirect(w, r, webAdminSetupPath, http.StatusFound) + return + } + msg := getFlashMessage(w, r) + s.renderClientLoginPage(w, r, msg.getI18nError()) +} + +func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + protocol := common.ProtocolHTTP + username := strings.TrimSpace(r.Form.Get("username")) + password := r.Form.Get("password") + if username == "" || password == "" { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) + s.renderClientLoginPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + } + + if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message)) + return + } + + user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol) + if err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + s.renderClientLoginPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) + if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message)) + return + } + + defer user.CloseFs() //nolint:errcheck + err = user.CheckFsRoot(connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nErrorFsGeneric)) + return + } + s.loginUser(w, r, &user, connectionID, ipAddr, false, s.renderClientLoginPage) +} + +func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + err := r.ParseForm() + if err != nil { + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + newPassword := strings.TrimSpace(r.Form.Get("password")) + confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) + _, user, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), + newPassword, confirmPassword, false) + if err != nil { + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) + return + } + connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) + if err := checkHTTPClientUser(user, r, connectionID, true, false); err != nil { + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset)) + return + } + + defer user.CloseFs() //nolint:errcheck + err = user.CheckFsRoot(connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) + s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset)) + return + } + s.loginUser(w, r, user, connectionID, ipAddr, false, s.renderClientResetPwdPage) +} + +func (s *httpdServer) handleWebClientTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + claims, err := jwt.FromContext(r.Context()) + if err != nil { + s.renderNotFoundPage(w, r, nil) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + username := claims.Username + recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) + if username == "" || recoveryCode == "" { + s.renderClientTwoFactorRecoveryPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + user, userMerged, err := dataprovider.GetUserVariants(username, "") + if err != nil { + if errors.Is(err, util.ErrNotFound) { + handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck + } + s.renderClientTwoFactorRecoveryPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if !userMerged.Filters.TOTPConfig.Enabled || !slices.Contains(userMerged.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { + s.renderClientTwoFactorPage(w, r, util.NewI18nError( + util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled)) + return + } + for idx, code := range user.Filters.RecoveryCodes { + if err := code.Secret.Decrypt(); err != nil { + s.renderClientInternalServerErrorPage(w, r, fmt.Errorf("unable to decrypt recovery code: %w", err)) + return + } + if code.Secret.GetPayload() == recoveryCode { + if code.Used { + s.renderClientTwoFactorRecoveryPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + user.Filters.RecoveryCodes[idx].Used = true + err = dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, ipAddr, user.Role) + if err != nil { + logger.Warn(logSender, "", "unable to set the recovery code %q as used: %v", recoveryCode, err) + s.renderClientInternalServerErrorPage(w, r, errors.New("unable to set the recovery code as used")) + return + } + connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String()) + s.loginUser(w, r, &userMerged, connectionID, ipAddr, true, + s.renderClientTwoFactorRecoveryPage) + return + } + } + handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck + s.renderClientTwoFactorRecoveryPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) +} + +func (s *httpdServer) handleWebClientTwoFactorPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + claims, err := jwt.FromContext(r.Context()) + if err != nil { + s.renderNotFoundPage(w, r, nil) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + username := claims.Username + passcode := strings.TrimSpace(r.Form.Get("passcode")) + if username == "" || passcode == "" { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) + s.renderClientTwoFactorPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + user, err := dataprovider.GetUserWithGroupSettings(username, "") + if err != nil { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) + return + } + if !user.Filters.TOTPConfig.Enabled || !slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + s.renderClientTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled)) + return + } + err = user.Filters.TOTPConfig.Secret.Decrypt() + if err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + s.renderClientInternalServerErrorPage(w, r, err) + return + } + match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, passcode, + user.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials, r) + s.renderClientTwoFactorPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + connectionID := fmt.Sprintf("%s_%s", getProtocolFromRequest(r), xid.New().String()) + s.loginUser(w, r, &user, connectionID, ipAddr, true, s.renderClientTwoFactorPage) +} + +func (s *httpdServer) handleWebAdminTwoFactorRecoveryPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + + claims, err := jwt.FromContext(r.Context()) + if err != nil { + s.renderNotFoundPage(w, r, nil) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + username := claims.Username + recoveryCode := strings.TrimSpace(r.Form.Get("recovery_code")) + if username == "" || recoveryCode == "" { + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + admin, err := dataprovider.AdminExists(username) + if err != nil { + if errors.Is(err, util.ErrNotFound) { + handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck + } + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if !admin.Filters.TOTPConfig.Enabled { + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(util.NewValidationError("two factory authentication is not enabled"), util.I18n2FADisabled)) + return + } + for idx, code := range admin.Filters.RecoveryCodes { + if err := code.Secret.Decrypt(); err != nil { + s.renderInternalServerErrorPage(w, r, fmt.Errorf("unable to decrypt recovery code: %w", err)) + return + } + if code.Secret.GetPayload() == recoveryCode { + if code.Used { + s.renderTwoFactorRecoveryPage(w, r, + util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + admin.Filters.RecoveryCodes[idx].Used = true + err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role) + if err != nil { + logger.Warn(logSender, "", "unable to set the recovery code %q as used: %v", recoveryCode, err) + s.renderInternalServerErrorPage(w, r, errors.New("unable to set the recovery code as used")) + return + } + s.loginAdmin(w, r, &admin, true, s.renderTwoFactorRecoveryPage, ipAddr) + return + } + } + handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck + s.renderTwoFactorRecoveryPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) +} + +func (s *httpdServer) handleWebAdminTwoFactorPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + claims, err := jwt.FromContext(r.Context()) + if err != nil { + s.renderNotFoundPage(w, r, nil) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + username := claims.Username + passcode := strings.TrimSpace(r.Form.Get("passcode")) + if username == "" || passcode == "" { + s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck + s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + admin, err := dataprovider.AdminExists(username) + if err != nil { + if errors.Is(err, util.ErrNotFound) { + handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck + } + s.renderTwoFactorPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) + return + } + if !admin.Filters.TOTPConfig.Enabled { + s.renderTwoFactorPage(w, r, util.NewI18nError(common.ErrInternalFailure, util.I18n2FADisabled)) + return + } + err = admin.Filters.TOTPConfig.Secret.Decrypt() + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + match, err := mfa.ValidateTOTPPasscode(admin.Filters.TOTPConfig.ConfigName, passcode, + admin.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) //nolint:errcheck + s.renderTwoFactorPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + s.loginAdmin(w, r, &admin, true, s.renderTwoFactorPage, ipAddr) +} + +func (s *httpdServer) handleWebAdminLoginPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + username := strings.TrimSpace(r.Form.Get("username")) + password := strings.TrimSpace(r.Form.Get("password")) + if username == "" || password == "" { + s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderAdminLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr) + if err != nil { + handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck + s.renderAdminLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + s.loginAdmin(w, r, &admin, false, s.renderAdminLoginPage, ipAddr) +} + +func (s *httpdServer) renderAdminLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := loginPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18nLoginTitle, + CurrentURL: webAdminLoginPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + FormDisabled: s.binding.isWebAdminLoginFormDisabled(), + CheckRedirect: false, + } + if s.binding.showClientLoginURL() { + data.AltLoginURL = webClientLoginPath + data.AltLoginName = s.binding.webClientBranding().ShortName + } + if smtp.IsEnabled() && !data.FormDisabled { + data.ForgotPwdURL = webAdminForgotPwdPath + } + if s.binding.OIDC.hasRoles() && !s.binding.isWebAdminOIDCLoginDisabled() { + data.OpenIDLoginURL = webAdminOIDCLoginPath + } + renderAdminTemplate(w, templateCommonLogin, data) +} + +func (s *httpdServer) handleWebAdminLogin(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + if !dataprovider.HasAdmin() { + http.Redirect(w, r, webAdminSetupPath, http.StatusFound) + return + } + msg := getFlashMessage(w, r) + s.renderAdminLoginPage(w, r, msg.getI18nError()) +} + +func (s *httpdServer) handleWebAdminLogout(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + removeCookie(w, r, webBaseAdminPath) + s.logoutOIDCUser(w, r) + + http.Redirect(w, r, webAdminLoginPath, http.StatusFound) +} + +func (s *httpdServer) handleWebAdminChangePwdPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + err := r.ParseForm() + if err != nil { + s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err = doChangeAdminPassword(r, strings.TrimSpace(r.Form.Get("current_password")), + strings.TrimSpace(r.Form.Get("new_password1")), strings.TrimSpace(r.Form.Get("new_password2"))) + if err != nil { + s.renderChangePasswordPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) + return + } + s.handleWebAdminLogout(w, r) +} + +func (s *httpdServer) handleWebAdminPasswordResetPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + err := r.ParseForm() + if err != nil { + s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + newPassword := strings.TrimSpace(r.Form.Get("password")) + confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) + admin, _, err := handleResetPassword(r, strings.TrimSpace(r.Form.Get("code")), + newPassword, confirmPassword, true) + if err != nil { + s.renderResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorChangePwdGeneric)) + return + } + + s.loginAdmin(w, r, admin, false, s.renderResetPwdPage, ipAddr) +} + +func (s *httpdServer) handleWebAdminSetupPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + if dataprovider.HasAdmin() { + s.renderBadRequestPage(w, r, errors.New("an admin user already exists")) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + err := r.ParseForm() + if err != nil { + s.renderAdminSetupPage(w, r, "", util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + username := strings.TrimSpace(r.Form.Get("username")) + password := strings.TrimSpace(r.Form.Get("password")) + confirmPassword := strings.TrimSpace(r.Form.Get("confirm_password")) + installCode := strings.TrimSpace(r.Form.Get("install_code")) + if installationCode != "" && installCode != resolveInstallationCode() { + s.renderAdminSetupPage(w, r, username, + util.NewI18nError( + util.NewValidationError(fmt.Sprintf("%v mismatch", installationCodeHint)), + util.I18nErrorSetupInstallCode), + ) + return + } + if username == "" { + s.renderAdminSetupPage(w, r, username, + util.NewI18nError(util.NewValidationError("please set a username"), util.I18nError500Message)) + return + } + if password == "" { + s.renderAdminSetupPage(w, r, username, + util.NewI18nError(util.NewValidationError("please set a password"), util.I18nError500Message)) + return + } + if password != confirmPassword { + s.renderAdminSetupPage(w, r, username, + util.NewI18nError(errors.New("the two password fields do not match"), util.I18nErrorChangePwdNoMatch)) + return + } + admin := dataprovider.Admin{ + Username: username, + Password: password, + Status: 1, + Permissions: []string{dataprovider.PermAdminAny}, + } + err = dataprovider.AddAdmin(&admin, username, ipAddr, "") + if err != nil { + s.renderAdminSetupPage(w, r, username, util.NewI18nError(err, util.I18nError500Message)) + return + } + s.loginAdmin(w, r, &admin, false, nil, ipAddr) +} + +func (s *httpdServer) loginUser( + w http.ResponseWriter, r *http.Request, user *dataprovider.User, connectionID, ipAddr string, + isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), +) { + c := &jwt.Claims{ + Username: user.Username, + Permissions: user.Filters.WebClient, + Role: user.Role, + MustSetTwoFactorAuth: user.MustSetSecondFactor(), + MustChangePassword: user.MustChangePassword(), + RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols, + } + c.Subject = user.GetSignature() + + audience := tokenAudienceWebClient + if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) && + user.CanManageMFA() && !isSecondFactorAuth { + audience = tokenAudienceWebClientPartial + } + + err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr) + if err != nil { + logger.Warn(logSender, connectionID, "unable to set user login cookie %v", err) + updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message)) + return + } + invalidateToken(r) + if audience == tokenAudienceWebClientPartial { + redirectPath := webClientTwoFactorPath + if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { + redirectPath += "?next=" + url.QueryEscape(next) + } + http.Redirect(w, r, redirectPath, http.StatusFound) + return + } + updateLoginMetrics(user, dataprovider.LoginMethodPassword, ipAddr, err, r) + dataprovider.UpdateLastLogin(user) + if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { + http.Redirect(w, r, next, http.StatusFound) + return + } + http.Redirect(w, r, webClientFilesPath, http.StatusFound) +} + +func (s *httpdServer) loginAdmin( + w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin, + isSecondFactorAuth bool, errorFunc func(w http.ResponseWriter, r *http.Request, err *util.I18nError), + ipAddr string, +) { + c := &jwt.Claims{ + Username: admin.Username, + Permissions: admin.Permissions, + Role: admin.Role, + HideUserPageSections: admin.Filters.Preferences.HideUserPageSections, + MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled, + MustChangePassword: admin.Filters.RequirePasswordChange, + } + c.Subject = admin.GetSignature() + + audience := tokenAudienceWebAdmin + if admin.Filters.TOTPConfig.Enabled && admin.CanManageMFA() && !isSecondFactorAuth { + audience = tokenAudienceWebAdminPartial + } + + err := createAndSetCookie(w, r, c, s.tokenAuth, audience, ipAddr) + if err != nil { + logger.Warn(logSender, "", "unable to set admin login cookie %v", err) + if errorFunc == nil { + s.renderAdminSetupPage(w, r, admin.Username, util.NewI18nError(err, util.I18nError500Message)) + return + } + errorFunc(w, r, util.NewI18nError(err, util.I18nError500Message)) + return + } + invalidateToken(r) + if audience == tokenAudienceWebAdminPartial { + http.Redirect(w, r, webAdminTwoFactorPath, http.StatusFound) + return + } + dataprovider.UpdateAdminLastLogin(admin) + common.DelayLogin(nil) + redirectURL := webUsersPath + if errorFunc == nil { + redirectURL = webAdminMFAPath + } + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *httpdServer) logout(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + invalidateToken(r) + sendAPIResponse(w, r, nil, "Your token has been invalidated", http.StatusOK) +} + +func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + username, password, ok := r.BasicAuth() + protocol := common.ProtocolHTTP + if !ok { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + if username == "" || strings.TrimSpace(password) == "" { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, common.ErrNoCredentials, r) + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + if err := common.Config.ExecutePostConnectHook(ipAddr, protocol); err != nil { + updateLoginMetrics(&dataprovider.User{BaseUser: sdk.BaseUser{Username: username}}, + dataprovider.LoginMethodPassword, ipAddr, err, r) + sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + user, err := dataprovider.CheckUserAndPass(username, password, ipAddr, protocol) + if err != nil { + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), + http.StatusUnauthorized) + return + } + connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String()) + if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + + if user.Filters.TOTPConfig.Enabled && slices.Contains(user.Filters.TOTPConfig.Protocols, common.ProtocolHTTP) { + passcode := r.Header.Get(otpHeaderCode) + if passcode == "" { + logger.Debug(logSender, "", "TOTP enabled for user %q and not passcode provided, authentication refused", user.Username) + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials, r) + sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), + http.StatusUnauthorized) + return + } + err = user.Filters.TOTPConfig.Secret.Decrypt() + if err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + sendAPIResponse(w, r, fmt.Errorf("unable to decrypt TOTP secret: %w", err), http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + match, err := mfa.ValidateTOTPPasscode(user.Filters.TOTPConfig.ConfigName, passcode, + user.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + logger.Debug(logSender, "invalid passcode for user %q, match? %v, err: %v", user.Username, match, err) + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, dataprovider.ErrInvalidCredentials, r) + sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), + http.StatusUnauthorized) + return + } + } + + defer user.CloseFs() //nolint:errcheck + err = user.CheckFsRoot(connectionID) + if err != nil { + logger.Warn(logSender, connectionID, "unable to check fs root: %v", err) + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + s.generateAndSendUserToken(w, r, ipAddr, user) +} + +func (s *httpdServer) generateAndSendUserToken(w http.ResponseWriter, r *http.Request, ipAddr string, user dataprovider.User) { + c := &jwt.Claims{ + Username: user.Username, + Permissions: user.Filters.WebClient, + Role: user.Role, + MustSetTwoFactorAuth: user.MustSetSecondFactor(), + MustChangePassword: user.MustChangePassword(), + RequiredTwoFactorProtocols: user.Filters.TwoFactorAuthProtocols, + } + c.Subject = user.GetSignature() + + token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPIUser, ipAddr, getTokenDuration(tokenAudienceAPIUser)) + if err != nil { + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, common.ErrInternalFailure, r) + sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r) + dataprovider.UpdateLastLogin(&user) + + render.JSON(w, r, c.BuildTokenResponse(token)) +} + +func (s *httpdServer) getToken(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok { + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + sendAPIResponse(w, r, nil, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + admin, err := dataprovider.CheckAdminAndPass(username, password, ipAddr) + if err != nil { + handleDefenderEventLoginFailed(ipAddr, err) //nolint:errcheck + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + sendAPIResponse(w, r, dataprovider.ErrInvalidCredentials, http.StatusText(http.StatusUnauthorized), + http.StatusUnauthorized) + return + } + if admin.Filters.TOTPConfig.Enabled { + passcode := r.Header.Get(otpHeaderCode) + if passcode == "" { + logger.Debug(logSender, "", "TOTP enabled for admin %q and not passcode provided, authentication refused", admin.Username) + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + err = handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) + sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + err = admin.Filters.TOTPConfig.Secret.Decrypt() + if err != nil { + sendAPIResponse(w, r, fmt.Errorf("unable to decrypt TOTP secret: %w", err), + http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + match, err := mfa.ValidateTOTPPasscode(admin.Filters.TOTPConfig.ConfigName, passcode, + admin.Filters.TOTPConfig.Secret.GetPayload()) + if !match || err != nil { + logger.Debug(logSender, "invalid passcode for admin %q, match? %v, err: %v", admin.Username, match, err) + w.Header().Set(common.HTTPAuthenticationHeader, basicRealm) + err = handleDefenderEventLoginFailed(ipAddr, dataprovider.ErrInvalidCredentials) + sendAPIResponse(w, r, err, http.StatusText(http.StatusUnauthorized), + http.StatusUnauthorized) + return + } + } + + s.generateAndSendToken(w, r, admin, ipAddr) +} + +func (s *httpdServer) generateAndSendToken(w http.ResponseWriter, r *http.Request, admin dataprovider.Admin, ip string) { + c := &jwt.Claims{ + Username: admin.Username, + Permissions: admin.Permissions, + Role: admin.Role, + MustSetTwoFactorAuth: admin.Filters.RequireTwoFactor && !admin.Filters.TOTPConfig.Enabled, + MustChangePassword: admin.Filters.RequirePasswordChange, + } + c.Subject = admin.GetSignature() + + token, err := s.tokenAuth.SignWithParams(c, tokenAudienceAPI, ip, getTokenDuration(tokenAudienceAPI)) + if err != nil { + sendAPIResponse(w, r, err, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + dataprovider.UpdateAdminLastLogin(&admin) + common.DelayLogin(nil) + render.JSON(w, r, c.BuildTokenResponse(token)) +} + +func (s *httpdServer) checkCookieExpiration(w http.ResponseWriter, r *http.Request) { + if _, ok := r.Context().Value(oidcTokenKey).(string); ok { + return + } + claims, err := jwt.FromContext(r.Context()) + if err != nil { + return + } + if claims.Username == "" || claims.Subject == "" { + return + } + if time.Until(claims.Expiry.Time()) > cookieRefreshThreshold { + return + } + if (time.Since(claims.IssuedAt.Time()) + cookieTokenDuration) > maxTokenDuration { + return + } + if claims.Audience.Contains(tokenAudienceWebClient) { + s.refreshClientToken(w, r, claims) + } else { + s.refreshAdminToken(w, r, claims) + } +} + +func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) { + user, err := dataprovider.GetUserWithGroupSettings(tokenClaims.Username, "") + if err != nil { + return + } + if user.GetSignature() != tokenClaims.Subject { + logger.Debug(logSender, "", "signature mismatch for user %q, unable to refresh cookie", user.Username) + return + } + if err := user.CheckLoginConditions(); err != nil { + logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) + return + } + if err := checkHTTPClientUser(&user, r, xid.New().String(), true, false); err != nil { + logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err) + return + } + + tokenClaims.Permissions = user.Filters.WebClient + tokenClaims.Role = user.Role + logger.Debug(logSender, "", "cookie refreshed for user %q", user.Username) + createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebClient, util.GetIPFromRemoteAddress(r.RemoteAddr)) //nolint:errcheck +} + +func (s *httpdServer) refreshAdminToken(w http.ResponseWriter, r *http.Request, tokenClaims *jwt.Claims) { + admin, err := dataprovider.AdminExists(tokenClaims.Username) + if err != nil { + return + } + if admin.GetSignature() != tokenClaims.Subject { + logger.Debug(logSender, "", "signature mismatch for admin %q, unable to refresh cookie", admin.Username) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := admin.CanLogin(ipAddr); err != nil { + logger.Debug(logSender, "", "unable to refresh cookie for admin %q, err: %v", admin.Username, err) + return + } + tokenClaims.Permissions = admin.Permissions + tokenClaims.Role = admin.Role + tokenClaims.HideUserPageSections = admin.Filters.Preferences.HideUserPageSections + logger.Debug(logSender, "", "cookie refreshed for admin %q", admin.Username) + createAndSetCookie(w, r, tokenClaims, s.tokenAuth, tokenAudienceWebAdmin, ipAddr) //nolint:errcheck +} + +func (s *httpdServer) updateContextFromCookie(r *http.Request) *http.Request { + _, err := jwt.FromContext(r.Context()) + if err != nil { + _, err = r.Cookie(jwt.CookieKey) + if err != nil { + return r + } + token, err := jwt.VerifyRequest(s.tokenAuth, r, jwt.TokenFromCookie) + ctx := jwt.NewContext(r.Context(), token, err) + return r.WithContext(ctx) + } + return r +} + +func (s *httpdServer) parseHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + responseControllerDeadlines( + http.NewResponseController(w), + time.Now().Add(60*time.Second), + time.Now().Add(60*time.Second), + ) + w.Header().Set("Server", version.GetServerVersion("/", false)) + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + var ip net.IP + isUnixSocket := filepath.IsAbs(s.binding.Address) + if !isUnixSocket { + ip = net.ParseIP(ipAddr) + } + areHeadersAllowed := false + if isUnixSocket || ip != nil { + for _, allow := range s.binding.allowHeadersFrom { + if allow(ip) { + parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth) + if parsedIP != "" { + ipAddr = parsedIP + r.RemoteAddr = ipAddr + } + if forwardedProto := r.Header.Get(xForwardedProto); forwardedProto != "" { + ctx := context.WithValue(r.Context(), forwardedProtoKey, forwardedProto) + r = r.WithContext(ctx) + } + areHeadersAllowed = true + break + } + } + } + if !areHeadersAllowed { + for idx := range s.binding.Security.proxyHeaders { + r.Header.Del(s.binding.Security.proxyHeaders[idx]) + } + } + + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) checkConnection(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) + + if err := common.Connections.IsNewConnectionAllowed(ipAddr, common.ProtocolHTTP); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolHTTP, "", "connection not allowed from ip %q: %v", ipAddr, err) + s.sendForbiddenResponse(w, r, util.NewI18nError(err, util.I18nErrorConnectionForbidden)) + return + } + if common.IsBanned(ipAddr, common.ProtocolHTTP) { + s.sendForbiddenResponse(w, r, util.NewI18nError( + util.NewGenericError("your IP address is blocked"), + util.I18nErrorIPForbidden), + ) + return + } + if delay, err := common.LimitRate(common.ProtocolHTTP, ipAddr); err != nil { + delay += 499999999 * time.Nanosecond + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) + w.Header().Set("X-Retry-In", delay.String()) + s.sendTooManyRequestResponse(w, r, err) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (s *httpdServer) sendTooManyRequestResponse(w http.ResponseWriter, r *http.Request, err error) { + if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { + r = s.updateContextFromCookie(r) + if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(errors.New(http.StatusText(http.StatusTooManyRequests)), util.I18nError429Message), "") + return + } + s.renderMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(errors.New(http.StatusText(http.StatusTooManyRequests)), util.I18nError429Message), "") + return + } + sendAPIResponse(w, r, err, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) +} + +func (s *httpdServer) sendForbiddenResponse(w http.ResponseWriter, r *http.Request, err error) { + if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { + r = s.updateContextFromCookie(r) + if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { + s.renderClientForbiddenPage(w, r, err) + return + } + s.renderForbiddenPage(w, r, err) + return + } + sendAPIResponse(w, r, err, "", http.StatusForbidden) +} + +func (s *httpdServer) badHostHandler(w http.ResponseWriter, r *http.Request) { + host := r.Host + for _, header := range s.binding.Security.HostsProxyHeaders { + if h := r.Header.Get(header); h != "" { + host = h + break + } + } + logger.Debug(logSender, "", "the host %q is not allowed", host) + s.sendForbiddenResponse(w, r, util.NewI18nError( + util.NewGenericError(http.StatusText(http.StatusForbidden)), + util.I18nErrorConnectionForbidden, + )) +} + +func (s *httpdServer) notFoundHandler(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + if (s.enableWebAdmin || s.enableWebClient) && isWebRequest(r) { + r = s.updateContextFromCookie(r) + if s.enableWebClient && (isWebClientRequest(r) || !s.enableWebAdmin) { + s.renderClientNotFoundPage(w, r, nil) + return + } + s.renderNotFoundPage(w, r, nil) + return + } + sendAPIResponse(w, r, nil, http.StatusText(http.StatusNotFound), http.StatusNotFound) +} + +func (s *httpdServer) redirectToWebPath(w http.ResponseWriter, r *http.Request, webPath string) { + if dataprovider.HasAdmin() { + http.Redirect(w, r, webPath, http.StatusFound) + return + } + if s.enableWebAdmin { + http.Redirect(w, r, webAdminSetupPath, http.StatusFound) + } +} + +// The StripSlashes causes infinite redirects at the root path if used with http.FileServer. +// We also don't strip paths with more than one trailing slash, see #1434 +func (s *httpdServer) mustStripSlash(r *http.Request) bool { + urlPath := getURLPath(r) + return !strings.HasSuffix(urlPath, "//") && !strings.HasPrefix(urlPath, webOpenAPIPath) && + !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI) +} + +func (s *httpdServer) mustCheckPath(r *http.Request) bool { + urlPath := getURLPath(r) + return !strings.HasPrefix(urlPath, webStaticFilesPath) && !strings.HasPrefix(urlPath, acmeChallengeURI) +} + +func (s *httpdServer) initializeRouter() error { + signer, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase)) + if err != nil { + return err + } + csrfSigner, err := jwt.NewSigner(jose.HS256, getSigningKey(s.signingPassphrase)) + if err != nil { + return err + } + var hasHTTPSRedirect bool + s.tokenAuth = signer + s.csrfTokenAuth = csrfSigner + s.router = chi.NewRouter() + + s.router.Use(middleware.RequestID) + s.router.Use(s.parseHeaders) + s.router.Use(logger.NewStructuredLogger(logger.GetLogger())) + s.router.Use(middleware.Recoverer) + if s.binding.Security.Enabled { + secureMiddleware := secure.New(secure.Options{ + AllowedHosts: s.binding.Security.AllowedHosts, + AllowedHostsAreRegex: s.binding.Security.AllowedHostsAreRegex, + HostsProxyHeaders: s.binding.Security.HostsProxyHeaders, + SSLProxyHeaders: s.binding.Security.getHTTPSProxyHeaders(), + STSSeconds: s.binding.Security.STSSeconds, + STSIncludeSubdomains: s.binding.Security.STSIncludeSubdomains, + STSPreload: s.binding.Security.STSPreload, + ContentTypeNosniff: s.binding.Security.ContentTypeNosniff, + ContentSecurityPolicy: s.binding.Security.ContentSecurityPolicy, + PermissionsPolicy: s.binding.Security.PermissionsPolicy, + CrossOriginOpenerPolicy: s.binding.Security.CrossOriginOpenerPolicy, + CrossOriginResourcePolicy: s.binding.Security.CrossOriginResourcePolicy, + CrossOriginEmbedderPolicy: s.binding.Security.CrossOriginEmbedderPolicy, + ReferrerPolicy: s.binding.Security.ReferrerPolicy, + }) + secureMiddleware.SetBadHostHandler(http.HandlerFunc(s.badHostHandler)) + if s.binding.Security.CacheControl == "private" { + s.router.Use(cacheControlMiddleware) + } + s.router.Use(secureMiddleware.Handler) + if s.binding.Security.HTTPSRedirect { + s.router.Use(s.binding.Security.redirectHandler) + hasHTTPSRedirect = true + } + } + if s.cors.Enabled { + c := cors.New(cors.Options{ + AllowedOrigins: util.RemoveDuplicates(s.cors.AllowedOrigins, true), + AllowedMethods: util.RemoveDuplicates(s.cors.AllowedMethods, true), + AllowedHeaders: util.RemoveDuplicates(s.cors.AllowedHeaders, true), + ExposedHeaders: util.RemoveDuplicates(s.cors.ExposedHeaders, true), + MaxAge: s.cors.MaxAge, + AllowCredentials: s.cors.AllowCredentials, + OptionsPassthrough: s.cors.OptionsPassthrough, + OptionsSuccessStatus: s.cors.OptionsSuccessStatus, + AllowPrivateNetwork: s.cors.AllowPrivateNetwork, + }) + s.router.Use(c.Handler) + } + s.router.Use(middleware.Maybe(s.checkConnection, s.mustCheckPath)) + s.router.Use(middleware.GetHead) + s.router.Use(middleware.Maybe(middleware.StripSlashes, s.mustStripSlash)) + + s.router.NotFound(s.notFoundHandler) + + s.router.Get(healthzPath, func(w http.ResponseWriter, r *http.Request) { + render.PlainText(w, r, "ok") + }) + + if hasHTTPSRedirect { + if p := acme.GetHTTP01WebRoot(); p != "" { + serveStaticDir(s.router, acmeChallengeURI, p, true) + } + } + + s.setupRESTAPIRoutes() + + if s.enableWebAdmin || s.enableWebClient { + s.router.Group(func(router chi.Router) { + router.Use(cleanCacheControlMiddleware) + router.Use(compressor.Handler) + serveStaticDir(router, webStaticFilesPath, s.staticFilesPath, true) + }) + if s.binding.OIDC.isEnabled() { + s.router.Get(webOIDCRedirectPath, s.handleOIDCRedirect) + } + if s.enableWebClient { + s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.redirectToWebPath(w, r, webClientLoginPath) + }) + s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.redirectToWebPath(w, r, webClientLoginPath) + }) + } else { + s.router.Get(webRootPath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.redirectToWebPath(w, r, webAdminLoginPath) + }) + s.router.Get(webBasePath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.redirectToWebPath(w, r, webAdminLoginPath) + }) + } + } + + s.setupWebClientRoutes() + s.setupWebAdminRoutes() + return nil +} + +func (s *httpdServer) setupRESTAPIRoutes() { + if s.enableRESTAPI { + if !s.binding.isAdminTokenEndpointDisabled() { + s.router.Get(tokenPath, s.getToken) + s.router.Post(adminPath+"/{username}/forgot-password", forgotAdminPassword) + s.router.Post(adminPath+"/{username}/reset-password", resetAdminPassword) + } + + s.router.Group(func(router chi.Router) { + router.Use(checkNodeToken(s.tokenAuth)) + if !s.binding.isAdminAPIKeyAuthDisabled() { + router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeAdmin)) + } + router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader)) + router.Use(jwtAuthenticatorAPI) + + router.Get(versionPath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + render.JSON(w, r, version.Get()) + }) + + router.With(forbidAPIKeyAuthentication).Get(logoutPath, s.logout) + router.With(forbidAPIKeyAuthentication).Get(adminProfilePath, getAdminProfile) + router.With(forbidAPIKeyAuthentication, s.checkAuthRequirements).Put(adminProfilePath, updateAdminProfile) + router.With(forbidAPIKeyAuthentication).Put(adminPwdPath, changeAdminPassword) + // admin TOTP APIs + router.With(forbidAPIKeyAuthentication).Get(adminTOTPConfigsPath, getTOTPConfigs) + router.With(forbidAPIKeyAuthentication).Post(adminTOTPGeneratePath, generateTOTPSecret) + router.With(forbidAPIKeyAuthentication).Post(adminTOTPValidatePath, validateTOTPPasscode) + router.With(forbidAPIKeyAuthentication).Post(adminTOTPSavePath, saveTOTPConfig) + router.With(forbidAPIKeyAuthentication).Get(admin2FARecoveryCodesPath, getRecoveryCodes) + router.With(forbidAPIKeyAuthentication).Post(admin2FARecoveryCodesPath, generateRecoveryCodes) + + router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). + Get(apiKeysPath, getAPIKeys) + router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). + Post(apiKeysPath, addAPIKey) + router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). + Get(apiKeysPath+"/{id}", getAPIKeyByID) + router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). + Put(apiKeysPath+"/{id}", updateAPIKey) + router.With(forbidAPIKeyAuthentication, s.checkPerms(dataprovider.PermAdminAny)). + Delete(apiKeysPath+"/{id}", deleteAPIKey) + + router.Group(func(router chi.Router) { + router.Use(s.checkAuthRequirements) + + router.With(s.checkPerms(dataprovider.PermAdminViewServerStatus)). + Get(serverStatusPath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + render.JSON(w, r, getServicesStatus()) + }) + + router.With(s.checkPerms(dataprovider.PermAdminViewConnections)).Get(activeConnectionsPath, getActiveConnections) + router.With(s.checkPerms(dataprovider.PermAdminCloseConnections)). + Delete(activeConnectionsPath+"/{connectionID}", handleCloseConnection) + router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Get(quotasBasePath+"/users/scans", getUsersQuotaScans) + router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Post(quotasBasePath+"/users/{username}/scan", startUserQuotaScan) + router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Get(quotasBasePath+"/folders/scans", getFoldersQuotaScans) + router.With(s.checkPerms(dataprovider.PermAdminQuotaScans)).Post(quotasBasePath+"/folders/{name}/scan", startFolderQuotaScan) + router.With(s.checkPerms(dataprovider.PermAdminViewUsers)).Get(userPath, getUsers) + router.With(s.checkPerms(dataprovider.PermAdminAddUsers)).Post(userPath, addUser) + router.With(s.checkPerms(dataprovider.PermAdminViewUsers)).Get(userPath+"/{username}", getUserByUsername) //nolint:goconst + router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(userPath+"/{username}", updateUser) + router.With(s.checkPerms(dataprovider.PermAdminDeleteUsers)).Delete(userPath+"/{username}", deleteUser) + router.With(s.checkPerms(dataprovider.PermAdminDisableMFA)).Put(userPath+"/{username}/2fa/disable", disableUser2FA) //nolint:goconst + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Get(folderPath, getFolders) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Get(folderPath+"/{name}", getFolderByName) //nolint:goconst + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(folderPath, addFolder) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Put(folderPath+"/{name}", updateFolder) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Delete(folderPath+"/{name}", deleteFolder) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Get(groupPath, getGroups) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Get(groupPath+"/{name}", getGroupByName) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Post(groupPath, addGroup) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Put(groupPath+"/{name}", updateGroup) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Delete(groupPath+"/{name}", deleteGroup) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(dumpDataPath, dumpData) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(loadDataPath, loadData) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(loadDataPath, loadDataFromRequest) + router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/usage", + updateUserQuotaUsage) + router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/users/{username}/transfer-usage", + updateUserTransferQuotaUsage) + router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Put(quotasBasePath+"/folders/{name}/usage", + updateFolderQuotaUsage) + router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(defenderHosts, getDefenderHosts) + router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(defenderHosts+"/{id}", getDefenderHostByID) + router.With(s.checkPerms(dataprovider.PermAdminManageDefender)).Delete(defenderHosts+"/{id}", deleteDefenderHostByID) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(adminPath, getAdmins) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(adminPath, addAdmin) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(adminPath+"/{username}", getAdminByUsername) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(adminPath+"/{username}", updateAdmin) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(adminPath+"/{username}", deleteAdmin) + router.With(s.checkPerms(dataprovider.PermAdminDisableMFA)).Put(adminPath+"/{username}/2fa/disable", disableAdmin2FA) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(retentionChecksPath, getRetentionChecks) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler). + Get(fsEventsPath, searchFsEvents) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler). + Get(providerEventsPath, searchProviderEvents) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler). + Get(logEventsPath, searchLogEvents) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventActionsPath, getEventActions) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventActionsPath+"/{name}", getEventActionByName) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(eventActionsPath, addEventAction) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(eventActionsPath+"/{name}", updateEventAction) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(eventActionsPath+"/{name}", deleteEventAction) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventRulesPath, getEventRules) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(eventRulesPath+"/{name}", getEventRuleByName) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(eventRulesPath, addEventRule) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(eventRulesPath+"/{name}", updateEventRule) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(eventRulesPath+"/{name}", deleteEventRule) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(eventRulesPath+"/run/{name}", runOnDemandRule) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(rolesPath, getRoles) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(rolesPath, addRole) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(rolesPath+"/{name}", getRoleByName) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(rolesPath+"/{name}", updateRole) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(rolesPath+"/{name}", deleteRole) + router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler).Get(ipListsPath+"/{type}", getIPListEntries) //nolint:goconst + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(ipListsPath+"/{type}", addIPListEntry) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(ipListsPath+"/{type}/{ipornet}", getIPListEntry) //nolint:goconst + router.With(s.checkPerms(dataprovider.PermAdminAny)).Put(ipListsPath+"/{type}/{ipornet}", updateIPListEntry) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Delete(ipListsPath+"/{type}/{ipornet}", deleteIPListEntry) + }) + }) + + // share API available to external users + s.router.Get(sharesPath+"/{id}", s.downloadFromShare) + s.router.Post(sharesPath+"/{id}", s.uploadFilesToShare) + s.router.Post(sharesPath+"/{id}/{name}", s.uploadFileToShare) + s.router.With(compressor.Handler).Get(sharesPath+"/{id}/dirs", s.readBrowsableShareContents) + s.router.Get(sharesPath+"/{id}/files", s.downloadBrowsableSharedFile) + + if !s.binding.isUserTokenEndpointDisabled() { + s.router.Get(userTokenPath, s.getUserToken) + s.router.Post(userPath+"/{username}/forgot-password", forgotUserPassword) + s.router.Post(userPath+"/{username}/reset-password", resetUserPassword) + } + + s.router.Group(func(router chi.Router) { + if !s.binding.isUserAPIKeyAuthDisabled() { + router.Use(checkAPIKeyAuth(s.tokenAuth, dataprovider.APIKeyScopeUser)) + } + router.Use(jwt.Verify(s.tokenAuth, jwt.TokenFromHeader)) + router.Use(jwtAuthenticatorAPIUser) + + router.With(forbidAPIKeyAuthentication).Get(userLogoutPath, s.logout) + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientPasswordChangeDisabled)). + Put(userPwdPath, changeUserPassword) + router.With(forbidAPIKeyAuthentication).Get(userProfilePath, getUserProfile) + router.With(forbidAPIKeyAuthentication, s.checkAuthRequirements).Put(userProfilePath, updateUserProfile) + // user TOTP APIs + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). + Get(userTOTPConfigsPath, getTOTPConfigs) + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). + Post(userTOTPGeneratePath, generateTOTPSecret) + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). + Post(userTOTPValidatePath, validateTOTPPasscode) + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). + Post(userTOTPSavePath, saveTOTPConfig) + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). + Get(user2FARecoveryCodesPath, getRecoveryCodes) + router.With(forbidAPIKeyAuthentication, s.checkHTTPUserPerm(sdk.WebClientMFADisabled)). + Post(user2FARecoveryCodesPath, generateRecoveryCodes) + + router.With(s.checkAuthRequirements, compressor.Handler).Get(userDirsPath, readUserFolder) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Post(userDirsPath, createUserDir) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Patch(userDirsPath, renameUserFsEntry) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Delete(userDirsPath, deleteUserDir) + router.With(s.checkAuthRequirements).Get(userFilesPath, getUserFile) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Post(userFilesPath, uploadUserFiles) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Patch(userFilesPath, renameUserFsEntry) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Delete(userFilesPath, deleteUserFile) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Post(userFileActionsPath+"/move", renameUserFsEntry) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Post(userFileActionsPath+"/copy", copyUserFsEntry) + router.With(s.checkAuthRequirements).Post(userStreamZipPath, getUserFilesAsZipStream) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Get(userSharesPath, getShares) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Post(userSharesPath, addShare) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Get(userSharesPath+"/{id}", getShareByID) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Put(userSharesPath+"/{id}", updateShare) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Delete(userSharesPath+"/{id}", deleteShare) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Post(userUploadFilePath, uploadUserFile) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled)). + Patch(userFilesDirsMetadataPath, setFileDirMetadata) + }) + + if s.renderOpenAPI { + s.router.Group(func(router chi.Router) { + router.Use(cleanCacheControlMiddleware) + router.Use(compressor.Handler) + serveStaticDir(router, webOpenAPIPath, s.openAPIPath, false) + }) + } + } +} + +func (s *httpdServer) setupWebClientRoutes() { + if s.enableWebClient { + s.router.Get(webBaseClientPath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + http.Redirect(w, r, webClientLoginPath, http.StatusFound) + }) + s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webclient/logo.png"), + func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + renderPNGImage(w, r, dbBrandingConfig.getWebClientLogo()) + }) + s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webclient/favicon.png"), + func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + renderPNGImage(w, r, dbBrandingConfig.getWebClientFavicon()) + }) + s.router.Get(webClientLoginPath, s.handleClientWebLogin) + if s.binding.OIDC.isEnabled() && !s.binding.isWebClientOIDCLoginDisabled() { + s.router.Get(webClientOIDCLoginPath, s.handleWebClientOIDCLogin) + } + if !s.binding.isWebClientLoginFormDisabled() { + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webClientLoginPath, s.handleWebClientLoginPost) + s.router.Get(webClientForgotPwdPath, s.handleWebClientForgotPwd) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webClientForgotPwdPath, s.handleWebClientForgotPwdPost) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Get(webClientResetPwdPath, s.handleWebClientPasswordReset) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webClientResetPwdPath, s.handleWebClientPasswordResetPost) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). + Get(webClientTwoFactorPath, s.handleWebClientTwoFactor) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). + Post(webClientTwoFactorPath, s.handleWebClientTwoFactorPost) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). + Get(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecovery) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebClientPartial)). + Post(webClientTwoFactorRecoveryPath, s.handleWebClientTwoFactorRecoveryPost) + } + // share routes available to external users + s.router.Get(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginGet) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webClientPubSharesPath+"/{id}/login", s.handleClientShareLoginPost) + s.router.Get(webClientPubSharesPath+"/{id}/logout", s.handleClientShareLogout) + s.router.Get(webClientPubSharesPath+"/{id}", s.downloadFromShare) + s.router.Post(webClientPubSharesPath+"/{id}/partial", s.handleClientSharePartialDownload) + s.router.Get(webClientPubSharesPath+"/{id}/browse", s.handleShareGetFiles) + s.router.Post(webClientPubSharesPath+"/{id}/browse/exist", s.handleClientShareCheckExist) + s.router.Get(webClientPubSharesPath+"/{id}/download", s.handleClientSharedFile) + s.router.Get(webClientPubSharesPath+"/{id}/upload", s.handleClientUploadToShare) + s.router.With(compressor.Handler).Get(webClientPubSharesPath+"/{id}/dirs", s.handleShareGetDirContents) + s.router.Post(webClientPubSharesPath+"/{id}", s.uploadFilesToShare) + s.router.Post(webClientPubSharesPath+"/{id}/{name}", s.uploadFileToShare) + s.router.Get(webClientPubSharesPath+"/{id}/viewpdf", s.handleShareViewPDF) + s.router.Get(webClientPubSharesPath+"/{id}/getpdf", s.handleShareGetPDF) + + s.router.Group(func(router chi.Router) { + if s.binding.OIDC.isEnabled() { + router.Use(s.oidcTokenAuthenticator(tokenAudienceWebClient)) + } + router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie)) + router.Use(jwtAuthenticatorWebClient) + + router.Get(webClientLogoutPath, s.handleWebClientLogout) + router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientFilesPath, s.handleClientGetFiles) + router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientViewPDFPath, s.handleClientViewPDF) + router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientGetPDFPath, s.handleClientGetPDF) + router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientFilePath, getUserFile) + router.With(s.checkAuthRequirements, s.refreshCookie, s.verifyCSRFHeader).Get(webClientTasksPath+"/{id}", + getWebTask) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Post(webClientFilePath, uploadUserFile) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Post(webClientExistPath, s.handleClientCheckExist) + router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientEditFilePath, s.handleClientEditFile) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Delete(webClientFilesPath, deleteUserFile) + router.With(s.checkAuthRequirements, compressor.Handler, s.refreshCookie). + Get(webClientDirsPath, s.handleClientGetDirContents) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Post(webClientDirsPath, createUserDir) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Delete(webClientDirsPath, taskDeleteDir) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Post(webClientFileActionsPath+"/move", taskRenameFsEntry) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientWriteDisabled), s.verifyCSRFHeader). + Post(webClientFileActionsPath+"/copy", taskCopyFsEntry) + router.With(s.checkAuthRequirements, s.refreshCookie). + Post(webClientDownloadZipPath, s.handleWebClientDownloadZip) + router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientPingPath, handlePingRequest) + router.With(s.checkAuthRequirements, s.refreshCookie).Get(webClientProfilePath, + s.handleClientGetProfile) + router.With(s.checkAuthRequirements).Post(webClientProfilePath, s.handleWebClientProfilePost) + router.With(s.checkHTTPUserPerm(sdk.WebClientPasswordChangeDisabled)). + Get(webChangeClientPwdPath, s.handleWebClientChangePwd) + router.With(s.checkHTTPUserPerm(sdk.WebClientPasswordChangeDisabled)). + Post(webChangeClientPwdPath, s.handleWebClientChangePwdPost) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie). + Get(webClientMFAPath, s.handleWebClientMFA) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.refreshCookie). + Get(webClientMFAPath+"/qrcode", getQRCode) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). + Post(webClientTOTPGeneratePath, generateTOTPSecret) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). + Post(webClientTOTPValidatePath, validateTOTPPasscode) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). + Post(webClientTOTPSavePath, saveTOTPConfig) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader, s.refreshCookie). + Get(webClientRecoveryCodesPath, getRecoveryCodes) + router.With(s.checkHTTPUserPerm(sdk.WebClientMFADisabled), s.verifyCSRFHeader). + Post(webClientRecoveryCodesPath, generateRecoveryCodes) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), compressor.Handler, s.refreshCookie). + Get(webClientSharesPath+jsonAPISuffix, getAllShares) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.refreshCookie). + Get(webClientSharesPath, s.handleClientGetShares) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.refreshCookie). + Get(webClientSharePath, s.handleClientAddShareGet) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Post(webClientSharePath, s.handleClientAddSharePost) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.refreshCookie). + Get(webClientSharePath+"/{id}", s.handleClientUpdateShareGet) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled)). + Post(webClientSharePath+"/{id}", s.handleClientUpdateSharePost) + router.With(s.checkAuthRequirements, s.checkHTTPUserPerm(sdk.WebClientSharesDisabled), s.verifyCSRFHeader). + Delete(webClientSharePath+"/{id}", deleteShare) + }) + } +} + +func (s *httpdServer) setupWebAdminRoutes() { + if s.enableWebAdmin { + s.router.Get(webBaseAdminPath, func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + s.redirectToWebPath(w, r, webAdminLoginPath) + }) + s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webadmin/logo.png"), + func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + renderPNGImage(w, r, dbBrandingConfig.getWebAdminLogo()) + }) + s.router.With(cleanCacheControlMiddleware).Get(path.Join(webStaticFilesPath, "branding/webadmin/favicon.png"), + func(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + renderPNGImage(w, r, dbBrandingConfig.getWebAdminFavicon()) + }) + s.router.Get(webAdminLoginPath, s.handleWebAdminLogin) + if s.binding.OIDC.hasRoles() && !s.binding.isWebAdminOIDCLoginDisabled() { + s.router.Get(webAdminOIDCLoginPath, s.handleWebAdminOIDCLogin) + } + s.router.Get(webOAuth2RedirectPath, s.handleOAuth2TokenRedirect) + s.router.Get(webAdminSetupPath, s.handleWebAdminSetupGet) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webAdminSetupPath, s.handleWebAdminSetupPost) + if !s.binding.isWebAdminLoginFormDisabled() { + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webAdminLoginPath, s.handleWebAdminLoginPost) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). + Get(webAdminTwoFactorPath, s.handleWebAdminTwoFactor) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). + Post(webAdminTwoFactorPath, s.handleWebAdminTwoFactorPost) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). + Get(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecovery) + s.router.With(jwt.Verify(s.tokenAuth, jwt.TokenFromCookie), + s.jwtAuthenticatorPartial(tokenAudienceWebAdminPartial)). + Post(webAdminTwoFactorRecoveryPath, s.handleWebAdminTwoFactorRecoveryPost) + s.router.Get(webAdminForgotPwdPath, s.handleWebAdminForgotPwd) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webAdminForgotPwdPath, s.handleWebAdminForgotPwdPost) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Get(webAdminResetPwdPath, s.handleWebAdminPasswordReset) + s.router.With(jwt.Verify(s.csrfTokenAuth, jwt.TokenFromCookie)). + Post(webAdminResetPwdPath, s.handleWebAdminPasswordResetPost) + } + + s.router.Group(func(router chi.Router) { + if s.binding.OIDC.isEnabled() { + router.Use(s.oidcTokenAuthenticator(tokenAudienceWebAdmin)) + } + router.Use(jwt.Verify(s.tokenAuth, oidcTokenFromContext, jwt.TokenFromCookie)) + router.Use(jwtAuthenticatorWebAdmin) + + router.Get(webLogoutPath, s.handleWebAdminLogout) + router.With(s.refreshCookie, s.checkAuthRequirements, s.requireBuiltinLogin).Get( + webAdminProfilePath, s.handleWebAdminProfile) + router.With(s.checkAuthRequirements, s.requireBuiltinLogin).Post(webAdminProfilePath, s.handleWebAdminProfilePost) + router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webChangeAdminPwdPath, s.handleWebAdminChangePwd) + router.With(s.requireBuiltinLogin).Post(webChangeAdminPwdPath, s.handleWebAdminChangePwdPost) + + router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath, s.handleWebAdminMFA) + router.With(s.refreshCookie, s.requireBuiltinLogin).Get(webAdminMFAPath+"/qrcode", getQRCode) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPGeneratePath, generateTOTPSecret) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPValidatePath, validateTOTPPasscode) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminTOTPSavePath, saveTOTPConfig) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin, s.refreshCookie).Get(webAdminRecoveryCodesPath, + getRecoveryCodes) + router.With(s.verifyCSRFHeader, s.requireBuiltinLogin).Post(webAdminRecoveryCodesPath, generateRecoveryCodes) + + router.Group(func(router chi.Router) { + router.Use(s.checkAuthRequirements) + + router.With(s.checkPerms(dataprovider.PermAdminViewUsers), s.refreshCookie). + Get(webUsersPath, s.handleGetWebUsers) + router.With(s.checkPerms(dataprovider.PermAdminViewUsers), compressor.Handler, s.refreshCookie). + Get(webUsersPath+jsonAPISuffix, getAllUsers) + router.With(s.checkPerms(dataprovider.PermAdminAddUsers), s.refreshCookie). + Get(webUserPath, s.handleWebAddUserGet) + router.With(s.checkPerms(dataprovider.PermAdminChangeUsers), s.refreshCookie). + Get(webUserPath+"/{username}", s.handleWebUpdateUserGet) + router.With(s.checkPerms(dataprovider.PermAdminAddUsers)).Post(webUserPath, s.handleWebAddUserPost) + router.With(s.checkPerms(dataprovider.PermAdminChangeUsers)).Post(webUserPath+"/{username}", + s.handleWebUpdateUserPost) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.refreshCookie). + Get(webGroupsPath, s.handleWebGetGroups) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups), compressor.Handler, s.refreshCookie). + Get(webGroupsPath+jsonAPISuffix, getAllGroups) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.refreshCookie). + Get(webGroupPath, s.handleWebAddGroupGet) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Post(webGroupPath, s.handleWebAddGroupPost) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.refreshCookie). + Get(webGroupPath+"/{name}", s.handleWebUpdateGroupGet) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups)).Post(webGroupPath+"/{name}", + s.handleWebUpdateGroupPost) + router.With(s.checkPerms(dataprovider.PermAdminManageGroups), s.verifyCSRFHeader). + Delete(webGroupPath+"/{name}", deleteGroup) + router.With(s.checkPerms(dataprovider.PermAdminViewConnections), s.refreshCookie). + Get(webConnectionsPath, s.handleWebGetConnections) + router.With(s.checkPerms(dataprovider.PermAdminViewConnections), s.refreshCookie). + Get(webConnectionsPath+jsonAPISuffix, getActiveConnections) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). + Get(webFoldersPath, s.handleWebGetFolders) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders), compressor.Handler, s.refreshCookie). + Get(webFoldersPath+jsonAPISuffix, getAllFolders) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). + Get(webFolderPath, s.handleWebAddFolderGet) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(webFolderPath, s.handleWebAddFolderPost) + router.With(s.checkPerms(dataprovider.PermAdminViewServerStatus), s.refreshCookie). + Get(webStatusPath, s.handleWebGetStatus) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminsPath, s.handleGetWebAdmins) + router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). + Get(webAdminsPath+jsonAPISuffix, getAllAdmins) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminPath, s.handleWebAddAdminGet) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminPath+"/{username}", s.handleWebUpdateAdminGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminPath, s.handleWebAddAdminPost) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminPath+"/{username}", + s.handleWebUpdateAdminPost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). + Delete(webAdminPath+"/{username}", deleteAdmin) + router.With(s.checkPerms(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader). + Put(webAdminPath+"/{username}/2fa/disable", disableAdmin2FA) + router.With(s.checkPerms(dataprovider.PermAdminCloseConnections), s.verifyCSRFHeader). + Delete(webConnectionsPath+"/{connectionID}", handleCloseConnection) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). + Get(webFolderPath+"/{name}", s.handleWebUpdateFolderGet) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(webFolderPath+"/{name}", + s.handleWebUpdateFolderPost) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.verifyCSRFHeader). + Delete(webFolderPath+"/{name}", deleteFolder) + router.With(s.checkPerms(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader). + Post(webScanVFolderPath+"/{name}", startFolderQuotaScan) + router.With(s.checkPerms(dataprovider.PermAdminDeleteUsers), s.verifyCSRFHeader). + Delete(webUserPath+"/{username}", deleteUser) + router.With(s.checkPerms(dataprovider.PermAdminDisableMFA), s.verifyCSRFHeader). + Put(webUserPath+"/{username}/2fa/disable", disableUser2FA) + router.With(s.checkPerms(dataprovider.PermAdminQuotaScans), s.verifyCSRFHeader). + Post(webQuotaScanPath+"/{username}", startUserQuotaScan) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(webMaintenancePath, s.handleWebMaintenance) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(webBackupPath, dumpData) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webRestorePath, s.handleWebRestore) + router.With(s.checkPerms(dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers), s.refreshCookie). + Get(webTemplateUser, s.handleWebTemplateUserGet) + router.With(s.checkPerms(dataprovider.PermAdminAddUsers, dataprovider.PermAdminChangeUsers)). + Post(webTemplateUser, s.handleWebTemplateUserPost) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders), s.refreshCookie). + Get(webTemplateFolder, s.handleWebTemplateFolderGet) + router.With(s.checkPerms(dataprovider.PermAdminManageFolders)).Post(webTemplateFolder, s.handleWebTemplateFolderPost) + router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(webDefenderPath, s.handleWebDefenderPage) + router.With(s.checkPerms(dataprovider.PermAdminViewDefender)).Get(webDefenderHostsPath, getDefenderHosts) + router.With(s.checkPerms(dataprovider.PermAdminManageDefender), s.verifyCSRFHeader). + Delete(webDefenderHostsPath+"/{id}", deleteDefenderHostByID) + router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). + Get(webAdminEventActionsPath+jsonAPISuffix, getAllActions) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminEventActionsPath, s.handleWebGetEventActions) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminEventActionPath, s.handleWebAddEventActionGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventActionPath, + s.handleWebAddEventActionPost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminEventActionPath+"/{name}", s.handleWebUpdateEventActionGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventActionPath+"/{name}", + s.handleWebUpdateEventActionPost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). + Delete(webAdminEventActionPath+"/{name}", deleteEventAction) + router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). + Get(webAdminEventRulesPath+jsonAPISuffix, getAllRules) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminEventRulesPath, s.handleWebGetEventRules) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminEventRulePath, s.handleWebAddEventRuleGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventRulePath, + s.handleWebAddEventRulePost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminEventRulePath+"/{name}", s.handleWebUpdateEventRuleGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminEventRulePath+"/{name}", + s.handleWebUpdateEventRulePost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). + Delete(webAdminEventRulePath+"/{name}", deleteEventRule) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). + Post(webAdminEventRulePath+"/run/{name}", runOnDemandRule) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminRolesPath, s.handleWebGetRoles) + router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). + Get(webAdminRolesPath+jsonAPISuffix, getAllRoles) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminRolePath, s.handleWebAddRoleGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminRolePath, s.handleWebAddRolePost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie). + Get(webAdminRolePath+"/{name}", s.handleWebUpdateRoleGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webAdminRolePath+"/{name}", + s.handleWebUpdateRolePost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). + Delete(webAdminRolePath+"/{name}", deleteRole) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), s.refreshCookie).Get(webEventsPath, + s.handleWebGetEvents) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler, s.refreshCookie). + Get(webEventsFsSearchPath, searchFsEvents) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler, s.refreshCookie). + Get(webEventsProviderSearchPath, searchProviderEvents) + router.With(s.checkPerms(dataprovider.PermAdminViewEvents), compressor.Handler, s.refreshCookie). + Get(webEventsLogSearchPath, searchLogEvents) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Get(webIPListsPath, s.handleWebIPListsPage) + router.With(s.checkPerms(dataprovider.PermAdminAny), compressor.Handler, s.refreshCookie). + Get(webIPListsPath+"/{type}", getIPListEntries) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie).Get(webIPListPath+"/{type}", + s.handleWebAddIPListEntryGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webIPListPath+"/{type}", + s.handleWebAddIPListEntryPost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie).Get(webIPListPath+"/{type}/{ipornet}", + s.handleWebUpdateIPListEntryGet) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webIPListPath+"/{type}/{ipornet}", + s.handleWebUpdateIPListEntryPost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader). + Delete(webIPListPath+"/{type}/{ipornet}", deleteIPListEntry) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.refreshCookie).Get(webConfigsPath, s.handleWebConfigs) + router.With(s.checkPerms(dataprovider.PermAdminAny)).Post(webConfigsPath, s.handleWebConfigsPost) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader, s.refreshCookie). + Post(webConfigsPath+"/smtp/test", testSMTPConfig) + router.With(s.checkPerms(dataprovider.PermAdminAny), s.verifyCSRFHeader, s.refreshCookie). + Post(webOAuth2TokenPath, s.handleSMTPOAuth2TokenRequestPost) + }) + }) + } +} diff --git a/internal/httpd/token.go b/internal/httpd/token.go new file mode 100644 index 00000000..4a0f9873 --- /dev/null +++ b/internal/httpd/token.go @@ -0,0 +1,95 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func newTokenManager(isShared int) tokenManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider token manager") + return &dbTokenManager{} + } + logger.Info(logSender, "", "using memory token manager") + return &memoryTokenManager{} +} + +type tokenManager interface { + Add(token string, expiresAt time.Time) + Get(token string) bool + Cleanup() +} + +type memoryTokenManager struct { + invalidatedJWTTokens sync.Map +} + +func (m *memoryTokenManager) Add(token string, expiresAt time.Time) { + m.invalidatedJWTTokens.Store(token, expiresAt) +} + +func (m *memoryTokenManager) Get(token string) bool { + _, ok := m.invalidatedJWTTokens.Load(token) + return ok +} + +func (m *memoryTokenManager) Cleanup() { + m.invalidatedJWTTokens.Range(func(key, value any) bool { + exp, ok := value.(time.Time) + if !ok || exp.Before(time.Now().UTC()) { + m.invalidatedJWTTokens.Delete(key) + } + return true + }) +} + +type dbTokenManager struct{} + +func (m *dbTokenManager) getKey(token string) string { + digest := sha256.Sum256([]byte(token)) + return hex.EncodeToString(digest[:]) +} + +func (m *dbTokenManager) Add(token string, expiresAt time.Time) { + key := m.getKey(token) + data := map[string]string{ + "jwt": token, + } + session := dataprovider.Session{ + Key: key, + Data: data, + Type: dataprovider.SessionTypeInvalidToken, + Timestamp: util.GetTimeAsMsSinceEpoch(expiresAt), + } + dataprovider.AddSharedSession(session) //nolint:errcheck +} + +func (m *dbTokenManager) Get(token string) bool { + key := m.getKey(token) + _, err := dataprovider.GetSharedSession(key, dataprovider.SessionTypeInvalidToken) + return err == nil +} + +func (m *dbTokenManager) Cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeInvalidToken, time.Now()) //nolint:errcheck +} diff --git a/internal/httpd/web.go b/internal/httpd/web.go new file mode 100644 index 00000000..da7e7542 --- /dev/null +++ b/internal/httpd/web.go @@ -0,0 +1,164 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "errors" + "net/http" + "strings" + + "github.com/go-chi/render" + "github.com/unrolled/secure" + + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + webDateTimeFormat = "2006-01-02 15:04:05" // YYYY-MM-DD HH:MM:SS + redactedSecret = "[**redacted**]" + csrfFormToken = "_form_token" + csrfHeaderToken = "X-CSRF-TOKEN" + templateCommonDir = "common" + templateTwoFactor = "twofactor.html" + templateTwoFactorRecovery = "twofactor-recovery.html" + templateForgotPassword = "forgot-password.html" + templateResetPassword = "reset-password.html" + templateChangePwd = "changepassword.html" + templateMessage = "message.html" + templateCommonBase = "base.html" + templateCommonBaseLogin = "baselogin.html" + templateCommonLogin = "login.html" +) + +var ( + errInvalidTokenClaims = errors.New("invalid token claims") +) + +type commonBasePage struct { + CSPNonce string + StaticURL string + Version string +} + +type loginPage struct { + commonBasePage + CurrentURL string + Error *util.I18nError + CSRFToken string + AltLoginURL string + AltLoginName string + ForgotPwdURL string + OpenIDLoginURL string + Title string + Branding UIBranding + Languages []string + FormDisabled bool + CheckRedirect bool +} + +type twoFactorPage struct { + commonBasePage + CurrentURL string + Error *util.I18nError + CSRFToken string + RecoveryURL string + Title string + Branding UIBranding + Languages []string + CheckRedirect bool +} + +type forgotPwdPage struct { + commonBasePage + CurrentURL string + Error *util.I18nError + CSRFToken string + LoginURL string + Title string + Branding UIBranding + Languages []string + CheckRedirect bool +} + +type resetPwdPage struct { + commonBasePage + CurrentURL string + Error *util.I18nError + CSRFToken string + LoginURL string + Title string + Branding UIBranding + Languages []string + CheckRedirect bool +} + +func getSliceFromDelimitedValues(values, delimiter string) []string { + result := []string{} + for v := range strings.SplitSeq(values, delimiter) { + cleaned := strings.TrimSpace(v) + if cleaned != "" { + result = append(result, cleaned) + } + } + return result +} + +func hasPrefixAndSuffix(key, prefix, suffix string) bool { + return strings.HasPrefix(key, prefix) && strings.HasSuffix(key, suffix) +} + +func getCommonBasePage(r *http.Request) commonBasePage { + return commonBasePage{ + CSPNonce: secure.CSPNonce(r.Context()), + StaticURL: webStaticFilesPath, + Version: version.GetServerVersion(" ", true), + } +} + +func i18nListDirMsg(status int) string { + if status == http.StatusForbidden { + return util.I18nErrorDirList403 + } + return util.I18nErrorDirListGeneric +} + +func i18nFsMsg(status int) string { + if status == http.StatusForbidden { + return util.I18nError403Message + } + return util.I18nErrorFsGeneric +} + +func getI18NErrorString(err error, fallback string) string { + var errI18n *util.I18nError + if errors.As(err, &errI18n) { + return errI18n.Message + } + return fallback +} + +func getI18nError(err error) *util.I18nError { + var errI18n *util.I18nError + if err != nil { + errI18n = util.NewI18nError(err, util.I18nError500Message) + } + return errI18n +} + +func handlePingRequest(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + render.PlainText(w, r, "PONG") +} diff --git a/internal/httpd/webadmin.go b/internal/httpd/webadmin.go new file mode 100644 index 00000000..a565c05e --- /dev/null +++ b/internal/httpd/webadmin.go @@ -0,0 +1,4461 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "slices" + "sort" + "strconv" + "strings" + "time" + + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "golang.org/x/oauth2" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +type userPageMode int + +const ( + userPageModeAdd userPageMode = iota + 1 + userPageModeUpdate + userPageModeTemplate +) + +type folderPageMode int + +const ( + folderPageModeAdd folderPageMode = iota + 1 + folderPageModeUpdate + folderPageModeTemplate +) + +type genericPageMode int + +const ( + genericPageModeAdd genericPageMode = iota + 1 + genericPageModeUpdate +) + +const ( + templateAdminDir = "webadmin" + templateBase = "base.html" + templateFsConfig = "fsconfig.html" + templateUsers = "users.html" + templateUser = "user.html" + templateAdmins = "admins.html" + templateAdmin = "admin.html" + templateConnections = "connections.html" + templateGroups = "groups.html" + templateGroup = "group.html" + templateFolders = "folders.html" + templateFolder = "folder.html" + templateEventRules = "eventrules.html" + templateEventRule = "eventrule.html" + templateEventActions = "eventactions.html" + templateEventAction = "eventaction.html" + templateRoles = "roles.html" + templateRole = "role.html" + templateEvents = "events.html" + templateStatus = "status.html" + templateDefender = "defender.html" + templateIPLists = "iplists.html" + templateIPList = "iplist.html" + templateConfigs = "configs.html" + templateProfile = "profile.html" + templateMaintenance = "maintenance.html" + templateMFA = "mfa.html" + templateSetup = "adminsetup.html" + defaultQueryLimit = 1000 + inversePatternType = "inverse" +) + +var ( + adminTemplates = make(map[string]*template.Template) +) + +type basePage struct { + commonBasePage + Title string + CurrentURL string + UsersURL string + UserURL string + UserTemplateURL string + AdminsURL string + AdminURL string + QuotaScanURL string + ConnectionsURL string + GroupsURL string + GroupURL string + FoldersURL string + FolderURL string + FolderTemplateURL string + DefenderURL string + IPListsURL string + IPListURL string + EventsURL string + ConfigsURL string + LogoutURL string + LoginURL string + ProfileURL string + ChangePwdURL string + MFAURL string + EventRulesURL string + EventRuleURL string + EventActionsURL string + EventActionURL string + RolesURL string + RoleURL string + FolderQuotaScanURL string + StatusURL string + MaintenanceURL string + CSRFToken string + IsEventManagerPage bool + IsIPManagerPage bool + IsServerManagerPage bool + HasDefender bool + HasSearcher bool + HasExternalLogin bool + LoggedUser *dataprovider.Admin + IsLoggedToShare bool + Branding UIBranding + Languages []string +} + +type statusPage struct { + basePage + Status *ServicesStatus +} + +type fsWrapper struct { + vfs.Filesystem + IsUserPage bool + IsGroupPage bool + IsHidden bool + HasUsersBaseDir bool + DirPath string +} + +type userPage struct { + basePage + User *dataprovider.User + RootPerms []string + Error *util.I18nError + ValidPerms []string + ValidLoginMethods []string + ValidProtocols []string + TwoFactorProtocols []string + WebClientOptions []string + RootDirPerms []string + Mode userPageMode + VirtualFolders []vfs.BaseVirtualFolder + Groups []dataprovider.Group + Roles []dataprovider.Role + CanImpersonate bool + FsWrapper fsWrapper + CanUseTLSCerts bool +} + +type adminPage struct { + basePage + Admin *dataprovider.Admin + Groups []dataprovider.Group + Roles []dataprovider.Role + Error *util.I18nError + IsAdd bool +} + +type profilePage struct { + basePage + Error *util.I18nError + AllowAPIKeyAuth bool + Email string + Description string +} + +type changePasswordPage struct { + basePage + Error *util.I18nError +} + +type mfaPage struct { + basePage + TOTPConfigs []string + TOTPConfig dataprovider.AdminTOTPConfig + GenerateTOTPURL string + ValidateTOTPURL string + SaveTOTPURL string + RecCodesURL string + RequireTwoFactor bool +} + +type maintenancePage struct { + basePage + BackupPath string + RestorePath string + Error *util.I18nError +} + +type defenderHostsPage struct { + basePage + DefenderHostsURL string +} + +type ipListsPage struct { + basePage + IPListsSearchURL string + RateLimitersStatus bool + RateLimitersProtocols string + IsAllowListEnabled bool +} + +type ipListPage struct { + basePage + Entry *dataprovider.IPListEntry + Error *util.I18nError + Mode genericPageMode +} + +type setupPage struct { + commonBasePage + CurrentURL string + Error *util.I18nError + CSRFToken string + Username string + HasInstallationCode bool + InstallationCodeHint string + HideSupportLink bool + Title string + Branding UIBranding + Languages []string + CheckRedirect bool +} + +type folderPage struct { + basePage + Folder vfs.BaseVirtualFolder + Error *util.I18nError + Mode folderPageMode + FsWrapper fsWrapper +} + +type groupPage struct { + basePage + Group *dataprovider.Group + Error *util.I18nError + Mode genericPageMode + ValidPerms []string + ValidLoginMethods []string + ValidProtocols []string + TwoFactorProtocols []string + WebClientOptions []string + VirtualFolders []vfs.BaseVirtualFolder + FsWrapper fsWrapper +} + +type rolePage struct { + basePage + Role *dataprovider.Role + Error *util.I18nError + Mode genericPageMode +} + +type eventActionPage struct { + basePage + Action dataprovider.BaseEventAction + ActionTypes []dataprovider.EnumMapping + FsActions []dataprovider.EnumMapping + HTTPMethods []string + EnabledCommands []string + RedactedSecret string + Error *util.I18nError + Mode genericPageMode +} + +type eventRulePage struct { + basePage + Rule dataprovider.EventRule + TriggerTypes []dataprovider.EnumMapping + Actions []dataprovider.BaseEventAction + FsEvents []string + Protocols []string + ProviderEvents []string + ProviderObjects []string + Error *util.I18nError + Mode genericPageMode + IsShared bool +} + +type eventsPage struct { + basePage + FsEventsSearchURL string + ProviderEventsSearchURL string + LogEventsSearchURL string +} + +type configsPage struct { + basePage + Configs dataprovider.Configs + ConfigSection int + RedactedSecret string + OAuth2TokenURL string + OAuth2RedirectURL string + WebClientBranding UIBranding + Error *util.I18nError +} + +type messagePage struct { + basePage + Error *util.I18nError + Success string + Text string +} + +type userTemplateFields struct { + Username string + Password string + PublicKeys []string + RequirePwdChange bool +} + +func loadAdminTemplates(templatesPath string) { + usersPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateUsers), + } + userPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateFsConfig), + filepath.Join(templatesPath, templateAdminDir, templateUser), + } + adminsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateAdmins), + } + adminPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateAdmin), + } + profilePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateProfile), + } + changePwdPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateCommonDir, templateChangePwd), + } + connectionsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateConnections), + } + messagePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateCommonDir, templateMessage), + } + foldersPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateFolders), + } + folderPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateFsConfig), + filepath.Join(templatesPath, templateAdminDir, templateFolder), + } + groupsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateGroups), + } + groupPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateFsConfig), + filepath.Join(templatesPath, templateAdminDir, templateGroup), + } + eventRulesPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateEventRules), + } + eventRulePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateEventRule), + } + eventActionsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateEventActions), + } + eventActionPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateEventAction), + } + statusPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateStatus), + } + loginPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateCommonLogin), + } + maintenancePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateMaintenance), + } + defenderPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateDefender), + } + ipListsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateIPLists), + } + ipListPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateIPList), + } + mfaPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateMFA), + } + twoFactorPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateTwoFactor), + } + twoFactorRecoveryPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateTwoFactorRecovery), + } + setupPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateAdminDir, templateSetup), + } + forgotPwdPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateForgotPassword), + } + resetPwdPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateResetPassword), + } + rolesPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateRoles), + } + rolePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateRole), + } + eventsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateEvents), + } + configsPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateAdminDir, templateBase), + filepath.Join(templatesPath, templateAdminDir, templateConfigs), + } + + fsBaseTpl := template.New("fsBaseTemplate").Funcs(template.FuncMap{ + "HumanizeBytes": util.ByteCountSI, + }) + usersTmpl := util.LoadTemplate(nil, usersPaths...) + userTmpl := util.LoadTemplate(fsBaseTpl, userPaths...) + adminsTmpl := util.LoadTemplate(nil, adminsPaths...) + adminTmpl := util.LoadTemplate(nil, adminPaths...) + connectionsTmpl := util.LoadTemplate(nil, connectionsPaths...) + messageTmpl := util.LoadTemplate(nil, messagePaths...) + groupsTmpl := util.LoadTemplate(nil, groupsPaths...) + groupTmpl := util.LoadTemplate(fsBaseTpl, groupPaths...) + foldersTmpl := util.LoadTemplate(nil, foldersPaths...) + folderTmpl := util.LoadTemplate(fsBaseTpl, folderPaths...) + eventRulesTmpl := util.LoadTemplate(nil, eventRulesPaths...) + eventRuleTmpl := util.LoadTemplate(fsBaseTpl, eventRulePaths...) + eventActionsTmpl := util.LoadTemplate(nil, eventActionsPaths...) + eventActionTmpl := util.LoadTemplate(nil, eventActionPaths...) + statusTmpl := util.LoadTemplate(nil, statusPaths...) + loginTmpl := util.LoadTemplate(nil, loginPaths...) + profileTmpl := util.LoadTemplate(nil, profilePaths...) + changePwdTmpl := util.LoadTemplate(nil, changePwdPaths...) + maintenanceTmpl := util.LoadTemplate(nil, maintenancePaths...) + defenderTmpl := util.LoadTemplate(nil, defenderPaths...) + ipListsTmpl := util.LoadTemplate(nil, ipListsPaths...) + ipListTmpl := util.LoadTemplate(nil, ipListPaths...) + mfaTmpl := util.LoadTemplate(nil, mfaPaths...) + twoFactorTmpl := util.LoadTemplate(nil, twoFactorPaths...) + twoFactorRecoveryTmpl := util.LoadTemplate(nil, twoFactorRecoveryPaths...) + setupTmpl := util.LoadTemplate(nil, setupPaths...) + forgotPwdTmpl := util.LoadTemplate(nil, forgotPwdPaths...) + resetPwdTmpl := util.LoadTemplate(nil, resetPwdPaths...) + rolesTmpl := util.LoadTemplate(nil, rolesPaths...) + roleTmpl := util.LoadTemplate(nil, rolePaths...) + eventsTmpl := util.LoadTemplate(nil, eventsPaths...) + configsTmpl := util.LoadTemplate(nil, configsPaths...) + + adminTemplates[templateUsers] = usersTmpl + adminTemplates[templateUser] = userTmpl + adminTemplates[templateAdmins] = adminsTmpl + adminTemplates[templateAdmin] = adminTmpl + adminTemplates[templateConnections] = connectionsTmpl + adminTemplates[templateMessage] = messageTmpl + adminTemplates[templateGroups] = groupsTmpl + adminTemplates[templateGroup] = groupTmpl + adminTemplates[templateFolders] = foldersTmpl + adminTemplates[templateFolder] = folderTmpl + adminTemplates[templateEventRules] = eventRulesTmpl + adminTemplates[templateEventRule] = eventRuleTmpl + adminTemplates[templateEventActions] = eventActionsTmpl + adminTemplates[templateEventAction] = eventActionTmpl + adminTemplates[templateStatus] = statusTmpl + adminTemplates[templateCommonLogin] = loginTmpl + adminTemplates[templateProfile] = profileTmpl + adminTemplates[templateChangePwd] = changePwdTmpl + adminTemplates[templateMaintenance] = maintenanceTmpl + adminTemplates[templateDefender] = defenderTmpl + adminTemplates[templateIPLists] = ipListsTmpl + adminTemplates[templateIPList] = ipListTmpl + adminTemplates[templateMFA] = mfaTmpl + adminTemplates[templateTwoFactor] = twoFactorTmpl + adminTemplates[templateTwoFactorRecovery] = twoFactorRecoveryTmpl + adminTemplates[templateSetup] = setupTmpl + adminTemplates[templateForgotPassword] = forgotPwdTmpl + adminTemplates[templateResetPassword] = resetPwdTmpl + adminTemplates[templateRoles] = rolesTmpl + adminTemplates[templateRole] = roleTmpl + adminTemplates[templateEvents] = eventsTmpl + adminTemplates[templateConfigs] = configsTmpl +} + +func isEventManagerResource(currentURL string) bool { + if currentURL == webAdminEventRulesPath { + return true + } + if currentURL == webAdminEventActionsPath { + return true + } + if currentURL == webAdminEventRulePath || strings.HasPrefix(currentURL, webAdminEventRulePath+"/") { + return true + } + if currentURL == webAdminEventActionPath || strings.HasPrefix(currentURL, webAdminEventActionPath+"/") { + return true + } + return false +} + +func isIPListsResource(currentURL string) bool { + if currentURL == webDefenderPath { + return true + } + if currentURL == webIPListsPath { + return true + } + if strings.HasPrefix(currentURL, webIPListPath+"/") { + return true + } + return false +} + +func isServerManagerResource(currentURL string) bool { + return currentURL == webEventsPath || currentURL == webStatusPath || currentURL == webMaintenancePath || + currentURL == webConfigsPath +} + +func (s *httpdServer) getBasePageData(title, currentURL string, w http.ResponseWriter, r *http.Request) basePage { + var csrfToken string + if currentURL != "" { + csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath) + } + return basePage{ + commonBasePage: getCommonBasePage(r), + Title: title, + CurrentURL: currentURL, + UsersURL: webUsersPath, + UserURL: webUserPath, + UserTemplateURL: webTemplateUser, + AdminsURL: webAdminsPath, + AdminURL: webAdminPath, + GroupsURL: webGroupsPath, + GroupURL: webGroupPath, + FoldersURL: webFoldersPath, + FolderURL: webFolderPath, + FolderTemplateURL: webTemplateFolder, + DefenderURL: webDefenderPath, + IPListsURL: webIPListsPath, + IPListURL: webIPListPath, + EventsURL: webEventsPath, + ConfigsURL: webConfigsPath, + LogoutURL: webLogoutPath, + LoginURL: webAdminLoginPath, + ProfileURL: webAdminProfilePath, + ChangePwdURL: webChangeAdminPwdPath, + MFAURL: webAdminMFAPath, + EventRulesURL: webAdminEventRulesPath, + EventRuleURL: webAdminEventRulePath, + EventActionsURL: webAdminEventActionsPath, + EventActionURL: webAdminEventActionPath, + RolesURL: webAdminRolesPath, + RoleURL: webAdminRolePath, + QuotaScanURL: webQuotaScanPath, + ConnectionsURL: webConnectionsPath, + StatusURL: webStatusPath, + FolderQuotaScanURL: webScanVFolderPath, + MaintenanceURL: webMaintenancePath, + LoggedUser: getAdminFromToken(r), + IsEventManagerPage: isEventManagerResource(currentURL), + IsIPManagerPage: isIPListsResource(currentURL), + IsServerManagerPage: isServerManagerResource(currentURL), + HasDefender: common.Config.DefenderConfig.Enabled, + HasSearcher: plugin.Handler.HasSearcher(), + HasExternalLogin: isLoggedInWithOIDC(r), + CSRFToken: csrfToken, + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + } +} + +func renderAdminTemplate(w http.ResponseWriter, tmplName string, data any) { + err := adminTemplates[tmplName].ExecuteTemplate(w, tmplName, data) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (s *httpdServer) renderMessagePageWithString(w http.ResponseWriter, r *http.Request, title string, statusCode int, + err error, message, text string, +) { + data := messagePage{ + basePage: s.getBasePageData(title, "", w, r), + Error: getI18nError(err), + Success: message, + Text: text, + } + w.WriteHeader(statusCode) + renderAdminTemplate(w, templateMessage, data) +} + +func (s *httpdServer) renderMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, + err error, message string, +) { + s.renderMessagePageWithString(w, r, title, statusCode, err, message, "") +} + +func (s *httpdServer) renderInternalServerErrorPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderMessagePage(w, r, util.I18nError500Title, http.StatusInternalServerError, + util.NewI18nError(err, util.I18nError500Message), "") +} + +func (s *httpdServer) renderBadRequestPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderMessagePage(w, r, util.I18nError400Title, http.StatusBadRequest, + util.NewI18nError(err, util.I18nError400Message), "") +} + +func (s *httpdServer) renderForbiddenPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderMessagePage(w, r, util.I18nError403Title, http.StatusForbidden, + util.NewI18nError(err, util.I18nError403Message), "") +} + +func (s *httpdServer) renderNotFoundPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderMessagePage(w, r, util.I18nError404Title, http.StatusNotFound, + util.NewI18nError(err, util.I18nError404Message), "") +} + +func (s *httpdServer) renderForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := forgotPwdPage{ + commonBasePage: getCommonBasePage(r), + CurrentURL: webAdminForgotPwdPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), + LoginURL: webAdminLoginPath, + Title: util.I18nForgotPwdTitle, + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + } + renderAdminTemplate(w, templateForgotPassword, data) +} + +func (s *httpdServer) renderResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := resetPwdPage{ + commonBasePage: getCommonBasePage(r), + CurrentURL: webAdminResetPwdPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), + LoginURL: webAdminLoginPath, + Title: util.I18nResetPwdTitle, + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + } + renderAdminTemplate(w, templateResetPassword, data) +} + +func (s *httpdServer) renderTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := twoFactorPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18n2FATitle, + CurrentURL: webAdminTwoFactorPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), + RecoveryURL: webAdminTwoFactorRecoveryPath, + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + } + renderAdminTemplate(w, templateTwoFactor, data) +} + +func (s *httpdServer) renderTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := twoFactorPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18n2FATitle, + CurrentURL: webAdminTwoFactorRecoveryPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseAdminPath), + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + } + renderAdminTemplate(w, templateTwoFactorRecovery, data) +} + +func (s *httpdServer) renderMFAPage(w http.ResponseWriter, r *http.Request) { + data := mfaPage{ + basePage: s.getBasePageData(util.I18n2FATitle, webAdminMFAPath, w, r), + TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), + GenerateTOTPURL: webAdminTOTPGeneratePath, + ValidateTOTPURL: webAdminTOTPValidatePath, + SaveTOTPURL: webAdminTOTPSavePath, + RecCodesURL: webAdminRecoveryCodesPath, + } + admin, err := dataprovider.AdminExists(data.LoggedUser.Username) + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + data.TOTPConfig = admin.Filters.TOTPConfig + data.RequireTwoFactor = admin.Filters.RequireTwoFactor + renderAdminTemplate(w, templateMFA, data) +} + +func (s *httpdServer) renderProfilePage(w http.ResponseWriter, r *http.Request, err error) { + data := profilePage{ + basePage: s.getBasePageData(util.I18nProfileTitle, webAdminProfilePath, w, r), + Error: getI18nError(err), + } + admin, err := dataprovider.AdminExists(data.LoggedUser.Username) + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + data.AllowAPIKeyAuth = admin.Filters.AllowAPIKeyAuth + data.Email = admin.Email + data.Description = admin.Description + + renderAdminTemplate(w, templateProfile, data) +} + +func (s *httpdServer) renderChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := changePasswordPage{ + basePage: s.getBasePageData(util.I18nChangePwdTitle, webChangeAdminPwdPath, w, r), + Error: err, + } + + renderAdminTemplate(w, templateChangePwd, data) +} + +func (s *httpdServer) renderMaintenancePage(w http.ResponseWriter, r *http.Request, err error) { + data := maintenancePage{ + basePage: s.getBasePageData(util.I18nMaintenanceTitle, webMaintenancePath, w, r), + BackupPath: webBackupPath, + RestorePath: webRestorePath, + Error: getI18nError(err), + } + + renderAdminTemplate(w, templateMaintenance, data) +} + +func (s *httpdServer) renderConfigsPage(w http.ResponseWriter, r *http.Request, configs dataprovider.Configs, + err error, section int, +) { + configs.SetNilsToEmpty() + if configs.SMTP.Port == 0 { + configs.SMTP.Port = 587 + configs.SMTP.AuthType = 1 + configs.SMTP.Encryption = 2 + } + if configs.ACME.HTTP01Challenge.Port == 0 { + configs.ACME.HTTP01Challenge.Port = 80 + } + data := configsPage{ + basePage: s.getBasePageData(util.I18nConfigsTitle, webConfigsPath, w, r), + Configs: configs, + ConfigSection: section, + RedactedSecret: redactedSecret, + OAuth2TokenURL: webOAuth2TokenPath, + OAuth2RedirectURL: webOAuth2RedirectPath, + WebClientBranding: s.binding.webClientBranding(), + Error: getI18nError(err), + } + + renderAdminTemplate(w, templateConfigs, data) +} + +func (s *httpdServer) renderAdminSetupPage(w http.ResponseWriter, r *http.Request, username string, err *util.I18nError) { + data := setupPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18nSetupTitle, + CurrentURL: webAdminSetupPath, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseAdminPath), + Username: username, + HasInstallationCode: installationCode != "", + InstallationCodeHint: installationCodeHint, + HideSupportLink: hideSupportLink, + Error: err, + Branding: s.binding.webAdminBranding(), + Languages: s.binding.languages(), + } + + renderAdminTemplate(w, templateSetup, data) +} + +func (s *httpdServer) renderAddUpdateAdminPage(w http.ResponseWriter, r *http.Request, admin *dataprovider.Admin, + err error, isAdd bool) { + groups, errGroups := s.getWebGroups(w, r, defaultQueryLimit, true) + if errGroups != nil { + return + } + roles, errRoles := s.getWebRoles(w, r, 10, true) + if errRoles != nil { + return + } + currentURL := webAdminPath + title := util.I18nAddAdminTitle + if !isAdd { + currentURL = fmt.Sprintf("%v/%v", webAdminPath, url.PathEscape(admin.Username)) + title = util.I18nUpdateAdminTitle + } + data := adminPage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Admin: admin, + Groups: groups, + Roles: roles, + Error: getI18nError(err), + IsAdd: isAdd, + } + + renderAdminTemplate(w, templateAdmin, data) +} + +func (s *httpdServer) getUserPageTitleAndURL(mode userPageMode, username string) (string, string) { + var title, currentURL string + switch mode { + case userPageModeAdd: + title = util.I18nAddUserTitle + currentURL = webUserPath + case userPageModeUpdate: + title = util.I18nUpdateUserTitle + currentURL = fmt.Sprintf("%v/%v", webUserPath, url.PathEscape(username)) + case userPageModeTemplate: + title = util.I18nTemplateUserTitle + currentURL = webTemplateUser + } + return title, currentURL +} + +func (s *httpdServer) renderUserPage(w http.ResponseWriter, r *http.Request, user *dataprovider.User, + mode userPageMode, err error, admin *dataprovider.Admin, +) { + user.SetEmptySecretsIfNil() + title, currentURL := s.getUserPageTitleAndURL(mode, user.Username) + if user.Password != "" && user.IsPasswordHashed() { + switch mode { + case userPageModeUpdate: + user.Password = redactedSecret + default: + user.Password = "" + } + } + user.FsConfig.RedactedSecret = redactedSecret + basePage := s.getBasePageData(title, currentURL, w, r) + if (mode == userPageModeAdd || mode == userPageModeTemplate) && len(user.Groups) == 0 && admin != nil { + for _, group := range admin.Groups { + user.Groups = append(user.Groups, sdk.GroupMapping{ + Name: group.Name, + Type: group.Options.GetUserGroupType(), + }) + } + } + var roles []dataprovider.Role + if basePage.LoggedUser.Role == "" { + var errRoles error + roles, errRoles = s.getWebRoles(w, r, 10, true) + if errRoles != nil { + return + } + } + folders, errFolders := s.getWebVirtualFolders(w, r, defaultQueryLimit, true) + if errFolders != nil { + return + } + groups, errGroups := s.getWebGroups(w, r, defaultQueryLimit, true) + if errGroups != nil { + return + } + data := userPage{ + basePage: basePage, + Mode: mode, + Error: getI18nError(err), + User: user, + ValidPerms: dataprovider.ValidPerms, + ValidLoginMethods: dataprovider.ValidLoginMethods, + ValidProtocols: dataprovider.ValidProtocols, + TwoFactorProtocols: dataprovider.MFAProtocols, + WebClientOptions: sdk.WebClientOptions, + RootDirPerms: user.GetPermissionsForPath("/"), + VirtualFolders: folders, + Groups: groups, + Roles: roles, + CanImpersonate: os.Getuid() == 0, + CanUseTLSCerts: ftpd.GetStatus().IsActive || webdavd.GetStatus().IsActive, + FsWrapper: fsWrapper{ + Filesystem: user.FsConfig, + IsUserPage: true, + IsGroupPage: false, + IsHidden: basePage.LoggedUser.Filters.Preferences.HideFilesystem(), + HasUsersBaseDir: dataprovider.HasUsersBaseDir(), + DirPath: user.HomeDir, + }, + } + renderAdminTemplate(w, templateUser, data) +} + +func (s *httpdServer) renderIPListPage(w http.ResponseWriter, r *http.Request, entry dataprovider.IPListEntry, + mode genericPageMode, err error, +) { + var title, currentURL string + switch mode { + case genericPageModeAdd: + title = util.I18nAddIPListTitle + currentURL = fmt.Sprintf("%s/%d", webIPListPath, entry.Type) + case genericPageModeUpdate: + title = util.I18nUpdateIPListTitle + currentURL = fmt.Sprintf("%s/%d/%s", webIPListPath, entry.Type, url.PathEscape(entry.IPOrNet)) + } + data := ipListPage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Error: getI18nError(err), + Entry: &entry, + Mode: mode, + } + renderAdminTemplate(w, templateIPList, data) +} + +func (s *httpdServer) renderRolePage(w http.ResponseWriter, r *http.Request, role dataprovider.Role, + mode genericPageMode, err error, +) { + var title, currentURL string + switch mode { + case genericPageModeAdd: + title = util.I18nRoleAddTitle + currentURL = webAdminRolePath + case genericPageModeUpdate: + title = util.I18nRoleUpdateTitle + currentURL = fmt.Sprintf("%s/%s", webAdminRolePath, url.PathEscape(role.Name)) + } + data := rolePage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Error: getI18nError(err), + Role: &role, + Mode: mode, + } + renderAdminTemplate(w, templateRole, data) +} + +func (s *httpdServer) renderGroupPage(w http.ResponseWriter, r *http.Request, group dataprovider.Group, + mode genericPageMode, err error, +) { + folders, errFolders := s.getWebVirtualFolders(w, r, defaultQueryLimit, true) + if errFolders != nil { + return + } + group.SetEmptySecretsIfNil() + group.UserSettings.FsConfig.RedactedSecret = redactedSecret + var title, currentURL string + switch mode { + case genericPageModeAdd: + title = util.I18nAddGroupTitle + currentURL = webGroupPath + case genericPageModeUpdate: + title = util.I18nUpdateGroupTitle + currentURL = fmt.Sprintf("%v/%v", webGroupPath, url.PathEscape(group.Name)) + } + group.UserSettings.FsConfig.RedactedSecret = redactedSecret + group.UserSettings.FsConfig.SetEmptySecretsIfNil() + + data := groupPage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Error: getI18nError(err), + Group: &group, + Mode: mode, + ValidPerms: dataprovider.ValidPerms, + ValidLoginMethods: dataprovider.ValidLoginMethods, + ValidProtocols: dataprovider.ValidProtocols, + TwoFactorProtocols: dataprovider.MFAProtocols, + WebClientOptions: sdk.WebClientOptions, + VirtualFolders: folders, + FsWrapper: fsWrapper{ + Filesystem: group.UserSettings.FsConfig, + IsUserPage: false, + IsGroupPage: true, + HasUsersBaseDir: false, + DirPath: group.UserSettings.HomeDir, + }, + } + renderAdminTemplate(w, templateGroup, data) +} + +func (s *httpdServer) renderEventActionPage(w http.ResponseWriter, r *http.Request, action dataprovider.BaseEventAction, + mode genericPageMode, err error, +) { + action.Options.SetEmptySecretsIfNil() + var title, currentURL string + switch mode { + case genericPageModeAdd: + title = util.I18nAddActionTitle + currentURL = webAdminEventActionPath + case genericPageModeUpdate: + title = util.I18nUpdateActionTitle + currentURL = fmt.Sprintf("%s/%s", webAdminEventActionPath, url.PathEscape(action.Name)) + } + if action.Options.HTTPConfig.Timeout == 0 { + action.Options.HTTPConfig.Timeout = 20 + } + if action.Options.CmdConfig.Timeout == 0 { + action.Options.CmdConfig.Timeout = 20 + } + if action.Options.PwdExpirationConfig.Threshold == 0 { + action.Options.PwdExpirationConfig.Threshold = 10 + } + + data := eventActionPage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Action: action, + ActionTypes: dataprovider.EventActionTypes, + FsActions: dataprovider.FsActionTypes, + HTTPMethods: dataprovider.SupportedHTTPActionMethods, + EnabledCommands: dataprovider.EnabledActionCommands, + RedactedSecret: redactedSecret, + Error: getI18nError(err), + Mode: mode, + } + renderAdminTemplate(w, templateEventAction, data) +} + +func (s *httpdServer) renderEventRulePage(w http.ResponseWriter, r *http.Request, rule dataprovider.EventRule, + mode genericPageMode, err error, +) { + actions, errActions := s.getWebEventActions(w, r, defaultQueryLimit, true) + if errActions != nil { + return + } + var title, currentURL string + switch mode { + case genericPageModeAdd: + title = util.I18nAddRuleTitle + currentURL = webAdminEventRulePath + case genericPageModeUpdate: + title = util.I18nUpdateRuleTitle + currentURL = fmt.Sprintf("%v/%v", webAdminEventRulePath, url.PathEscape(rule.Name)) + } + + data := eventRulePage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Rule: rule, + TriggerTypes: dataprovider.EventTriggerTypes, + Actions: actions, + FsEvents: dataprovider.SupportedFsEvents, + Protocols: dataprovider.SupportedRuleConditionProtocols, + ProviderEvents: dataprovider.SupportedProviderEvents, + ProviderObjects: dataprovider.SupporteRuleConditionProviderObjects, + Error: getI18nError(err), + Mode: mode, + IsShared: s.isShared > 0, + } + renderAdminTemplate(w, templateEventRule, data) +} + +func (s *httpdServer) renderFolderPage(w http.ResponseWriter, r *http.Request, folder vfs.BaseVirtualFolder, + mode folderPageMode, err error, +) { + var title, currentURL string + switch mode { + case folderPageModeAdd: + title = util.I18nAddFolderTitle + currentURL = webFolderPath + case folderPageModeUpdate: + title = util.I18nUpdateFolderTitle + currentURL = fmt.Sprintf("%v/%v", webFolderPath, url.PathEscape(folder.Name)) + case folderPageModeTemplate: + title = util.I18nTemplateFolderTitle + currentURL = webTemplateFolder + } + folder.FsConfig.RedactedSecret = redactedSecret + folder.FsConfig.SetEmptySecretsIfNil() + + data := folderPage{ + basePage: s.getBasePageData(title, currentURL, w, r), + Error: getI18nError(err), + Folder: folder, + Mode: mode, + FsWrapper: fsWrapper{ + Filesystem: folder.FsConfig, + IsUserPage: false, + IsGroupPage: false, + HasUsersBaseDir: false, + DirPath: folder.MappedPath, + }, + } + renderAdminTemplate(w, templateFolder, data) +} + +func getFoldersForTemplate(r *http.Request) []string { + var res []string + for k := range r.Form { + if hasPrefixAndSuffix(k, "template_folders[", "][tpl_foldername]") { + r.Form.Add("tpl_foldername", r.Form.Get(k)) + } + } + folderNames := r.Form["tpl_foldername"] + folders := make(map[string]bool) + for _, name := range folderNames { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if _, ok := folders[name]; ok { + continue + } + folders[name] = true + res = append(res, name) + } + return res +} + +func getUsersForTemplate(r *http.Request) []userTemplateFields { + var res []userTemplateFields + tplUsernames := r.Form["tpl_username"] + tplPasswords := r.Form["tpl_password"] + tplPublicKeys := r.Form["tpl_public_keys"] + + users := make(map[string]bool) + for idx := range tplUsernames { + username := tplUsernames[idx] + password := "" + publicKey := "" + if len(tplPasswords) > idx { + password = strings.TrimSpace(tplPasswords[idx]) + } + if len(tplPublicKeys) > idx { + publicKey = strings.TrimSpace(tplPublicKeys[idx]) + } + if username == "" { + continue + } + if _, ok := users[username]; ok { + continue + } + + users[username] = true + res = append(res, userTemplateFields{ + Username: username, + Password: password, + PublicKeys: []string{publicKey}, + RequirePwdChange: r.Form.Get("tpl_require_password_change") != "", + }) + } + + return res +} + +func getVirtualFoldersFromPostFields(r *http.Request) []vfs.VirtualFolder { + var virtualFolders []vfs.VirtualFolder + folderPaths := r.Form["vfolder_path"] + folderNames := r.Form["vfolder_name"] + folderQuotaSizes := r.Form["vfolder_quota_size"] + folderQuotaFiles := r.Form["vfolder_quota_files"] + for idx, p := range folderPaths { + name := "" + if len(folderNames) > idx { + name = folderNames[idx] + } + if p != "" && name != "" { + vfolder := vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: name, + }, + VirtualPath: p, + QuotaFiles: -1, + QuotaSize: -1, + } + if len(folderQuotaSizes) > idx { + quotaSize, err := util.ParseBytes(folderQuotaSizes[idx]) + if err == nil { + vfolder.QuotaSize = quotaSize + } + } + if len(folderQuotaFiles) > idx { + quotaFiles, err := strconv.Atoi(folderQuotaFiles[idx]) + if err == nil { + vfolder.QuotaFiles = quotaFiles + } + } + virtualFolders = append(virtualFolders, vfolder) + } + } + + return virtualFolders +} + +func getSubDirPermissionsFromPostFields(r *http.Request) map[string][]string { + permissions := make(map[string][]string) + + for idx, p := range r.Form["sub_perm_path"] { + if p != "" { + permissions[p] = r.Form["sub_perm_permissions"+strconv.Itoa(idx)] + } + } + + return permissions +} + +func getUserPermissionsFromPostFields(r *http.Request) map[string][]string { + permissions := getSubDirPermissionsFromPostFields(r) + permissions["/"] = r.Form["permissions"] + + return permissions +} + +func getAccessTimeRestrictionsFromPostFields(r *http.Request) []sdk.TimePeriod { + var result []sdk.TimePeriod + + dayOfWeeks := r.Form["access_time_day_of_week"] + starts := r.Form["access_time_start"] + ends := r.Form["access_time_end"] + + for idx, dayOfWeek := range dayOfWeeks { + dayOfWeek = strings.TrimSpace(dayOfWeek) + start := "" + if len(starts) > idx { + start = strings.TrimSpace(starts[idx]) + } + end := "" + if len(ends) > idx { + end = strings.TrimSpace(ends[idx]) + } + dayNumber, err := strconv.Atoi(dayOfWeek) + if err == nil && start != "" && end != "" { + result = append(result, sdk.TimePeriod{ + DayOfWeek: dayNumber, + From: start, + To: end, + }) + } + } + + return result +} + +func getBandwidthLimitsFromPostFields(r *http.Request) ([]sdk.BandwidthLimit, error) { + var result []sdk.BandwidthLimit + bwSources := r.Form["bandwidth_limit_sources"] + uploadSources := r.Form["upload_bandwidth_source"] + downloadSources := r.Form["download_bandwidth_source"] + + for idx, bwSource := range bwSources { + sources := getSliceFromDelimitedValues(bwSource, ",") + if len(sources) > 0 { + bwLimit := sdk.BandwidthLimit{ + Sources: sources, + } + ul := "" + dl := "" + if len(uploadSources) > idx { + ul = uploadSources[idx] + } + if len(downloadSources) > idx { + dl = downloadSources[idx] + } + if ul != "" { + bandwidthUL, err := strconv.ParseInt(ul, 10, 64) + if err != nil { + return result, fmt.Errorf("invalid upload_bandwidth_source%v %q: %w", idx, ul, err) + } + bwLimit.UploadBandwidth = bandwidthUL + } + if dl != "" { + bandwidthDL, err := strconv.ParseInt(dl, 10, 64) + if err != nil { + return result, fmt.Errorf("invalid download_bandwidth_source%v %q: %w", idx, ul, err) + } + bwLimit.DownloadBandwidth = bandwidthDL + } + result = append(result, bwLimit) + } + } + + return result, nil +} + +func getPatterDenyPolicyFromString(policy string) int { + denyPolicy := sdk.DenyPolicyDefault + if policy == "1" { + denyPolicy = sdk.DenyPolicyHide + } + return denyPolicy +} + +func getFilePatternsFromPostField(r *http.Request) []sdk.PatternsFilter { + var result []sdk.PatternsFilter + patternPaths := r.Form["pattern_path"] + patterns := r.Form["patterns"] + patternTypes := r.Form["pattern_type"] + policies := r.Form["pattern_policy"] + + allowedPatterns := make(map[string][]string) + deniedPatterns := make(map[string][]string) + patternPolicies := make(map[string]string) + + for idx := range patternPaths { + p := patternPaths[idx] + filters := strings.ReplaceAll(patterns[idx], " ", "") + patternType := patternTypes[idx] + patternPolicy := policies[idx] + if p != "" && filters != "" { + if patternType == "allowed" { + allowedPatterns[p] = append(allowedPatterns[p], strings.Split(filters, ",")...) + } else { + deniedPatterns[p] = append(deniedPatterns[p], strings.Split(filters, ",")...) + } + if patternPolicy != "" && patternPolicy != "0" { + patternPolicies[p] = patternPolicy + } + } + } + + for dirAllowed, allowPatterns := range allowedPatterns { + filter := sdk.PatternsFilter{ + Path: dirAllowed, + AllowedPatterns: allowPatterns, + DenyPolicy: getPatterDenyPolicyFromString(patternPolicies[dirAllowed]), + } + for dirDenied, denPatterns := range deniedPatterns { + if dirAllowed == dirDenied { + filter.DeniedPatterns = denPatterns + break + } + } + result = append(result, filter) + } + for dirDenied, denPatterns := range deniedPatterns { + found := false + for _, res := range result { + if res.Path == dirDenied { + found = true + break + } + } + if !found { + result = append(result, sdk.PatternsFilter{ + Path: dirDenied, + DeniedPatterns: denPatterns, + DenyPolicy: getPatterDenyPolicyFromString(patternPolicies[dirDenied]), + }) + } + } + return result +} + +func getGroupsFromUserPostFields(r *http.Request) []sdk.GroupMapping { + var groups []sdk.GroupMapping + + primaryGroup := strings.TrimSpace(r.Form.Get("primary_group")) + if primaryGroup != "" { + groups = append(groups, sdk.GroupMapping{ + Name: primaryGroup, + Type: sdk.GroupTypePrimary, + }) + } + secondaryGroups := r.Form["secondary_groups"] + for _, name := range secondaryGroups { + groups = append(groups, sdk.GroupMapping{ + Name: strings.TrimSpace(name), + Type: sdk.GroupTypeSecondary, + }) + } + membershipGroups := r.Form["membership_groups"] + for _, name := range membershipGroups { + groups = append(groups, sdk.GroupMapping{ + Name: strings.TrimSpace(name), + Type: sdk.GroupTypeMembership, + }) + } + return groups +} + +func getFiltersFromUserPostFields(r *http.Request) (sdk.BaseUserFilters, error) { + var filters sdk.BaseUserFilters + bwLimits, err := getBandwidthLimitsFromPostFields(r) + if err != nil { + return filters, err + } + maxFileSize, err := util.ParseBytes(r.Form.Get("max_upload_file_size")) + if err != nil { + return filters, util.NewI18nError(fmt.Errorf("invalid max upload file size: %w", err), util.I18nErrorInvalidMaxFilesize) + } + defaultSharesExpiration, err := strconv.Atoi(r.Form.Get("default_shares_expiration")) + if err != nil { + return filters, fmt.Errorf("invalid default shares expiration: %w", err) + } + maxSharesExpiration, err := strconv.Atoi(r.Form.Get("max_shares_expiration")) + if err != nil { + return filters, fmt.Errorf("invalid max shares expiration: %w", err) + } + passwordExpiration, err := strconv.Atoi(r.Form.Get("password_expiration")) + if err != nil { + return filters, fmt.Errorf("invalid password expiration: %w", err) + } + passwordStrength, err := strconv.Atoi(r.Form.Get("password_strength")) + if err != nil { + return filters, fmt.Errorf("invalid password strength: %w", err) + } + if r.Form.Get("ftp_security") == "1" { + filters.FTPSecurity = 1 + } + filters.BandwidthLimits = bwLimits + filters.AllowedIP = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") + filters.DeniedIP = getSliceFromDelimitedValues(r.Form.Get("denied_ip"), ",") + filters.DeniedLoginMethods = r.Form["denied_login_methods"] + filters.DeniedProtocols = r.Form["denied_protocols"] + filters.TwoFactorAuthProtocols = r.Form["required_two_factor_protocols"] + filters.FilePatterns = getFilePatternsFromPostField(r) + filters.TLSUsername = sdk.TLSUsername(strings.TrimSpace(r.Form.Get("tls_username"))) + filters.WebClient = r.Form["web_client_options"] + filters.DefaultSharesExpiration = defaultSharesExpiration + filters.MaxSharesExpiration = maxSharesExpiration + filters.PasswordExpiration = passwordExpiration + filters.PasswordStrength = passwordStrength + filters.AccessTime = getAccessTimeRestrictionsFromPostFields(r) + hooks := r.Form["hooks"] + if slices.Contains(hooks, "external_auth_disabled") { + filters.Hooks.ExternalAuthDisabled = true + } + if slices.Contains(hooks, "pre_login_disabled") { + filters.Hooks.PreLoginDisabled = true + } + if slices.Contains(hooks, "check_password_disabled") { + filters.Hooks.CheckPasswordDisabled = true + } + filters.IsAnonymous = r.Form.Get("is_anonymous") != "" + filters.DisableFsChecks = r.Form.Get("disable_fs_checks") != "" + filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" + filters.StartDirectory = strings.TrimSpace(r.Form.Get("start_directory")) + filters.MaxUploadFileSize = maxFileSize + filters.ExternalAuthCacheTime, err = strconv.ParseInt(r.Form.Get("external_auth_cache_time"), 10, 64) + if err != nil { + return filters, fmt.Errorf("invalid external auth cache time: %w", err) + } + return filters, nil +} + +func getSecretFromFormField(r *http.Request, field string) *kms.Secret { + secret := kms.NewPlainSecret(r.Form.Get(field)) + if strings.TrimSpace(secret.GetPayload()) == redactedSecret { + secret.SetStatus(sdkkms.SecretStatusRedacted) + } + if strings.TrimSpace(secret.GetPayload()) == "" { + secret.SetStatus("") + } + return secret +} + +func getS3Config(r *http.Request) (vfs.S3FsConfig, error) { + var err error + config := vfs.S3FsConfig{} + config.Bucket = strings.TrimSpace(r.Form.Get("s3_bucket")) + config.Region = strings.TrimSpace(r.Form.Get("s3_region")) + config.AccessKey = strings.TrimSpace(r.Form.Get("s3_access_key")) + config.RoleARN = strings.TrimSpace(r.Form.Get("s3_role_arn")) + config.AccessSecret = getSecretFromFormField(r, "s3_access_secret") + config.SSECustomerKey = getSecretFromFormField(r, "s3_sse_customer_key") + config.Endpoint = strings.TrimSpace(r.Form.Get("s3_endpoint")) + config.StorageClass = strings.TrimSpace(r.Form.Get("s3_storage_class")) + config.ACL = strings.TrimSpace(r.Form.Get("s3_acl")) + config.KeyPrefix = strings.TrimSpace(strings.TrimPrefix(r.Form.Get("s3_key_prefix"), "/")) + config.UploadPartSize, err = strconv.ParseInt(r.Form.Get("s3_upload_part_size"), 10, 64) + if err != nil { + return config, fmt.Errorf("invalid s3 upload part size: %w", err) + } + config.UploadConcurrency, err = strconv.Atoi(r.Form.Get("s3_upload_concurrency")) + if err != nil { + return config, fmt.Errorf("invalid s3 upload concurrency: %w", err) + } + config.DownloadPartSize, err = strconv.ParseInt(r.Form.Get("s3_download_part_size"), 10, 64) + if err != nil { + return config, fmt.Errorf("invalid s3 download part size: %w", err) + } + config.DownloadConcurrency, err = strconv.Atoi(r.Form.Get("s3_download_concurrency")) + if err != nil { + return config, fmt.Errorf("invalid s3 download concurrency: %w", err) + } + config.ForcePathStyle = r.Form.Get("s3_force_path_style") != "" + config.SkipTLSVerify = r.Form.Get("s3_skip_tls_verify") != "" + config.DownloadPartMaxTime, err = strconv.Atoi(r.Form.Get("s3_download_part_max_time")) + if err != nil { + return config, fmt.Errorf("invalid s3 download part max time: %w", err) + } + config.UploadPartMaxTime, err = strconv.Atoi(r.Form.Get("s3_upload_part_max_time")) + if err != nil { + return config, fmt.Errorf("invalid s3 upload part max time: %w", err) + } + return config, nil +} + +func getGCSConfig(r *http.Request) (vfs.GCSFsConfig, error) { + var err error + config := vfs.GCSFsConfig{} + + config.Bucket = strings.TrimSpace(r.Form.Get("gcs_bucket")) + config.StorageClass = strings.TrimSpace(r.Form.Get("gcs_storage_class")) + config.ACL = strings.TrimSpace(r.Form.Get("gcs_acl")) + config.KeyPrefix = strings.TrimSpace(strings.TrimPrefix(r.Form.Get("gcs_key_prefix"), "/")) + uploadPartSize, err := strconv.ParseInt(r.Form.Get("gcs_upload_part_size"), 10, 64) + if err == nil { + config.UploadPartSize = uploadPartSize + } + uploadPartMaxTime, err := strconv.Atoi(r.Form.Get("gcs_upload_part_max_time")) + if err == nil { + config.UploadPartMaxTime = uploadPartMaxTime + } + autoCredentials := r.Form.Get("gcs_auto_credentials") + if autoCredentials != "" { + config.AutomaticCredentials = 1 + } else { + config.AutomaticCredentials = 0 + } + credentials, _, err := r.FormFile("gcs_credential_file") + if errors.Is(err, http.ErrMissingFile) { + return config, nil + } + if err != nil { + return config, err + } + defer credentials.Close() + fileBytes, err := io.ReadAll(credentials) + if err != nil || len(fileBytes) == 0 { + if len(fileBytes) == 0 { + err = errors.New("credentials file size must be greater than 0") + } + return config, err + } + config.Credentials = kms.NewPlainSecret(util.BytesToString(fileBytes)) + config.AutomaticCredentials = 0 + return config, err +} + +func getSFTPConfig(r *http.Request) (vfs.SFTPFsConfig, error) { + var err error + config := vfs.SFTPFsConfig{} + config.Endpoint = strings.TrimSpace(r.Form.Get("sftp_endpoint")) + config.Username = strings.TrimSpace(r.Form.Get("sftp_username")) + config.Password = getSecretFromFormField(r, "sftp_password") + config.PrivateKey = getSecretFromFormField(r, "sftp_private_key") + config.KeyPassphrase = getSecretFromFormField(r, "sftp_key_passphrase") + fingerprintsFormValue := r.Form.Get("sftp_fingerprints") + config.Fingerprints = getSliceFromDelimitedValues(fingerprintsFormValue, "\n") + config.Prefix = strings.TrimSpace(r.Form.Get("sftp_prefix")) + config.DisableCouncurrentReads = r.Form.Get("sftp_disable_concurrent_reads") != "" + config.BufferSize, err = strconv.ParseInt(r.Form.Get("sftp_buffer_size"), 10, 64) + if r.Form.Get("sftp_equality_check_mode") != "" { + config.EqualityCheckMode = 1 + } else { + config.EqualityCheckMode = 0 + } + if err != nil { + return config, fmt.Errorf("invalid SFTP buffer size: %w", err) + } + return config, nil +} + +func getHTTPFsConfig(r *http.Request) vfs.HTTPFsConfig { + config := vfs.HTTPFsConfig{} + config.Endpoint = strings.TrimSpace(r.Form.Get("http_endpoint")) + config.Username = strings.TrimSpace(r.Form.Get("http_username")) + config.SkipTLSVerify = r.Form.Get("http_skip_tls_verify") != "" + config.Password = getSecretFromFormField(r, "http_password") + config.APIKey = getSecretFromFormField(r, "http_api_key") + if r.Form.Get("http_equality_check_mode") != "" { + config.EqualityCheckMode = 1 + } else { + config.EqualityCheckMode = 0 + } + return config +} + +func getAzureConfig(r *http.Request) (vfs.AzBlobFsConfig, error) { + var err error + config := vfs.AzBlobFsConfig{} + config.Container = strings.TrimSpace(r.Form.Get("az_container")) + config.AccountName = strings.TrimSpace(r.Form.Get("az_account_name")) + config.AccountKey = getSecretFromFormField(r, "az_account_key") + config.SASURL = getSecretFromFormField(r, "az_sas_url") + config.Endpoint = strings.TrimSpace(r.Form.Get("az_endpoint")) + config.KeyPrefix = strings.TrimSpace(strings.TrimPrefix(r.Form.Get("az_key_prefix"), "/")) + config.AccessTier = strings.TrimSpace(r.Form.Get("az_access_tier")) + config.UseEmulator = r.Form.Get("az_use_emulator") != "" + config.UploadPartSize, err = strconv.ParseInt(r.Form.Get("az_upload_part_size"), 10, 64) + if err != nil { + return config, fmt.Errorf("invalid azure upload part size: %w", err) + } + config.UploadConcurrency, err = strconv.Atoi(r.Form.Get("az_upload_concurrency")) + if err != nil { + return config, fmt.Errorf("invalid azure upload concurrency: %w", err) + } + config.DownloadPartSize, err = strconv.ParseInt(r.Form.Get("az_download_part_size"), 10, 64) + if err != nil { + return config, fmt.Errorf("invalid azure download part size: %w", err) + } + config.DownloadConcurrency, err = strconv.Atoi(r.Form.Get("az_download_concurrency")) + if err != nil { + return config, fmt.Errorf("invalid azure download concurrency: %w", err) + } + return config, nil +} + +func getOsConfigFromPostFields(r *http.Request, readBufferField, writeBufferField string) sdk.OSFsConfig { + config := sdk.OSFsConfig{} + readBuffer, err := strconv.Atoi(r.Form.Get(readBufferField)) + if err == nil { + config.ReadBufferSize = readBuffer + } + writeBuffer, err := strconv.Atoi(r.Form.Get(writeBufferField)) + if err == nil { + config.WriteBufferSize = writeBuffer + } + return config +} + +func getFsConfigFromPostFields(r *http.Request) (vfs.Filesystem, error) { + var fs vfs.Filesystem + fs.Provider = dataprovider.GetProviderFromValue(r.Form.Get("fs_provider")) + switch fs.Provider { + case sdk.LocalFilesystemProvider: + fs.OSConfig = getOsConfigFromPostFields(r, "osfs_read_buffer_size", "osfs_write_buffer_size") + case sdk.S3FilesystemProvider: + config, err := getS3Config(r) + if err != nil { + return fs, err + } + fs.S3Config = config + case sdk.AzureBlobFilesystemProvider: + config, err := getAzureConfig(r) + if err != nil { + return fs, err + } + fs.AzBlobConfig = config + case sdk.GCSFilesystemProvider: + config, err := getGCSConfig(r) + if err != nil { + return fs, err + } + fs.GCSConfig = config + case sdk.CryptedFilesystemProvider: + fs.CryptConfig.Passphrase = getSecretFromFormField(r, "crypt_passphrase") + fs.CryptConfig.OSFsConfig = getOsConfigFromPostFields(r, "cryptfs_read_buffer_size", "cryptfs_write_buffer_size") + case sdk.SFTPFilesystemProvider: + config, err := getSFTPConfig(r) + if err != nil { + return fs, err + } + fs.SFTPConfig = config + case sdk.HTTPFilesystemProvider: + fs.HTTPConfig = getHTTPFsConfig(r) + } + return fs, nil +} + +func getAdminHiddenUserPageSections(r *http.Request) int { + var result int + + for _, val := range r.Form["user_page_hidden_sections"] { + switch val { + case "1": + result++ + case "2": + result += 2 + case "3": + result += 4 + case "4": + result += 8 + case "5": + result += 16 + case "6": + result += 32 + case "7": + result += 64 + } + } + + return result +} + +func getAdminFromPostFields(r *http.Request) (dataprovider.Admin, error) { + var admin dataprovider.Admin + err := r.ParseForm() + if err != nil { + return admin, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + status, err := strconv.Atoi(r.Form.Get("status")) + if err != nil { + return admin, fmt.Errorf("invalid status: %w", err) + } + admin.Username = strings.TrimSpace(r.Form.Get("username")) + admin.Password = strings.TrimSpace(r.Form.Get("password")) + admin.Permissions = r.Form["permissions"] + admin.Email = strings.TrimSpace(r.Form.Get("email")) + admin.Status = status + admin.Role = strings.TrimSpace(r.Form.Get("role")) + admin.Filters.AllowList = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") + admin.Filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" + admin.Filters.RequireTwoFactor = r.Form.Get("require_two_factor") != "" + admin.Filters.RequirePasswordChange = r.Form.Get("require_password_change") != "" + admin.AdditionalInfo = r.Form.Get("additional_info") + admin.Description = r.Form.Get("description") + admin.Filters.Preferences.HideUserPageSections = getAdminHiddenUserPageSections(r) + admin.Filters.Preferences.DefaultUsersExpiration = 0 + if val := r.Form.Get("default_users_expiration"); val != "" { + defaultUsersExpiration, err := strconv.Atoi(r.Form.Get("default_users_expiration")) + if err != nil { + return admin, fmt.Errorf("invalid default users expiration: %w", err) + } + admin.Filters.Preferences.DefaultUsersExpiration = defaultUsersExpiration + } + for k := range r.Form { + if hasPrefixAndSuffix(k, "groups[", "][group]") { + groupName := strings.TrimSpace(r.Form.Get(k)) + if groupName != "" { + group := dataprovider.AdminGroupMapping{ + Name: groupName, + } + base, _ := strings.CutSuffix(k, "[group]") + addAsGroupType := strings.TrimSpace(r.Form.Get(base + "[group_type]")) + switch addAsGroupType { + case "1": + group.Options.AddToUsersAs = dataprovider.GroupAddToUsersAsPrimary + case "2": + group.Options.AddToUsersAs = dataprovider.GroupAddToUsersAsSecondary + default: + group.Options.AddToUsersAs = dataprovider.GroupAddToUsersAsMembership + } + admin.Groups = append(admin.Groups, group) + } + } + } + return admin, nil +} + +func replacePlaceholders(field string, replacements map[string]string) string { + for k, v := range replacements { + field = strings.ReplaceAll(field, k, v) + } + return field +} + +func getFolderFromTemplate(folder vfs.BaseVirtualFolder, name string) vfs.BaseVirtualFolder { + folder.Name = name + replacements := make(map[string]string) + replacements["%name%"] = folder.Name + + folder.MappedPath = replacePlaceholders(folder.MappedPath, replacements) + folder.Description = replacePlaceholders(folder.Description, replacements) + switch folder.FsConfig.Provider { + case sdk.CryptedFilesystemProvider: + folder.FsConfig.CryptConfig = getCryptFsFromTemplate(folder.FsConfig.CryptConfig, replacements) + case sdk.S3FilesystemProvider: + folder.FsConfig.S3Config = getS3FsFromTemplate(folder.FsConfig.S3Config, replacements) + case sdk.GCSFilesystemProvider: + folder.FsConfig.GCSConfig = getGCSFsFromTemplate(folder.FsConfig.GCSConfig, replacements) + case sdk.AzureBlobFilesystemProvider: + folder.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(folder.FsConfig.AzBlobConfig, replacements) + case sdk.SFTPFilesystemProvider: + folder.FsConfig.SFTPConfig = getSFTPFsFromTemplate(folder.FsConfig.SFTPConfig, replacements) + case sdk.HTTPFilesystemProvider: + folder.FsConfig.HTTPConfig = getHTTPFsFromTemplate(folder.FsConfig.HTTPConfig, replacements) + } + + return folder +} + +func getCryptFsFromTemplate(fsConfig vfs.CryptFsConfig, replacements map[string]string) vfs.CryptFsConfig { + if fsConfig.Passphrase != nil { + if fsConfig.Passphrase.IsPlain() { + payload := replacePlaceholders(fsConfig.Passphrase.GetPayload(), replacements) + fsConfig.Passphrase = kms.NewPlainSecret(payload) + } + } + return fsConfig +} + +func getS3FsFromTemplate(fsConfig vfs.S3FsConfig, replacements map[string]string) vfs.S3FsConfig { + fsConfig.KeyPrefix = replacePlaceholders(fsConfig.KeyPrefix, replacements) + fsConfig.AccessKey = replacePlaceholders(fsConfig.AccessKey, replacements) + if fsConfig.AccessSecret != nil && fsConfig.AccessSecret.IsPlain() { + payload := replacePlaceholders(fsConfig.AccessSecret.GetPayload(), replacements) + fsConfig.AccessSecret = kms.NewPlainSecret(payload) + } + if fsConfig.SSECustomerKey != nil && fsConfig.SSECustomerKey.IsPlain() { + payload := replacePlaceholders(fsConfig.SSECustomerKey.GetPayload(), replacements) + fsConfig.SSECustomerKey = kms.NewPlainSecret(payload) + } + return fsConfig +} + +func getGCSFsFromTemplate(fsConfig vfs.GCSFsConfig, replacements map[string]string) vfs.GCSFsConfig { + fsConfig.KeyPrefix = replacePlaceholders(fsConfig.KeyPrefix, replacements) + return fsConfig +} + +func getAzBlobFsFromTemplate(fsConfig vfs.AzBlobFsConfig, replacements map[string]string) vfs.AzBlobFsConfig { + fsConfig.KeyPrefix = replacePlaceholders(fsConfig.KeyPrefix, replacements) + fsConfig.AccountName = replacePlaceholders(fsConfig.AccountName, replacements) + if fsConfig.AccountKey != nil && fsConfig.AccountKey.IsPlain() { + payload := replacePlaceholders(fsConfig.AccountKey.GetPayload(), replacements) + fsConfig.AccountKey = kms.NewPlainSecret(payload) + } + return fsConfig +} + +func getSFTPFsFromTemplate(fsConfig vfs.SFTPFsConfig, replacements map[string]string) vfs.SFTPFsConfig { + fsConfig.Prefix = replacePlaceholders(fsConfig.Prefix, replacements) + fsConfig.Username = replacePlaceholders(fsConfig.Username, replacements) + if fsConfig.Password != nil && fsConfig.Password.IsPlain() { + payload := replacePlaceholders(fsConfig.Password.GetPayload(), replacements) + fsConfig.Password = kms.NewPlainSecret(payload) + } + return fsConfig +} + +func getHTTPFsFromTemplate(fsConfig vfs.HTTPFsConfig, replacements map[string]string) vfs.HTTPFsConfig { + fsConfig.Username = replacePlaceholders(fsConfig.Username, replacements) + return fsConfig +} + +func getUserFromTemplate(user dataprovider.User, template userTemplateFields) dataprovider.User { + user.Username = template.Username + user.Password = template.Password + user.PublicKeys = template.PublicKeys + user.Filters.RequirePasswordChange = template.RequirePwdChange + replacements := make(map[string]string) + replacements["%username%"] = user.Username + if user.Password != "" && !user.IsPasswordHashed() { + user.Password = replacePlaceholders(user.Password, replacements) + replacements["%password%"] = user.Password + } + + user.HomeDir = replacePlaceholders(user.HomeDir, replacements) + var vfolders []vfs.VirtualFolder + for _, vfolder := range user.VirtualFolders { + vfolder.Name = replacePlaceholders(vfolder.Name, replacements) + vfolder.VirtualPath = replacePlaceholders(vfolder.VirtualPath, replacements) + vfolders = append(vfolders, vfolder) + } + user.VirtualFolders = vfolders + user.Description = replacePlaceholders(user.Description, replacements) + user.AdditionalInfo = replacePlaceholders(user.AdditionalInfo, replacements) + user.Filters.StartDirectory = replacePlaceholders(user.Filters.StartDirectory, replacements) + + switch user.FsConfig.Provider { + case sdk.CryptedFilesystemProvider: + user.FsConfig.CryptConfig = getCryptFsFromTemplate(user.FsConfig.CryptConfig, replacements) + case sdk.S3FilesystemProvider: + user.FsConfig.S3Config = getS3FsFromTemplate(user.FsConfig.S3Config, replacements) + case sdk.GCSFilesystemProvider: + user.FsConfig.GCSConfig = getGCSFsFromTemplate(user.FsConfig.GCSConfig, replacements) + case sdk.AzureBlobFilesystemProvider: + user.FsConfig.AzBlobConfig = getAzBlobFsFromTemplate(user.FsConfig.AzBlobConfig, replacements) + case sdk.SFTPFilesystemProvider: + user.FsConfig.SFTPConfig = getSFTPFsFromTemplate(user.FsConfig.SFTPConfig, replacements) + case sdk.HTTPFilesystemProvider: + user.FsConfig.HTTPConfig = getHTTPFsFromTemplate(user.FsConfig.HTTPConfig, replacements) + } + + return user +} + +func getTransferLimits(r *http.Request) (int64, int64, int64, error) { + dataTransferUL, err := strconv.ParseInt(r.Form.Get("upload_data_transfer"), 10, 64) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid upload data transfer: %w", err) + } + dataTransferDL, err := strconv.ParseInt(r.Form.Get("download_data_transfer"), 10, 64) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid download data transfer: %w", err) + } + dataTransferTotal, err := strconv.ParseInt(r.Form.Get("total_data_transfer"), 10, 64) + if err != nil { + return 0, 0, 0, fmt.Errorf("invalid total data transfer: %w", err) + } + return dataTransferUL, dataTransferDL, dataTransferTotal, nil +} + +func getQuotaLimits(r *http.Request) (int64, int, error) { + quotaSize, err := util.ParseBytes(r.Form.Get("quota_size")) + if err != nil { + return 0, 0, util.NewI18nError(fmt.Errorf("invalid quota size: %w", err), util.I18nErrorInvalidQuotaSize) + } + quotaFiles, err := strconv.Atoi(r.Form.Get("quota_files")) + if err != nil { + return 0, 0, fmt.Errorf("invalid quota files: %w", err) + } + return quotaSize, quotaFiles, nil +} + +func updateRepeaterFormFields(r *http.Request) { + for k := range r.Form { + if hasPrefixAndSuffix(k, "public_keys[", "][public_key]") { + key := r.Form.Get(k) + if strings.TrimSpace(key) != "" { + r.Form.Add("public_keys", key) + } + continue + } + if hasPrefixAndSuffix(k, "tls_certs[", "][tls_cert]") { + cert := strings.TrimSpace(r.Form.Get(k)) + if cert != "" { + r.Form.Add("tls_certs", cert) + } + continue + } + if hasPrefixAndSuffix(k, "additional_emails[", "][additional_email]") { + email := strings.TrimSpace(r.Form.Get(k)) + if email != "" { + r.Form.Add("additional_emails", email) + } + continue + } + if hasPrefixAndSuffix(k, "virtual_folders[", "][vfolder_path]") { + base, _ := strings.CutSuffix(k, "[vfolder_path]") + r.Form.Add("vfolder_path", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("vfolder_name", strings.TrimSpace(r.Form.Get(base+"[vfolder_name]"))) + r.Form.Add("vfolder_quota_files", strings.TrimSpace(r.Form.Get(base+"[vfolder_quota_files]"))) + r.Form.Add("vfolder_quota_size", strings.TrimSpace(r.Form.Get(base+"[vfolder_quota_size]"))) + continue + } + if hasPrefixAndSuffix(k, "directory_permissions[", "][sub_perm_path]") { + base, _ := strings.CutSuffix(k, "[sub_perm_path]") + r.Form.Add("sub_perm_path", strings.TrimSpace(r.Form.Get(k))) + r.Form["sub_perm_permissions"+strconv.Itoa(len(r.Form["sub_perm_path"])-1)] = r.Form[base+"[sub_perm_permissions][]"] + continue + } + if hasPrefixAndSuffix(k, "directory_patterns[", "][pattern_path]") { + base, _ := strings.CutSuffix(k, "[pattern_path]") + r.Form.Add("pattern_path", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("patterns", strings.TrimSpace(r.Form.Get(base+"[patterns]"))) + r.Form.Add("pattern_type", strings.TrimSpace(r.Form.Get(base+"[pattern_type]"))) + r.Form.Add("pattern_policy", strings.TrimSpace(r.Form.Get(base+"[pattern_policy]"))) + continue + } + if hasPrefixAndSuffix(k, "access_time_restrictions[", "][access_time_day_of_week]") { + base, _ := strings.CutSuffix(k, "[access_time_day_of_week]") + r.Form.Add("access_time_day_of_week", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("access_time_start", strings.TrimSpace(r.Form.Get(base+"[access_time_start]"))) + r.Form.Add("access_time_end", strings.TrimSpace(r.Form.Get(base+"[access_time_end]"))) + continue + } + if hasPrefixAndSuffix(k, "src_bandwidth_limits[", "][bandwidth_limit_sources]") { + base, _ := strings.CutSuffix(k, "[bandwidth_limit_sources]") + r.Form.Add("bandwidth_limit_sources", r.Form.Get(k)) + r.Form.Add("upload_bandwidth_source", strings.TrimSpace(r.Form.Get(base+"[upload_bandwidth_source]"))) + r.Form.Add("download_bandwidth_source", strings.TrimSpace(r.Form.Get(base+"[download_bandwidth_source]"))) + continue + } + if hasPrefixAndSuffix(k, "template_users[", "][tpl_username]") { + base, _ := strings.CutSuffix(k, "[tpl_username]") + r.Form.Add("tpl_username", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("tpl_password", strings.TrimSpace(r.Form.Get(base+"[tpl_password]"))) + r.Form.Add("tpl_public_keys", strings.TrimSpace(r.Form.Get(base+"[tpl_public_keys]"))) + continue + } + } +} + +func getUserFromPostFields(r *http.Request) (dataprovider.User, error) { + user := dataprovider.User{} + err := r.ParseMultipartForm(maxRequestSize) + if err != nil { + return user, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + updateRepeaterFormFields(r) + + uid, err := strconv.Atoi(r.Form.Get("uid")) + if err != nil { + return user, fmt.Errorf("invalid uid: %w", err) + } + gid, err := strconv.Atoi(r.Form.Get("gid")) + if err != nil { + return user, fmt.Errorf("invalid uid: %w", err) + } + maxSessions, err := strconv.Atoi(r.Form.Get("max_sessions")) + if err != nil { + return user, fmt.Errorf("invalid max sessions: %w", err) + } + quotaSize, quotaFiles, err := getQuotaLimits(r) + if err != nil { + return user, err + } + bandwidthUL, err := strconv.ParseInt(r.Form.Get("upload_bandwidth"), 10, 64) + if err != nil { + return user, fmt.Errorf("invalid upload bandwidth: %w", err) + } + bandwidthDL, err := strconv.ParseInt(r.Form.Get("download_bandwidth"), 10, 64) + if err != nil { + return user, fmt.Errorf("invalid download bandwidth: %w", err) + } + dataTransferUL, dataTransferDL, dataTransferTotal, err := getTransferLimits(r) + if err != nil { + return user, err + } + status, err := strconv.Atoi(r.Form.Get("status")) + if err != nil { + return user, fmt.Errorf("invalid status: %w", err) + } + expirationDateMillis := int64(0) + expirationDateString := r.Form.Get("expiration_date") + if strings.TrimSpace(expirationDateString) != "" { + expirationDate, err := time.Parse(webDateTimeFormat, expirationDateString) + if err != nil { + return user, err + } + expirationDateMillis = util.GetTimeAsMsSinceEpoch(expirationDate) + } + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + return user, err + } + filters, err := getFiltersFromUserPostFields(r) + if err != nil { + return user, err + } + filters.TLSCerts = r.Form["tls_certs"] + user = dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: strings.TrimSpace(r.Form.Get("username")), + Email: strings.TrimSpace(r.Form.Get("email")), + Password: strings.TrimSpace(r.Form.Get("password")), + PublicKeys: r.Form["public_keys"], + HomeDir: strings.TrimSpace(r.Form.Get("home_dir")), + UID: uid, + GID: gid, + Permissions: getUserPermissionsFromPostFields(r), + MaxSessions: maxSessions, + QuotaSize: quotaSize, + QuotaFiles: quotaFiles, + UploadBandwidth: bandwidthUL, + DownloadBandwidth: bandwidthDL, + UploadDataTransfer: dataTransferUL, + DownloadDataTransfer: dataTransferDL, + TotalDataTransfer: dataTransferTotal, + Status: status, + ExpirationDate: expirationDateMillis, + AdditionalInfo: r.Form.Get("additional_info"), + Description: r.Form.Get("description"), + Role: strings.TrimSpace(r.Form.Get("role")), + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: filters, + RequirePasswordChange: r.Form.Get("require_password_change") != "", + AdditionalEmails: r.Form["additional_emails"], + }, + VirtualFolders: getVirtualFoldersFromPostFields(r), + FsConfig: fsConfig, + Groups: getGroupsFromUserPostFields(r), + } + return user, nil +} + +func getGroupFromPostFields(r *http.Request) (dataprovider.Group, error) { + group := dataprovider.Group{} + err := r.ParseMultipartForm(maxRequestSize) + if err != nil { + return group, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + updateRepeaterFormFields(r) + + maxSessions, err := strconv.Atoi(r.Form.Get("max_sessions")) + if err != nil { + return group, fmt.Errorf("invalid max sessions: %w", err) + } + quotaSize, quotaFiles, err := getQuotaLimits(r) + if err != nil { + return group, err + } + bandwidthUL, err := strconv.ParseInt(r.Form.Get("upload_bandwidth"), 10, 64) + if err != nil { + return group, fmt.Errorf("invalid upload bandwidth: %w", err) + } + bandwidthDL, err := strconv.ParseInt(r.Form.Get("download_bandwidth"), 10, 64) + if err != nil { + return group, fmt.Errorf("invalid download bandwidth: %w", err) + } + dataTransferUL, dataTransferDL, dataTransferTotal, err := getTransferLimits(r) + if err != nil { + return group, err + } + expiresIn, err := strconv.Atoi(r.Form.Get("expires_in")) + if err != nil { + return group, fmt.Errorf("invalid expires in: %w", err) + } + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + return group, err + } + filters, err := getFiltersFromUserPostFields(r) + if err != nil { + return group, err + } + group = dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: strings.TrimSpace(r.Form.Get("name")), + Description: r.Form.Get("description"), + }, + UserSettings: dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + HomeDir: strings.TrimSpace(r.Form.Get("home_dir")), + MaxSessions: maxSessions, + QuotaSize: quotaSize, + QuotaFiles: quotaFiles, + Permissions: getSubDirPermissionsFromPostFields(r), + UploadBandwidth: bandwidthUL, + DownloadBandwidth: bandwidthDL, + UploadDataTransfer: dataTransferUL, + DownloadDataTransfer: dataTransferDL, + TotalDataTransfer: dataTransferTotal, + ExpiresIn: expiresIn, + Filters: filters, + }, + FsConfig: fsConfig, + }, + VirtualFolders: getVirtualFoldersFromPostFields(r), + } + return group, nil +} + +func getKeyValsFromPostFields(r *http.Request, key, val string) []dataprovider.KeyValue { + var res []dataprovider.KeyValue + + keys := r.Form[key] + values := r.Form[val] + + for idx, k := range keys { + v := values[idx] + if k != "" && v != "" { + res = append(res, dataprovider.KeyValue{ + Key: k, + Value: v, + }) + } + } + + return res +} + +func getRenameConfigsFromPostFields(r *http.Request) []dataprovider.RenameConfig { + var res []dataprovider.RenameConfig + keys := r.Form["fs_rename_source"] + values := r.Form["fs_rename_target"] + + for idx, k := range keys { + v := values[idx] + if k != "" && v != "" { + opts := r.Form["fs_rename_options"+strconv.Itoa(idx)] + res = append(res, dataprovider.RenameConfig{ + KeyValue: dataprovider.KeyValue{ + Key: k, + Value: v, + }, + UpdateModTime: slices.Contains(opts, "1"), + }) + } + } + + return res +} + +func getFoldersRetentionFromPostFields(r *http.Request) ([]dataprovider.FolderRetention, error) { + var res []dataprovider.FolderRetention + paths := r.Form["folder_retention_path"] + values := r.Form["folder_retention_val"] + + for idx, p := range paths { + if p != "" { + retention, err := strconv.Atoi(values[idx]) + if err != nil { + return nil, fmt.Errorf("invalid retention for path %q: %w", p, err) + } + opts := r.Form["folder_retention_options"+strconv.Itoa(idx)] + res = append(res, dataprovider.FolderRetention{ + Path: p, + Retention: retention, + DeleteEmptyDirs: slices.Contains(opts, "1"), + }) + } + } + + return res, nil +} + +func getHTTPPartsFromPostFields(r *http.Request) []dataprovider.HTTPPart { + var result []dataprovider.HTTPPart + + names := r.Form["http_part_name"] + files := r.Form["http_part_file"] + headers := r.Form["http_part_headers"] + bodies := r.Form["http_part_body"] + orders := r.Form["http_part_order"] + + for idx, partName := range names { + if partName != "" { + order, err := strconv.Atoi(orders[idx]) + if err == nil { + filePath := files[idx] + body := bodies[idx] + concatHeaders := getSliceFromDelimitedValues(headers[idx], "\n") + var headers []dataprovider.KeyValue + for _, h := range concatHeaders { + values := strings.SplitN(h, ":", 2) + if len(values) > 1 { + headers = append(headers, dataprovider.KeyValue{ + Key: strings.TrimSpace(values[0]), + Value: strings.TrimSpace(values[1]), + }) + } + } + result = append(result, dataprovider.HTTPPart{ + Name: partName, + Filepath: filePath, + Headers: headers, + Body: body, + Order: order, + }) + } + } + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Order < result[j].Order + }) + return result +} + +func updateRepeaterFormActionFields(r *http.Request) { + for k := range r.Form { + if hasPrefixAndSuffix(k, "http_headers[", "][http_header_key]") { + base, _ := strings.CutSuffix(k, "[http_header_key]") + r.Form.Add("http_header_key", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("http_header_value", strings.TrimSpace(r.Form.Get(base+"[http_header_value]"))) + continue + } + if hasPrefixAndSuffix(k, "query_parameters[", "][http_query_key]") { + base, _ := strings.CutSuffix(k, "[http_query_key]") + r.Form.Add("http_query_key", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("http_query_value", strings.TrimSpace(r.Form.Get(base+"[http_query_value]"))) + continue + } + if hasPrefixAndSuffix(k, "multipart_body[", "][http_part_name]") { + base, _ := strings.CutSuffix(k, "[http_part_name]") + order, _ := strings.CutPrefix(k, "multipart_body[") + order, _ = strings.CutSuffix(order, "][http_part_name]") + r.Form.Add("http_part_name", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("http_part_file", strings.TrimSpace(r.Form.Get(base+"[http_part_file]"))) + r.Form.Add("http_part_headers", strings.TrimSpace(r.Form.Get(base+"[http_part_headers]"))) + r.Form.Add("http_part_body", strings.TrimSpace(r.Form.Get(base+"[http_part_body]"))) + r.Form.Add("http_part_order", order) + continue + } + if hasPrefixAndSuffix(k, "env_vars[", "][cmd_env_key]") { + base, _ := strings.CutSuffix(k, "[cmd_env_key]") + r.Form.Add("cmd_env_key", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("cmd_env_value", strings.TrimSpace(r.Form.Get(base+"[cmd_env_value]"))) + continue + } + if hasPrefixAndSuffix(k, "data_retention[", "][folder_retention_path]") { + base, _ := strings.CutSuffix(k, "[folder_retention_path]") + r.Form.Add("folder_retention_path", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("folder_retention_val", strings.TrimSpace(r.Form.Get(base+"[folder_retention_val]"))) + r.Form["folder_retention_options"+strconv.Itoa(len(r.Form["folder_retention_path"])-1)] = + r.Form[base+"[folder_retention_options][]"] + continue + } + if hasPrefixAndSuffix(k, "fs_rename[", "][fs_rename_source]") { + base, _ := strings.CutSuffix(k, "[fs_rename_source]") + r.Form.Add("fs_rename_source", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("fs_rename_target", strings.TrimSpace(r.Form.Get(base+"[fs_rename_target]"))) + r.Form["fs_rename_options"+strconv.Itoa(len(r.Form["fs_rename_source"])-1)] = + r.Form[base+"[fs_rename_options][]"] + continue + } + if hasPrefixAndSuffix(k, "fs_copy[", "][fs_copy_source]") { + base, _ := strings.CutSuffix(k, "[fs_copy_source]") + r.Form.Add("fs_copy_source", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("fs_copy_target", strings.TrimSpace(r.Form.Get(base+"[fs_copy_target]"))) + continue + } + } +} + +func getEventActionOptionsFromPostFields(r *http.Request) (dataprovider.BaseEventActionOptions, error) { + updateRepeaterFormActionFields(r) + httpTimeout, err := strconv.Atoi(r.Form.Get("http_timeout")) + if err != nil { + return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid http timeout: %w", err) + } + cmdTimeout, err := strconv.Atoi(r.Form.Get("cmd_timeout")) + if err != nil { + return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid command timeout: %w", err) + } + foldersRetention, err := getFoldersRetentionFromPostFields(r) + if err != nil { + return dataprovider.BaseEventActionOptions{}, err + } + fsActionType, err := strconv.Atoi(r.Form.Get("fs_action_type")) + if err != nil { + return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid fs action type: %w", err) + } + pwdExpirationThreshold, err := strconv.Atoi(r.Form.Get("pwd_expiration_threshold")) + if err != nil { + return dataprovider.BaseEventActionOptions{}, fmt.Errorf("invalid password expiration threshold: %w", err) + } + var disableThreshold, deleteThreshold int + if val, err := strconv.Atoi(r.Form.Get("inactivity_disable_threshold")); err == nil { + disableThreshold = val + } + if val, err := strconv.Atoi(r.Form.Get("inactivity_delete_threshold")); err == nil { + deleteThreshold = val + } + var emailAttachments []string + if r.Form.Get("email_attachments") != "" { + emailAttachments = getSliceFromDelimitedValues(r.Form.Get("email_attachments"), ",") + } + var cmdArgs []string + if r.Form.Get("cmd_arguments") != "" { + cmdArgs = getSliceFromDelimitedValues(r.Form.Get("cmd_arguments"), ",") + } + idpMode := 0 + if r.Form.Get("idp_mode") == "1" { + idpMode = 1 + } + emailContentType := 0 + if r.Form.Get("email_content_type") == "1" { + emailContentType = 1 + } + options := dataprovider.BaseEventActionOptions{ + HTTPConfig: dataprovider.EventActionHTTPConfig{ + Endpoint: strings.TrimSpace(r.Form.Get("http_endpoint")), + Username: strings.TrimSpace(r.Form.Get("http_username")), + Password: getSecretFromFormField(r, "http_password"), + Headers: getKeyValsFromPostFields(r, "http_header_key", "http_header_value"), + Timeout: httpTimeout, + SkipTLSVerify: r.Form.Get("http_skip_tls_verify") != "", + Method: r.Form.Get("http_method"), + QueryParameters: getKeyValsFromPostFields(r, "http_query_key", "http_query_value"), + Body: r.Form.Get("http_body"), + Parts: getHTTPPartsFromPostFields(r), + }, + CmdConfig: dataprovider.EventActionCommandConfig{ + Cmd: strings.TrimSpace(r.Form.Get("cmd_path")), + Args: cmdArgs, + Timeout: cmdTimeout, + EnvVars: getKeyValsFromPostFields(r, "cmd_env_key", "cmd_env_value"), + }, + EmailConfig: dataprovider.EventActionEmailConfig{ + Recipients: getSliceFromDelimitedValues(r.Form.Get("email_recipients"), ","), + Bcc: getSliceFromDelimitedValues(r.Form.Get("email_bcc"), ","), + Subject: r.Form.Get("email_subject"), + ContentType: emailContentType, + Body: r.Form.Get("email_body"), + Attachments: emailAttachments, + }, + RetentionConfig: dataprovider.EventActionDataRetentionConfig{ + Folders: foldersRetention, + }, + FsConfig: dataprovider.EventActionFilesystemConfig{ + Type: fsActionType, + Renames: getRenameConfigsFromPostFields(r), + Deletes: getSliceFromDelimitedValues(r.Form.Get("fs_delete_paths"), ","), + MkDirs: getSliceFromDelimitedValues(r.Form.Get("fs_mkdir_paths"), ","), + Exist: getSliceFromDelimitedValues(r.Form.Get("fs_exist_paths"), ","), + Copy: getKeyValsFromPostFields(r, "fs_copy_source", "fs_copy_target"), + Compress: dataprovider.EventActionFsCompress{ + Name: strings.TrimSpace(r.Form.Get("fs_compress_name")), + Paths: getSliceFromDelimitedValues(r.Form.Get("fs_compress_paths"), ","), + }, + }, + PwdExpirationConfig: dataprovider.EventActionPasswordExpiration{ + Threshold: pwdExpirationThreshold, + }, + UserInactivityConfig: dataprovider.EventActionUserInactivity{ + DisableThreshold: disableThreshold, + DeleteThreshold: deleteThreshold, + }, + IDPConfig: dataprovider.EventActionIDPAccountCheck{ + Mode: idpMode, + TemplateUser: strings.TrimSpace(r.Form.Get("idp_user")), + TemplateAdmin: strings.TrimSpace(r.Form.Get("idp_admin")), + }, + } + return options, nil +} + +func getEventActionFromPostFields(r *http.Request) (dataprovider.BaseEventAction, error) { + err := r.ParseForm() + if err != nil { + return dataprovider.BaseEventAction{}, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + actionType, err := strconv.Atoi(r.Form.Get("type")) + if err != nil { + return dataprovider.BaseEventAction{}, fmt.Errorf("invalid action type: %w", err) + } + options, err := getEventActionOptionsFromPostFields(r) + if err != nil { + return dataprovider.BaseEventAction{}, err + } + action := dataprovider.BaseEventAction{ + Name: strings.TrimSpace(r.Form.Get("name")), + Description: r.Form.Get("description"), + Type: actionType, + Options: options, + } + return action, nil +} + +func getIDPLoginEventFromPostField(r *http.Request) int { + switch r.Form.Get("idp_login_event") { + case "1": + return 1 + case "2": + return 2 + default: + return 0 + } +} + +func getEventRuleConditionsFromPostFields(r *http.Request) (dataprovider.EventConditions, error) { + var schedules []dataprovider.Schedule + var names, groupNames, roleNames, fsPaths []dataprovider.ConditionPattern + + scheduleHours := r.Form["schedule_hour"] + scheduleDayOfWeeks := r.Form["schedule_day_of_week"] + scheduleDayOfMonths := r.Form["schedule_day_of_month"] + scheduleMonths := r.Form["schedule_month"] + + for idx, hour := range scheduleHours { + if hour != "" { + schedules = append(schedules, dataprovider.Schedule{ + Hours: hour, + DayOfWeek: scheduleDayOfWeeks[idx], + DayOfMonth: scheduleDayOfMonths[idx], + Month: scheduleMonths[idx], + }) + } + } + + for idx, name := range r.Form["name_pattern"] { + if name != "" { + names = append(names, dataprovider.ConditionPattern{ + Pattern: name, + InverseMatch: r.Form["type_name_pattern"][idx] == inversePatternType, + }) + } + } + + for idx, name := range r.Form["group_name_pattern"] { + if name != "" { + groupNames = append(groupNames, dataprovider.ConditionPattern{ + Pattern: name, + InverseMatch: r.Form["type_group_name_pattern"][idx] == inversePatternType, + }) + } + } + + for idx, name := range r.Form["role_name_pattern"] { + if name != "" { + roleNames = append(roleNames, dataprovider.ConditionPattern{ + Pattern: name, + InverseMatch: r.Form["type_role_name_pattern"][idx] == inversePatternType, + }) + } + } + + for idx, name := range r.Form["fs_path_pattern"] { + if name != "" { + fsPaths = append(fsPaths, dataprovider.ConditionPattern{ + Pattern: name, + InverseMatch: r.Form["type_fs_path_pattern"][idx] == inversePatternType, + }) + } + } + + minFileSize, err := util.ParseBytes(r.Form.Get("fs_min_size")) + if err != nil { + return dataprovider.EventConditions{}, util.NewI18nError(fmt.Errorf("invalid min file size: %w", err), util.I18nErrorInvalidMinSize) + } + maxFileSize, err := util.ParseBytes(r.Form.Get("fs_max_size")) + if err != nil { + return dataprovider.EventConditions{}, util.NewI18nError(fmt.Errorf("invalid max file size: %w", err), util.I18nErrorInvalidMaxSize) + } + var eventStatuses []int + for _, s := range r.Form["fs_statuses"] { + status, err := strconv.ParseInt(s, 10, 32) + if err == nil { + eventStatuses = append(eventStatuses, int(status)) + } + } + conditions := dataprovider.EventConditions{ + FsEvents: r.Form["fs_events"], + ProviderEvents: r.Form["provider_events"], + IDPLoginEvent: getIDPLoginEventFromPostField(r), + Schedules: schedules, + Options: dataprovider.ConditionOptions{ + Names: names, + GroupNames: groupNames, + RoleNames: roleNames, + FsPaths: fsPaths, + Protocols: r.Form["fs_protocols"], + EventStatuses: eventStatuses, + ProviderObjects: r.Form["provider_objects"], + MinFileSize: minFileSize, + MaxFileSize: maxFileSize, + ConcurrentExecution: r.Form.Get("concurrent_execution") != "", + }, + } + return conditions, nil +} + +func getEventRuleActionsFromPostFields(r *http.Request) []dataprovider.EventAction { + var actions []dataprovider.EventAction + + names := r.Form["action_name"] + orders := r.Form["action_order"] + + for idx, name := range names { + if name != "" { + order, err := strconv.Atoi(orders[idx]) + if err == nil { + options := r.Form["action_options"+strconv.Itoa(idx)] + actions = append(actions, dataprovider.EventAction{ + BaseEventAction: dataprovider.BaseEventAction{ + Name: name, + }, + Order: order + 1, + Options: dataprovider.EventActionOptions{ + IsFailureAction: slices.Contains(options, "1"), + StopOnFailure: slices.Contains(options, "2"), + ExecuteSync: slices.Contains(options, "3"), + }, + }) + } + } + } + + return actions +} + +func updateRepeaterFormRuleFields(r *http.Request) { + for k := range r.Form { + if hasPrefixAndSuffix(k, "schedules[", "][schedule_hour]") { + base, _ := strings.CutSuffix(k, "[schedule_hour]") + r.Form.Add("schedule_hour", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("schedule_day_of_week", strings.TrimSpace(r.Form.Get(base+"[schedule_day_of_week]"))) + r.Form.Add("schedule_day_of_month", strings.TrimSpace(r.Form.Get(base+"[schedule_day_of_month]"))) + r.Form.Add("schedule_month", strings.TrimSpace(r.Form.Get(base+"[schedule_month]"))) + continue + } + if hasPrefixAndSuffix(k, "name_filters[", "][name_pattern]") { + base, _ := strings.CutSuffix(k, "[name_pattern]") + r.Form.Add("name_pattern", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("type_name_pattern", strings.TrimSpace(r.Form.Get(base+"[type_name_pattern]"))) + continue + } + if hasPrefixAndSuffix(k, "group_name_filters[", "][group_name_pattern]") { + base, _ := strings.CutSuffix(k, "[group_name_pattern]") + r.Form.Add("group_name_pattern", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("type_group_name_pattern", strings.TrimSpace(r.Form.Get(base+"[type_group_name_pattern]"))) + continue + } + if hasPrefixAndSuffix(k, "role_name_filters[", "][role_name_pattern]") { + base, _ := strings.CutSuffix(k, "[role_name_pattern]") + r.Form.Add("role_name_pattern", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("type_role_name_pattern", strings.TrimSpace(r.Form.Get(base+"[type_role_name_pattern]"))) + continue + } + if hasPrefixAndSuffix(k, "path_filters[", "][fs_path_pattern]") { + base, _ := strings.CutSuffix(k, "[fs_path_pattern]") + r.Form.Add("fs_path_pattern", strings.TrimSpace(r.Form.Get(k))) + r.Form.Add("type_fs_path_pattern", strings.TrimSpace(r.Form.Get(base+"[type_fs_path_pattern]"))) + continue + } + if hasPrefixAndSuffix(k, "actions[", "][action_name]") { + base, _ := strings.CutSuffix(k, "[action_name]") + order, _ := strings.CutPrefix(k, "actions[") + order, _ = strings.CutSuffix(order, "][action_name]") + r.Form.Add("action_name", strings.TrimSpace(r.Form.Get(k))) + r.Form["action_options"+strconv.Itoa(len(r.Form["action_name"])-1)] = r.Form[base+"[action_options][]"] + r.Form.Add("action_order", order) + continue + } + } +} + +func getEventRuleFromPostFields(r *http.Request) (dataprovider.EventRule, error) { + err := r.ParseForm() + if err != nil { + return dataprovider.EventRule{}, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + updateRepeaterFormRuleFields(r) + status, err := strconv.Atoi(r.Form.Get("status")) + if err != nil { + return dataprovider.EventRule{}, fmt.Errorf("invalid status: %w", err) + } + trigger, err := strconv.Atoi(r.Form.Get("trigger")) + if err != nil { + return dataprovider.EventRule{}, fmt.Errorf("invalid trigger: %w", err) + } + conditions, err := getEventRuleConditionsFromPostFields(r) + if err != nil { + return dataprovider.EventRule{}, err + } + rule := dataprovider.EventRule{ + Name: strings.TrimSpace(r.Form.Get("name")), + Status: status, + Description: r.Form.Get("description"), + Trigger: trigger, + Conditions: conditions, + Actions: getEventRuleActionsFromPostFields(r), + } + return rule, nil +} + +func getRoleFromPostFields(r *http.Request) (dataprovider.Role, error) { + err := r.ParseForm() + if err != nil { + return dataprovider.Role{}, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + + return dataprovider.Role{ + Name: strings.TrimSpace(r.Form.Get("name")), + Description: r.Form.Get("description"), + }, nil +} + +func getIPListEntryFromPostFields(r *http.Request, listType dataprovider.IPListType) (dataprovider.IPListEntry, error) { + err := r.ParseForm() + if err != nil { + return dataprovider.IPListEntry{}, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + var mode int + if listType == dataprovider.IPListTypeDefender { + mode, err = strconv.Atoi(r.Form.Get("mode")) + if err != nil { + return dataprovider.IPListEntry{}, fmt.Errorf("invalid mode: %w", err) + } + } else { + mode = 1 + } + protocols := 0 + for _, proto := range r.Form["protocols"] { + p, err := strconv.Atoi(proto) + if err == nil { + protocols += p + } + } + + return dataprovider.IPListEntry{ + IPOrNet: strings.TrimSpace(r.Form.Get("ipornet")), + Mode: mode, + Protocols: protocols, + Description: r.Form.Get("description"), + }, nil +} + +func getSFTPConfigsFromPostFields(r *http.Request) *dataprovider.SFTPDConfigs { + return &dataprovider.SFTPDConfigs{ + HostKeyAlgos: r.Form["sftp_host_key_algos"], + PublicKeyAlgos: r.Form["sftp_pub_key_algos"], + KexAlgorithms: r.Form["sftp_kex_algos"], + Ciphers: r.Form["sftp_ciphers"], + MACs: r.Form["sftp_macs"], + } +} + +func getACMEConfigsFromPostFields(r *http.Request) *dataprovider.ACMEConfigs { + port, err := strconv.Atoi(r.Form.Get("acme_port")) + if err != nil { + port = 80 + } + var protocols int + for _, val := range r.Form["acme_protocols"] { + switch val { + case "1": + protocols++ + case "2": + protocols += 2 + case "3": + protocols += 4 + } + } + + return &dataprovider.ACMEConfigs{ + Domain: strings.TrimSpace(r.Form.Get("acme_domain")), + Email: strings.TrimSpace(r.Form.Get("acme_email")), + HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: port}, + Protocols: protocols, + } +} + +func getSMTPConfigsFromPostFields(r *http.Request) *dataprovider.SMTPConfigs { + port, err := strconv.Atoi(r.Form.Get("smtp_port")) + if err != nil { + port = 587 + } + authType, err := strconv.Atoi(r.Form.Get("smtp_auth")) + if err != nil { + authType = 0 + } + encryption, err := strconv.Atoi(r.Form.Get("smtp_encryption")) + if err != nil { + encryption = 0 + } + debug := 0 + if r.Form.Get("smtp_debug") != "" { + debug = 1 + } + oauth2Provider := 0 + if r.Form.Get("smtp_oauth2_provider") == "1" { + oauth2Provider = 1 + } + return &dataprovider.SMTPConfigs{ + Host: strings.TrimSpace(r.Form.Get("smtp_host")), + Port: port, + From: strings.TrimSpace(r.Form.Get("smtp_from")), + User: strings.TrimSpace(r.Form.Get("smtp_username")), + Password: getSecretFromFormField(r, "smtp_password"), + AuthType: authType, + Encryption: encryption, + Domain: strings.TrimSpace(r.Form.Get("smtp_domain")), + Debug: debug, + OAuth2: dataprovider.SMTPOAuth2{ + Provider: oauth2Provider, + Tenant: strings.TrimSpace(r.Form.Get("smtp_oauth2_tenant")), + ClientID: strings.TrimSpace(r.Form.Get("smtp_oauth2_client_id")), + ClientSecret: getSecretFromFormField(r, "smtp_oauth2_client_secret"), + RefreshToken: getSecretFromFormField(r, "smtp_oauth2_refresh_token"), + }, + } +} + +func getImageInputBytes(r *http.Request, fieldName, removeFieldName string, defaultVal []byte) ([]byte, error) { + var result []byte + remove := r.Form.Get(removeFieldName) + if remove == "" || remove == "0" { + result = defaultVal + } + f, _, err := r.FormFile(fieldName) + if err != nil { + if errors.Is(err, http.ErrMissingFile) { + return result, nil + } + return nil, err + } + defer f.Close() + + return io.ReadAll(f) +} + +func getBrandingConfigFromPostFields(r *http.Request, config *dataprovider.BrandingConfigs) ( + *dataprovider.BrandingConfigs, error, +) { + if config == nil { + config = &dataprovider.BrandingConfigs{} + } + adminLogo, err := getImageInputBytes(r, "branding_webadmin_logo", "branding_webadmin_logo_remove", config.WebAdmin.Logo) + if err != nil { + return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + adminFavicon, err := getImageInputBytes(r, "branding_webadmin_favicon", "branding_webadmin_favicon_remove", + config.WebAdmin.Favicon) + if err != nil { + return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + clientLogo, err := getImageInputBytes(r, "branding_webclient_logo", "branding_webclient_logo_remove", + config.WebClient.Logo) + if err != nil { + return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + clientFavicon, err := getImageInputBytes(r, "branding_webclient_favicon", "branding_webclient_favicon_remove", + config.WebClient.Favicon) + if err != nil { + return nil, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + + branding := &dataprovider.BrandingConfigs{ + WebAdmin: dataprovider.BrandingConfig{ + Name: strings.TrimSpace(r.Form.Get("branding_webadmin_name")), + ShortName: strings.TrimSpace(r.Form.Get("branding_webadmin_short_name")), + Logo: adminLogo, + Favicon: adminFavicon, + DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_name")), + DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webadmin_disclaimer_url")), + }, + WebClient: dataprovider.BrandingConfig{ + Name: strings.TrimSpace(r.Form.Get("branding_webclient_name")), + ShortName: strings.TrimSpace(r.Form.Get("branding_webclient_short_name")), + Logo: clientLogo, + Favicon: clientFavicon, + DisclaimerName: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_name")), + DisclaimerURL: strings.TrimSpace(r.Form.Get("branding_webclient_disclaimer_url")), + }, + } + return branding, nil +} + +func (s *httpdServer) handleWebAdminForgotPwd(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + if !smtp.IsEnabled() { + s.renderNotFoundPage(w, r, errors.New("this page does not exist")) + return + } + s.renderForgotPwdPage(w, r, nil) +} + +func (s *httpdServer) handleWebAdminForgotPwdPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + err := r.ParseForm() + if err != nil { + s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err = handleForgotPassword(r, r.Form.Get("username"), true) + if err != nil { + s.renderForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric)) + return + } + http.Redirect(w, r, webAdminResetPwdPath, http.StatusFound) +} + +func (s *httpdServer) handleWebAdminPasswordReset(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + if !smtp.IsEnabled() { + s.renderNotFoundPage(w, r, errors.New("this page does not exist")) + return + } + s.renderResetPwdPage(w, r, nil) +} + +func (s *httpdServer) handleWebAdminTwoFactor(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderTwoFactorPage(w, r, nil) +} + +func (s *httpdServer) handleWebAdminTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderTwoFactorRecoveryPage(w, r, nil) +} + +func (s *httpdServer) handleWebAdminMFA(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderMFAPage(w, r) +} + +func (s *httpdServer) handleWebAdminProfile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderProfilePage(w, r, nil) +} + +func (s *httpdServer) handleWebAdminChangePwd(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderChangePasswordPage(w, r, nil) +} + +func (s *httpdServer) handleWebAdminProfilePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + err := r.ParseForm() + if err != nil { + s.renderProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidToken)) + return + } + admin, err := dataprovider.AdminExists(claims.Username) + if err != nil { + s.renderProfilePage(w, r, err) + return + } + admin.Filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" + admin.Email = r.Form.Get("email") + admin.Description = r.Form.Get("description") + err = dataprovider.UpdateAdmin(&admin, dataprovider.ActionExecutorSelf, ipAddr, admin.Role) + if err != nil { + s.renderProfilePage(w, r, err) + return + } + s.renderMessagePage(w, r, util.I18nProfileTitle, http.StatusOK, nil, util.I18nProfileUpdated) +} + +func (s *httpdServer) handleWebMaintenance(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderMaintenancePage(w, r, nil) +} + +func (s *httpdServer) handleWebRestore(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, MaxRestoreSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + err = r.ParseMultipartForm(MaxRestoreSize) + if err != nil { + s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + restoreMode, err := strconv.Atoi(r.Form.Get("mode")) + if err != nil { + s.renderMaintenancePage(w, r, err) + return + } + scanQuota, err := strconv.Atoi(r.Form.Get("quota")) + if err != nil { + s.renderMaintenancePage(w, r, err) + return + } + backupFile, _, err := r.FormFile("backup_file") + if err != nil { + s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorBackupFile)) + return + } + defer backupFile.Close() + + backupContent, err := io.ReadAll(backupFile) + if err != nil || len(backupContent) == 0 { + if len(backupContent) == 0 { + err = errors.New("backup file size must be greater than 0") + } + s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorBackupFile)) + return + } + + if err := restoreBackup(backupContent, "", scanQuota, restoreMode, claims.Username, ipAddr, claims.Role); err != nil { + s.renderMaintenancePage(w, r, util.NewI18nError(err, util.I18nErrorRestore)) + return + } + + s.renderMessagePage(w, r, util.I18nMaintenanceTitle, http.StatusOK, nil, util.I18nBackupOK) +} + +func getAllAdmins(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) + return + } + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetAdmins(limit, offset, dataprovider.OrderASC) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleGetWebAdmins(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := s.getBasePageData(util.I18nAdminsTitle, webAdminsPath, w, r) + renderAdminTemplate(w, templateAdmins, data) +} + +func (s *httpdServer) handleWebAdminSetupGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + if dataprovider.HasAdmin() { + http.Redirect(w, r, webAdminLoginPath, http.StatusFound) + return + } + s.renderAdminSetupPage(w, r, "", nil) +} + +func (s *httpdServer) handleWebAddAdminGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + admin := &dataprovider.Admin{ + Status: 1, + Permissions: []string{dataprovider.PermAdminAny}, + } + s.renderAddUpdateAdminPage(w, r, admin, nil, true) +} + +func (s *httpdServer) handleWebUpdateAdminGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + username := getURLParam(r, "username") + admin, err := dataprovider.AdminExists(username) + if err == nil { + s.renderAddUpdateAdminPage(w, r, &admin, nil, false) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebAddAdminPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + admin, err := getAdminFromPostFields(r) + if err != nil { + s.renderAddUpdateAdminPage(w, r, &admin, err, true) + return + } + if admin.Password == "" { + // Administrators can be used with OpenID Connect or for authentication + // via API key, in these cases the password is not necessary, we create + // a non-usable one. This feature is only useful for WebAdmin, in REST + // API you can create an unusable password externally. + admin.Password = util.GenerateUniqueID() + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err = dataprovider.AddAdmin(&admin, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderAddUpdateAdminPage(w, r, &admin, err, true) + return + } + http.Redirect(w, r, webAdminsPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateAdminPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + username := getURLParam(r, "username") + admin, err := dataprovider.AdminExists(username) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + + updatedAdmin, err := getAdminFromPostFields(r) + if err != nil { + s.renderAddUpdateAdminPage(w, r, &updatedAdmin, err, false) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedAdmin.ID = admin.ID + updatedAdmin.Username = admin.Username + if updatedAdmin.Password == "" { + updatedAdmin.Password = admin.Password + } + updatedAdmin.Filters.TOTPConfig = admin.Filters.TOTPConfig + updatedAdmin.Filters.RecoveryCodes = admin.Filters.RecoveryCodes + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderAddUpdateAdminPage(w, r, &updatedAdmin, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken), false) + return + } + if username == claims.Username { + if !util.SlicesEqual(admin.Permissions, updatedAdmin.Permissions) { + s.renderAddUpdateAdminPage(w, r, &updatedAdmin, + util.NewI18nError(errors.New("you cannot change your permissions"), + util.I18nErrorAdminSelfPerms, + ), false) + return + } + if updatedAdmin.Status == 0 { + s.renderAddUpdateAdminPage(w, r, &updatedAdmin, + util.NewI18nError(errors.New("you cannot disable yourself"), + util.I18nErrorAdminSelfDisable, + ), false) + return + } + if updatedAdmin.Role != claims.Role { + s.renderAddUpdateAdminPage(w, r, &updatedAdmin, + util.NewI18nError( + errors.New("you cannot add/change your role"), + util.I18nErrorAdminSelfRole, + ), false) + return + } + updatedAdmin.Filters.RequirePasswordChange = admin.Filters.RequirePasswordChange + updatedAdmin.Filters.RequireTwoFactor = admin.Filters.RequireTwoFactor + } + err = dataprovider.UpdateAdmin(&updatedAdmin, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderAddUpdateAdminPage(w, r, &updatedAdmin, err, false) + return + } + http.Redirect(w, r, webAdminsPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebDefenderPage(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + data := defenderHostsPage{ + basePage: s.getBasePageData(util.I18nDefenderTitle, webDefenderPath, w, r), + DefenderHostsURL: webDefenderHostsPath, + } + + renderAdminTemplate(w, templateDefender, data) +} + +func getAllUsers(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) + return + } + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetUsers(limit, offset, dataprovider.OrderASC, claims.Role) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleGetWebUsers(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + data := s.getBasePageData(util.I18nUsersTitle, webUsersPath, w, r) + renderAdminTemplate(w, templateUsers, data) +} + +func (s *httpdServer) handleWebTemplateFolderGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + if r.URL.Query().Get("from") != "" { + name := r.URL.Query().Get("from") + folder, err := dataprovider.GetFolderByName(name) + if err == nil { + folder.FsConfig.SetEmptySecrets() + s.renderFolderPage(w, r, folder, folderPageModeTemplate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } + } else { + folder := vfs.BaseVirtualFolder{} + s.renderFolderPage(w, r, folder, folderPageModeTemplate, nil) + } +} + +func (s *httpdServer) handleWebTemplateFolderPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + templateFolder := vfs.BaseVirtualFolder{} + err = r.ParseMultipartForm(maxRequestSize) + if err != nil { + s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, util.NewI18nError(err, util.I18nErrorInvalidForm), "") + return + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + + templateFolder.MappedPath = r.Form.Get("mapped_path") + templateFolder.Description = r.Form.Get("description") + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, err, "") + return + } + templateFolder.FsConfig = fsConfig + + var dump dataprovider.BackupData + + foldersFields := getFoldersForTemplate(r) + for _, tmpl := range foldersFields { + f := getFolderFromTemplate(templateFolder, tmpl) + if err := dataprovider.ValidateFolder(&f); err != nil { + s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, err, "") + return + } + dump.Folders = append(dump.Folders, f) + } + + if len(dump.Folders) == 0 { + s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, http.StatusBadRequest, + util.NewI18nError( + errors.New("no valid folder defined, unable to complete the requested action"), + util.I18nErrorFolderTemplate, + ), "") + return + } + if err = RestoreFolders(dump.Folders, "", 1, 0, claims.Username, ipAddr, claims.Role); err != nil { + s.renderMessagePage(w, r, util.I18nTemplateFolderTitle, getRespStatus(err), err, "") + return + } + http.Redirect(w, r, webFoldersPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebTemplateUserGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + tokenAdmin := getAdminFromToken(r) + admin, err := dataprovider.AdminExists(tokenAdmin.Username) + if err != nil { + s.renderInternalServerErrorPage(w, r, fmt.Errorf("unable to get the admin %q: %w", tokenAdmin.Username, err)) + return + } + if r.URL.Query().Get("from") != "" { + username := r.URL.Query().Get("from") + user, err := dataprovider.UserExists(username, admin.Role) + if err == nil { + user.SetEmptySecrets() + user.PublicKeys = nil + user.Email = "" + user.Filters.AdditionalEmails = nil + user.Description = "" + if user.ExpirationDate == 0 && admin.Filters.Preferences.DefaultUsersExpiration > 0 { + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) + } + s.renderUserPage(w, r, &user, userPageModeTemplate, nil, &admin) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } + } else { + user := dataprovider.User{BaseUser: sdk.BaseUser{ + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }} + if admin.Filters.Preferences.DefaultUsersExpiration > 0 { + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) + } + s.renderUserPage(w, r, &user, userPageModeTemplate, nil, &admin) + } +} + +func (s *httpdServer) handleWebTemplateUserPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + templateUser, err := getUserFromPostFields(r) + if err != nil { + s.renderMessagePage(w, r, util.I18nTemplateUserTitle, http.StatusBadRequest, err, "") + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + + var dump dataprovider.BackupData + + userTmplFields := getUsersForTemplate(r) + for _, tmpl := range userTmplFields { + u := getUserFromTemplate(templateUser, tmpl) + if err := dataprovider.ValidateUser(&u); err != nil { + s.renderMessagePage(w, r, util.I18nTemplateUserTitle, http.StatusBadRequest, err, "") + return + } + if claims.Role != "" { + u.Role = claims.Role + } + dump.Users = append(dump.Users, u) + } + + if len(dump.Users) == 0 { + s.renderMessagePage(w, r, util.I18nTemplateUserTitle, + http.StatusBadRequest, util.NewI18nError( + errors.New("no valid user defined, unable to complete the requested action"), + util.I18nErrorUserTemplate, + ), "") + return + } + if err = RestoreUsers(dump.Users, "", 1, 0, claims.Username, ipAddr, claims.Role); err != nil { + s.renderMessagePage(w, r, util.I18nTemplateUserTitle, getRespStatus(err), err, "") + return + } + http.Redirect(w, r, webUsersPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebAddUserGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + tokenAdmin := getAdminFromToken(r) + admin, err := dataprovider.AdminExists(tokenAdmin.Username) + if err != nil { + s.renderInternalServerErrorPage(w, r, fmt.Errorf("unable to get the admin %q: %w", tokenAdmin.Username, err)) + return + } + user := dataprovider.User{BaseUser: sdk.BaseUser{ + Status: 1, + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }}, + } + if admin.Filters.Preferences.DefaultUsersExpiration > 0 { + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(admin.Filters.Preferences.DefaultUsersExpiration))) + } + s.renderUserPage(w, r, &user, userPageModeAdd, nil, &admin) +} + +func (s *httpdServer) handleWebUpdateUserGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + username := getURLParam(r, "username") + user, err := dataprovider.UserExists(username, claims.Role) + if err == nil { + s.renderUserPage(w, r, &user, userPageModeUpdate, nil, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebAddUserPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + user, err := getUserFromPostFields(r) + if err != nil { + s.renderUserPage(w, r, &user, userPageModeAdd, err, nil) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + user = getUserFromTemplate(user, userTemplateFields{ + Username: user.Username, + Password: user.Password, + PublicKeys: user.PublicKeys, + RequirePwdChange: user.Filters.RequirePasswordChange, + }) + if claims.Role != "" { + user.Role = claims.Role + } + user.Filters.RecoveryCodes = nil + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: false, + } + err = dataprovider.AddUser(&user, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderUserPage(w, r, &user, userPageModeAdd, err, nil) + return + } + http.Redirect(w, r, webUsersPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateUserPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + username := getURLParam(r, "username") + user, err := dataprovider.UserExists(username, claims.Role) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + updatedUser, err := getUserFromPostFields(r) + if err != nil { + s.renderUserPage(w, r, &user, userPageModeUpdate, err, nil) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedUser.ID = user.ID + updatedUser.Username = user.Username + updatedUser.Filters.RecoveryCodes = user.Filters.RecoveryCodes + updatedUser.Filters.TOTPConfig = user.Filters.TOTPConfig + updatedUser.LastPasswordChange = user.LastPasswordChange + updatedUser.SetEmptySecretsIfNil() + if updatedUser.Password == redactedSecret { + updatedUser.Password = user.Password + } + updateEncryptedSecrets(&updatedUser.FsConfig, &user.FsConfig) + + updatedUser = getUserFromTemplate(updatedUser, userTemplateFields{ + Username: updatedUser.Username, + Password: updatedUser.Password, + PublicKeys: updatedUser.PublicKeys, + RequirePwdChange: updatedUser.Filters.RequirePasswordChange, + }) + if claims.Role != "" { + updatedUser.Role = claims.Role + } + + err = dataprovider.UpdateUser(&updatedUser, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderUserPage(w, r, &updatedUser, userPageModeUpdate, err, nil) + return + } + if r.Form.Get("disconnect") != "" { + disconnectUser(user.Username, claims.Username, claims.Role) + } + http.Redirect(w, r, webUsersPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebGetStatus(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + data := statusPage{ + basePage: s.getBasePageData(util.I18nStatusTitle, webStatusPath, w, r), + Status: getServicesStatus(), + } + renderAdminTemplate(w, templateStatus, data) +} + +func (s *httpdServer) handleWebGetConnections(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + + data := s.getBasePageData(util.I18nSessionsTitle, webConnectionsPath, w, r) + renderAdminTemplate(w, templateConnections, data) +} + +func (s *httpdServer) handleWebAddFolderGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderFolderPage(w, r, vfs.BaseVirtualFolder{}, folderPageModeAdd, nil) +} + +func (s *httpdServer) handleWebAddFolderPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + folder := vfs.BaseVirtualFolder{} + err = r.ParseMultipartForm(maxRequestSize) + if err != nil { + s.renderFolderPage(w, r, folder, folderPageModeAdd, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + folder.MappedPath = strings.TrimSpace(r.Form.Get("mapped_path")) + folder.Name = strings.TrimSpace(r.Form.Get("name")) + folder.Description = r.Form.Get("description") + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + s.renderFolderPage(w, r, folder, folderPageModeAdd, err) + return + } + folder.FsConfig = fsConfig + folder = getFolderFromTemplate(folder, folder.Name) + + err = dataprovider.AddFolder(&folder, claims.Username, ipAddr, claims.Role) + if err == nil { + http.Redirect(w, r, webFoldersPath, http.StatusSeeOther) + } else { + s.renderFolderPage(w, r, folder, folderPageModeAdd, err) + } +} + +func (s *httpdServer) handleWebUpdateFolderGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + name := getURLParam(r, "name") + folder, err := dataprovider.GetFolderByName(name) + if err == nil { + s.renderFolderPage(w, r, folder, folderPageModeUpdate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebUpdateFolderPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + name := getURLParam(r, "name") + folder, err := dataprovider.GetFolderByName(name) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + + err = r.ParseMultipartForm(maxRequestSize) + if err != nil { + s.renderFolderPage(w, r, folder, folderPageModeUpdate, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + fsConfig, err := getFsConfigFromPostFields(r) + if err != nil { + s.renderFolderPage(w, r, folder, folderPageModeUpdate, err) + return + } + updatedFolder := vfs.BaseVirtualFolder{ + MappedPath: strings.TrimSpace(r.Form.Get("mapped_path")), + Description: r.Form.Get("description"), + } + updatedFolder.ID = folder.ID + updatedFolder.Name = folder.Name + updatedFolder.FsConfig = fsConfig + updatedFolder.FsConfig.SetEmptySecretsIfNil() + updateEncryptedSecrets(&updatedFolder.FsConfig, &folder.FsConfig) + + updatedFolder = getFolderFromTemplate(updatedFolder, updatedFolder.Name) + + err = dataprovider.UpdateFolder(&updatedFolder, folder.Users, folder.Groups, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderFolderPage(w, r, updatedFolder, folderPageModeUpdate, err) + return + } + http.Redirect(w, r, webFoldersPath, http.StatusSeeOther) +} + +func (s *httpdServer) getWebVirtualFolders(w http.ResponseWriter, r *http.Request, limit int, minimal bool) ([]vfs.BaseVirtualFolder, error) { + folders := make([]vfs.BaseVirtualFolder, 0, 50) + for { + f, err := dataprovider.GetFolders(limit, len(folders), dataprovider.OrderASC, minimal) + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return folders, err + } + folders = append(folders, f...) + if len(f) < limit { + break + } + } + return folders, nil +} + +func getAllFolders(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetFolders(limit, offset, dataprovider.OrderASC, false) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleWebGetFolders(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := s.getBasePageData(util.I18nFoldersTitle, webFoldersPath, w, r) + renderAdminTemplate(w, templateFolders, data) +} + +func (s *httpdServer) getWebGroups(w http.ResponseWriter, r *http.Request, limit int, minimal bool) ([]dataprovider.Group, error) { + groups := make([]dataprovider.Group, 0, 50) + for { + f, err := dataprovider.GetGroups(limit, len(groups), dataprovider.OrderASC, minimal) + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return groups, err + } + groups = append(groups, f...) + if len(f) < limit { + break + } + } + return groups, nil +} + +func getAllGroups(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetGroups(limit, offset, dataprovider.OrderASC, false) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleWebGetGroups(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := s.getBasePageData(util.I18nGroupsTitle, webGroupsPath, w, r) + renderAdminTemplate(w, templateGroups, data) +} + +func (s *httpdServer) handleWebAddGroupGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderGroupPage(w, r, dataprovider.Group{}, genericPageModeAdd, nil) +} + +func (s *httpdServer) handleWebAddGroupPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + group, err := getGroupFromPostFields(r) + if err != nil { + s.renderGroupPage(w, r, group, genericPageModeAdd, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err = dataprovider.AddGroup(&group, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderGroupPage(w, r, group, genericPageModeAdd, err) + return + } + http.Redirect(w, r, webGroupsPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateGroupGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + name := getURLParam(r, "name") + group, err := dataprovider.GroupExists(name) + if err == nil { + s.renderGroupPage(w, r, group, genericPageModeUpdate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebUpdateGroupPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + name := getURLParam(r, "name") + group, err := dataprovider.GroupExists(name) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + updatedGroup, err := getGroupFromPostFields(r) + if err != nil { + s.renderGroupPage(w, r, group, genericPageModeUpdate, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedGroup.ID = group.ID + updatedGroup.Name = group.Name + updatedGroup.SetEmptySecretsIfNil() + + updateEncryptedSecrets(&updatedGroup.UserSettings.FsConfig, &group.UserSettings.FsConfig) + + err = dataprovider.UpdateGroup(&updatedGroup, group.Users, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderGroupPage(w, r, updatedGroup, genericPageModeUpdate, err) + return + } + http.Redirect(w, r, webGroupsPath, http.StatusSeeOther) +} + +func (s *httpdServer) getWebEventActions(w http.ResponseWriter, r *http.Request, limit int, minimal bool, +) ([]dataprovider.BaseEventAction, error) { + actions := make([]dataprovider.BaseEventAction, 0, limit) + for { + res, err := dataprovider.GetEventActions(limit, len(actions), dataprovider.OrderASC, minimal) + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return actions, err + } + actions = append(actions, res...) + if len(res) < limit { + break + } + } + return actions, nil +} + +func getAllActions(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetEventActions(limit, offset, dataprovider.OrderASC, false) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleWebGetEventActions(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := s.getBasePageData(util.I18nActionsTitle, webAdminEventActionsPath, w, r) + renderAdminTemplate(w, templateEventActions, data) +} + +func (s *httpdServer) handleWebAddEventActionGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + action := dataprovider.BaseEventAction{ + Type: dataprovider.ActionTypeHTTP, + } + s.renderEventActionPage(w, r, action, genericPageModeAdd, nil) +} + +func (s *httpdServer) handleWebAddEventActionPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + action, err := getEventActionFromPostFields(r) + if err != nil { + s.renderEventActionPage(w, r, action, genericPageModeAdd, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + if err = dataprovider.AddEventAction(&action, claims.Username, ipAddr, claims.Role); err != nil { + s.renderEventActionPage(w, r, action, genericPageModeAdd, err) + return + } + http.Redirect(w, r, webAdminEventActionsPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateEventActionGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + name := getURLParam(r, "name") + action, err := dataprovider.EventActionExists(name) + if err == nil { + s.renderEventActionPage(w, r, action, genericPageModeUpdate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebUpdateEventActionPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + name := getURLParam(r, "name") + action, err := dataprovider.EventActionExists(name) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + updatedAction, err := getEventActionFromPostFields(r) + if err != nil { + s.renderEventActionPage(w, r, updatedAction, genericPageModeUpdate, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedAction.ID = action.ID + updatedAction.Name = action.Name + updatedAction.Options.SetEmptySecretsIfNil() + switch updatedAction.Type { + case dataprovider.ActionTypeHTTP: + if updatedAction.Options.HTTPConfig.Password.IsNotPlainAndNotEmpty() { + updatedAction.Options.HTTPConfig.Password = action.Options.HTTPConfig.Password + } + } + err = dataprovider.UpdateEventAction(&updatedAction, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderEventActionPage(w, r, updatedAction, genericPageModeUpdate, err) + return + } + http.Redirect(w, r, webAdminEventActionsPath, http.StatusSeeOther) +} + +func getAllRules(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetEventRules(limit, offset, dataprovider.OrderASC) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleWebGetEventRules(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := s.getBasePageData(util.I18nRulesTitle, webAdminEventRulesPath, w, r) + renderAdminTemplate(w, templateEventRules, data) +} + +func (s *httpdServer) handleWebAddEventRuleGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + rule := dataprovider.EventRule{ + Status: 1, + Trigger: dataprovider.EventTriggerFsEvent, + } + s.renderEventRulePage(w, r, rule, genericPageModeAdd, nil) +} + +func (s *httpdServer) handleWebAddEventRulePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + rule, err := getEventRuleFromPostFields(r) + if err != nil { + s.renderEventRulePage(w, r, rule, genericPageModeAdd, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + err = verifyCSRFToken(r, s.csrfTokenAuth) + if err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + if err = dataprovider.AddEventRule(&rule, claims.Username, ipAddr, claims.Role); err != nil { + s.renderEventRulePage(w, r, rule, genericPageModeAdd, err) + return + } + http.Redirect(w, r, webAdminEventRulesPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateEventRuleGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + name := getURLParam(r, "name") + rule, err := dataprovider.EventRuleExists(name) + if err == nil { + s.renderEventRulePage(w, r, rule, genericPageModeUpdate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebUpdateEventRulePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + name := getURLParam(r, "name") + rule, err := dataprovider.EventRuleExists(name) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + updatedRule, err := getEventRuleFromPostFields(r) + if err != nil { + s.renderEventRulePage(w, r, updatedRule, genericPageModeUpdate, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedRule.ID = rule.ID + updatedRule.Name = rule.Name + err = dataprovider.UpdateEventRule(&updatedRule, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderEventRulePage(w, r, updatedRule, genericPageModeUpdate, err) + return + } + http.Redirect(w, r, webAdminEventRulesPath, http.StatusSeeOther) +} + +func (s *httpdServer) getWebRoles(w http.ResponseWriter, r *http.Request, limit int, minimal bool) ([]dataprovider.Role, error) { + roles := make([]dataprovider.Role, 0, 10) + for { + res, err := dataprovider.GetRoles(limit, len(roles), dataprovider.OrderASC, minimal) + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return roles, err + } + roles = append(roles, res...) + if len(res) < limit { + break + } + } + return roles, nil +} + +func getAllRoles(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + dataGetter := func(limit, offset int) ([]byte, int, error) { + results, err := dataprovider.GetRoles(limit, offset, dataprovider.OrderASC, false) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(results) + return data, len(results), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleWebGetRoles(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + data := s.getBasePageData(util.I18nRolesTitle, webAdminRolesPath, w, r) + + renderAdminTemplate(w, templateRoles, data) +} + +func (s *httpdServer) handleWebAddRoleGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderRolePage(w, r, dataprovider.Role{}, genericPageModeAdd, nil) +} + +func (s *httpdServer) handleWebAddRolePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + role, err := getRoleFromPostFields(r) + if err != nil { + s.renderRolePage(w, r, role, genericPageModeAdd, err) + return + } + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err = dataprovider.AddRole(&role, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderRolePage(w, r, role, genericPageModeAdd, err) + return + } + http.Redirect(w, r, webAdminRolesPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateRoleGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + role, err := dataprovider.RoleExists(getURLParam(r, "name")) + if err == nil { + s.renderRolePage(w, r, role, genericPageModeUpdate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebUpdateRolePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + role, err := dataprovider.RoleExists(getURLParam(r, "name")) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + + updatedRole, err := getRoleFromPostFields(r) + if err != nil { + s.renderRolePage(w, r, role, genericPageModeUpdate, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedRole.ID = role.ID + updatedRole.Name = role.Name + err = dataprovider.UpdateRole(&updatedRole, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderRolePage(w, r, updatedRole, genericPageModeUpdate, err) + return + } + http.Redirect(w, r, webAdminRolesPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebGetEvents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := eventsPage{ + basePage: s.getBasePageData(util.I18nEventsTitle, webEventsPath, w, r), + FsEventsSearchURL: webEventsFsSearchPath, + ProviderEventsSearchURL: webEventsProviderSearchPath, + LogEventsSearchURL: webEventsLogSearchPath, + } + renderAdminTemplate(w, templateEvents, data) +} + +func (s *httpdServer) handleWebIPListsPage(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + rtlStatus, rtlProtocols := common.Config.GetRateLimitersStatus() + data := ipListsPage{ + basePage: s.getBasePageData(util.I18nIPListsTitle, webIPListsPath, w, r), + RateLimitersStatus: rtlStatus, + RateLimitersProtocols: strings.Join(rtlProtocols, ", "), + IsAllowListEnabled: common.Config.IsAllowListEnabled(), + } + + renderAdminTemplate(w, templateIPLists, data) +} + +func (s *httpdServer) handleWebAddIPListEntryGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + listType, _, err := getIPListPathParams(r) + if err != nil { + s.renderBadRequestPage(w, r, err) + return + } + s.renderIPListPage(w, r, dataprovider.IPListEntry{Type: listType}, genericPageModeAdd, nil) +} + +func (s *httpdServer) handleWebAddIPListEntryPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + listType, _, err := getIPListPathParams(r) + if err != nil { + s.renderBadRequestPage(w, r, err) + return + } + entry, err := getIPListEntryFromPostFields(r, listType) + if err != nil { + s.renderIPListPage(w, r, entry, genericPageModeAdd, err) + return + } + entry.Type = listType + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + err = dataprovider.AddIPListEntry(&entry, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderIPListPage(w, r, entry, genericPageModeAdd, err) + return + } + http.Redirect(w, r, webIPListsPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebUpdateIPListEntryGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + listType, ipOrNet, err := getIPListPathParams(r) + if err != nil { + s.renderBadRequestPage(w, r, err) + return + } + entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) + if err == nil { + s.renderIPListPage(w, r, entry, genericPageModeUpdate, nil) + } else if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + } else { + s.renderInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleWebUpdateIPListEntryPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + listType, ipOrNet, err := getIPListPathParams(r) + if err != nil { + s.renderBadRequestPage(w, r, err) + return + } + entry, err := dataprovider.IPListEntryExists(ipOrNet, listType) + if errors.Is(err, util.ErrNotFound) { + s.renderNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + updatedEntry, err := getIPListEntryFromPostFields(r, listType) + if err != nil { + s.renderIPListPage(w, r, entry, genericPageModeUpdate, err) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedEntry.Type = listType + updatedEntry.IPOrNet = ipOrNet + err = dataprovider.UpdateIPListEntry(&updatedEntry, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderIPListPage(w, r, entry, genericPageModeUpdate, err) + return + } + http.Redirect(w, r, webIPListsPath, http.StatusSeeOther) +} + +func (s *httpdServer) handleWebConfigs(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + configs, err := dataprovider.GetConfigs() + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + s.renderConfigsPage(w, r, configs, nil, 0) +} + +func (s *httpdServer) handleWebConfigsPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + configs, err := dataprovider.GetConfigs() + if err != nil { + s.renderInternalServerErrorPage(w, r, err) + return + } + err = r.ParseMultipartForm(maxRequestSize) + if err != nil { + s.renderBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + defer r.MultipartForm.RemoveAll() //nolint:errcheck + + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + var configSection int + switch r.Form.Get("form_action") { + case "sftp_submit": + configSection = 1 + sftpConfigs := getSFTPConfigsFromPostFields(r) + configs.SFTPD = sftpConfigs + case "acme_submit": + configSection = 2 + acmeConfigs := getACMEConfigsFromPostFields(r) + configs.ACME = acmeConfigs + if err := acme.GetCertificatesForConfig(acmeConfigs, configurationDir); err != nil { + logger.Info(logSender, "", "unable to get ACME certificates: %v", err) + s.renderConfigsPage(w, r, configs, util.NewI18nError(err, util.I18nErrorACMEGeneric), configSection) + return + } + case "smtp_submit": + configSection = 3 + smtpConfigs := getSMTPConfigsFromPostFields(r) + updateSMTPSecrets(smtpConfigs, configs.SMTP) + configs.SMTP = smtpConfigs + case "branding_submit": + configSection = 4 + brandingConfigs, err := getBrandingConfigFromPostFields(r, configs.Branding) + configs.Branding = brandingConfigs + if err != nil { + logger.Info(logSender, "", "unable to get branding config: %v", err) + s.renderConfigsPage(w, r, configs, err, configSection) + return + } + default: + s.renderBadRequestPage(w, r, errors.New("unsupported form action")) + return + } + + err = dataprovider.UpdateConfigs(&configs, claims.Username, ipAddr, claims.Role) + if err != nil { + s.renderConfigsPage(w, r, configs, err, configSection) + return + } + postConfigsUpdate(configSection, configs) + s.renderMessagePage(w, r, util.I18nConfigsTitle, http.StatusOK, nil, util.I18nConfigsOK) +} + +func postConfigsUpdate(section int, configs dataprovider.Configs) { + switch section { + case 3: + err := configs.SMTP.TryDecrypt() + if err == nil { + smtp.Activate(configs.SMTP) + } else { + logger.Error(logSender, "", "unable to decrypt SMTP configuration, cannot activate configuration: %v", err) + } + case 4: + dbBrandingConfig.Set(configs.Branding) + } +} + +func (s *httpdServer) handleOAuth2TokenRedirect(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + stateToken := r.URL.Query().Get("state") + + state, err := verifyOAuth2Token(s.csrfTokenAuth, stateToken, util.GetIPFromRemoteAddress(r.RemoteAddr)) + if err != nil { + s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, err, "") + return + } + + pendingAuth, err := oauth2Mgr.getPendingAuth(state) + if err != nil { + oauth2Mgr.removePendingAuth(state) + s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError, + util.NewI18nError(err, util.I18nOAuth2ErrorValidateState), "") + return + } + oauth2Mgr.removePendingAuth(state) + + oauth2Config := smtp.OAuth2Config{ + Provider: pendingAuth.Provider, + ClientID: pendingAuth.ClientID, + ClientSecret: pendingAuth.ClientSecret.GetPayload(), + } + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + cfg := oauth2Config.GetOAuth2() + cfg.RedirectURL = pendingAuth.RedirectURL + token, err := cfg.Exchange(ctx, r.URL.Query().Get("code"), oauth2.VerifierOption(pendingAuth.Verifier)) + if err != nil { + s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusInternalServerError, + util.NewI18nError(err, util.I18nOAuth2ErrTokenExchange), "") + return + } + if token.RefreshToken == "" { + errTxt := "the OAuth2 provider returned an empty token. " + + "Some providers only return the token when the user first authorizes. " + + "If you have already registered SFTPGo with this user in the past, revoke access and try again. " + + "This way you will invalidate the previous token" + s.renderMessagePage(w, r, util.I18nOAuth2ErrorTitle, http.StatusBadRequest, + util.NewI18nError(errors.New(errTxt), util.I18nOAuth2ErrNoRefreshToken), "") + return + } + s.renderMessagePageWithString(w, r, util.I18nOAuth2Title, http.StatusOK, nil, util.I18nOAuth2OK, + fmt.Sprintf("%q", token.RefreshToken)) +} + +func updateSMTPSecrets(newConfigs, currentConfigs *dataprovider.SMTPConfigs) { + if currentConfigs == nil { + currentConfigs = &dataprovider.SMTPConfigs{} + } + if newConfigs.Password.IsNotPlainAndNotEmpty() { + newConfigs.Password = currentConfigs.Password + } + if newConfigs.OAuth2.ClientSecret.IsNotPlainAndNotEmpty() { + newConfigs.OAuth2.ClientSecret = currentConfigs.OAuth2.ClientSecret + } + if newConfigs.OAuth2.RefreshToken.IsNotPlainAndNotEmpty() { + newConfigs.OAuth2.RefreshToken = currentConfigs.OAuth2.RefreshToken + } +} diff --git a/internal/httpd/webclient.go b/internal/httpd/webclient.go new file mode 100644 index 00000000..71aba2ce --- /dev/null +++ b/internal/httpd/webclient.go @@ -0,0 +1,2286 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "html/template" + "io" + "math" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "slices" + "strconv" + "strings" + "time" + + "github.com/go-chi/render" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/jwt" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/smtp" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + templateClientDir = "webclient" + templateClientBase = "base.html" + templateClientFiles = "files.html" + templateClientProfile = "profile.html" + templateClientMFA = "mfa.html" + templateClientEditFile = "editfile.html" + templateClientShare = "share.html" + templateClientShares = "shares.html" + templateClientViewPDF = "viewpdf.html" + templateShareLogin = "sharelogin.html" + templateShareDownload = "sharedownload.html" + templateUploadToShare = "shareupload.html" +) + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +var ( + clientTemplates = make(map[string]*template.Template) + unixEpochTime = time.Unix(0, 0) +) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +type baseClientPage struct { + commonBasePage + Title string + CurrentURL string + FilesURL string + SharesURL string + ShareURL string + ProfileURL string + PingURL string + ChangePwdURL string + LogoutURL string + LoginURL string + EditURL string + MFAURL string + CSRFToken string + LoggedUser *dataprovider.User + IsLoggedToShare bool + Branding UIBranding + Languages []string +} + +type dirMapping struct { + DirName string + Href string +} + +type viewPDFPage struct { + commonBasePage + Title string + URL string + Branding UIBranding + Languages []string +} + +type editFilePage struct { + baseClientPage + CurrentDir string + FileURL string + Path string + Name string + ReadOnly bool + Data string +} + +type filesPage struct { + baseClientPage + CurrentDir string + DirsURL string + FileActionsURL string + CheckExistURL string + DownloadURL string + ViewPDFURL string + FileURL string + TasksURL string + CanAddFiles bool + CanCreateDirs bool + CanRename bool + CanDelete bool + CanDownload bool + CanShare bool + CanCopy bool + ShareUploadBaseURL string + Error *util.I18nError + Paths []dirMapping + QuotaUsage *userQuotaUsage + KeepAliveInterval int +} + +type shareLoginPage struct { + commonBasePage + CurrentURL string + Error *util.I18nError + CSRFToken string + Title string + Branding UIBranding + Languages []string + CheckRedirect bool +} + +type shareDownloadPage struct { + baseClientPage + DownloadLink string +} + +type shareUploadPage struct { + baseClientPage + Share *dataprovider.Share + UploadBasePath string +} + +type clientMessagePage struct { + baseClientPage + Error *util.I18nError + Success string + Text string +} + +type clientProfilePage struct { + baseClientPage + PublicKeys []string + TLSCerts []string + CanSubmit bool + AllowAPIKeyAuth bool + Email string + AdditionalEmails []string + AdditionalEmailsString string + Description string + Error *util.I18nError +} + +type changeClientPasswordPage struct { + baseClientPage + Error *util.I18nError +} + +type clientMFAPage struct { + baseClientPage + TOTPConfigs []string + TOTPConfig dataprovider.UserTOTPConfig + GenerateTOTPURL string + ValidateTOTPURL string + SaveTOTPURL string + RecCodesURL string + Protocols []string + RequiredProtocols []string +} + +type clientSharesPage struct { + baseClientPage + BasePublicSharesURL string + BaseURL string +} + +type clientSharePage struct { + baseClientPage + Share *dataprovider.Share + Error *util.I18nError + IsAdd bool +} + +type userQuotaUsage struct { + QuotaSize int64 + QuotaFiles int + UsedQuotaSize int64 + UsedQuotaFiles int + UploadDataTransfer int64 + DownloadDataTransfer int64 + TotalDataTransfer int64 + UsedUploadDataTransfer int64 + UsedDownloadDataTransfer int64 +} + +func (u *userQuotaUsage) HasQuotaInfo() bool { + if dataprovider.GetQuotaTracking() == 0 { + return false + } + if u.HasDiskQuota() { + return true + } + return u.HasTranferQuota() +} + +func (u *userQuotaUsage) HasDiskQuota() bool { + if u.QuotaSize > 0 || u.UsedQuotaSize > 0 { + return true + } + return u.QuotaFiles > 0 || u.UsedQuotaFiles > 0 +} + +func (u *userQuotaUsage) HasTranferQuota() bool { + if u.TotalDataTransfer > 0 || u.UploadDataTransfer > 0 || u.DownloadDataTransfer > 0 { + return true + } + return u.UsedDownloadDataTransfer > 0 || u.UsedUploadDataTransfer > 0 +} + +func (u *userQuotaUsage) GetQuotaSize() string { + if u.QuotaSize > 0 { + return fmt.Sprintf("%s/%s", util.ByteCountIEC(u.UsedQuotaSize), util.ByteCountIEC(u.QuotaSize)) + } + if u.UsedQuotaSize > 0 { + return util.ByteCountIEC(u.UsedQuotaSize) + } + return "" +} + +func (u *userQuotaUsage) GetQuotaFiles() string { + if u.QuotaFiles > 0 { + return fmt.Sprintf("%d/%d", u.UsedQuotaFiles, u.QuotaFiles) + } + if u.UsedQuotaFiles > 0 { + return strconv.FormatInt(int64(u.UsedQuotaFiles), 10) + } + return "" +} + +func (u *userQuotaUsage) GetQuotaSizePercentage() int { + if u.QuotaSize > 0 { + return int(math.Round(100 * float64(u.UsedQuotaSize) / float64(u.QuotaSize))) + } + return 0 +} + +func (u *userQuotaUsage) GetQuotaFilesPercentage() int { + if u.QuotaFiles > 0 { + return int(math.Round(100 * float64(u.UsedQuotaFiles) / float64(u.QuotaFiles))) + } + return 0 +} + +func (u *userQuotaUsage) IsQuotaSizeLow() bool { + return u.GetQuotaSizePercentage() > 85 +} + +func (u *userQuotaUsage) IsQuotaFilesLow() bool { + return u.GetQuotaFilesPercentage() > 85 +} + +func (u *userQuotaUsage) IsDiskQuotaLow() bool { + return u.IsQuotaSizeLow() || u.IsQuotaFilesLow() +} + +func (u *userQuotaUsage) GetTotalTransferQuota() string { + total := u.UsedUploadDataTransfer + u.UsedDownloadDataTransfer + if u.TotalDataTransfer > 0 { + return fmt.Sprintf("%s/%s", util.ByteCountIEC(total), util.ByteCountIEC(u.TotalDataTransfer*1048576)) + } + if total > 0 { + return util.ByteCountIEC(total) + } + return "" +} + +func (u *userQuotaUsage) GetUploadTransferQuota() string { + if u.UploadDataTransfer > 0 { + return fmt.Sprintf("%s/%s", util.ByteCountIEC(u.UsedUploadDataTransfer), + util.ByteCountIEC(u.UploadDataTransfer*1048576)) + } + if u.UsedUploadDataTransfer > 0 { + return util.ByteCountIEC(u.UsedUploadDataTransfer) + } + return "" +} + +func (u *userQuotaUsage) GetDownloadTransferQuota() string { + if u.DownloadDataTransfer > 0 { + return fmt.Sprintf("%s/%s", util.ByteCountIEC(u.UsedDownloadDataTransfer), + util.ByteCountIEC(u.DownloadDataTransfer*1048576)) + } + if u.UsedDownloadDataTransfer > 0 { + return util.ByteCountIEC(u.UsedDownloadDataTransfer) + } + return "" +} + +func (u *userQuotaUsage) GetTotalTransferQuotaPercentage() int { + if u.TotalDataTransfer > 0 { + return int(math.Round(100 * float64(u.UsedDownloadDataTransfer+u.UsedUploadDataTransfer) / float64(u.TotalDataTransfer*1048576))) + } + return 0 +} + +func (u *userQuotaUsage) GetUploadTransferQuotaPercentage() int { + if u.UploadDataTransfer > 0 { + return int(math.Round(100 * float64(u.UsedUploadDataTransfer) / float64(u.UploadDataTransfer*1048576))) + } + return 0 +} + +func (u *userQuotaUsage) GetDownloadTransferQuotaPercentage() int { + if u.DownloadDataTransfer > 0 { + return int(math.Round(100 * float64(u.UsedDownloadDataTransfer) / float64(u.DownloadDataTransfer*1048576))) + } + return 0 +} + +func (u *userQuotaUsage) IsTotalTransferQuotaLow() bool { + if u.TotalDataTransfer > 0 { + return u.GetTotalTransferQuotaPercentage() > 85 + } + return false +} + +func (u *userQuotaUsage) IsUploadTransferQuotaLow() bool { + if u.UploadDataTransfer > 0 { + return u.GetUploadTransferQuotaPercentage() > 85 + } + return false +} + +func (u *userQuotaUsage) IsDownloadTransferQuotaLow() bool { + if u.DownloadDataTransfer > 0 { + return u.GetDownloadTransferQuotaPercentage() > 85 + } + return false +} + +func (u *userQuotaUsage) IsTransferQuotaLow() bool { + return u.IsTotalTransferQuotaLow() || u.IsUploadTransferQuotaLow() || u.IsDownloadTransferQuotaLow() +} + +func (u *userQuotaUsage) IsQuotaLow() bool { + return u.IsDiskQuotaLow() || u.IsTransferQuotaLow() +} + +func newUserQuotaUsage(u *dataprovider.User) *userQuotaUsage { + return &userQuotaUsage{ + QuotaSize: u.QuotaSize, + QuotaFiles: u.QuotaFiles, + UsedQuotaSize: u.UsedQuotaSize, + UsedQuotaFiles: u.UsedQuotaFiles, + TotalDataTransfer: u.TotalDataTransfer, + UploadDataTransfer: u.UploadDataTransfer, + DownloadDataTransfer: u.DownloadDataTransfer, + UsedUploadDataTransfer: u.UsedUploadDataTransfer, + UsedDownloadDataTransfer: u.UsedDownloadDataTransfer, + } +} + +func getFileObjectURL(baseDir, name, baseWebPath string) string { + return fmt.Sprintf("%v?path=%v&_=%v", baseWebPath, url.QueryEscape(path.Join(baseDir, name)), time.Now().UTC().Unix()) +} + +func getFileObjectModTime(t time.Time) int64 { + if isZeroTime(t) { + return 0 + } + return t.UnixMilli() +} + +func loadClientTemplates(templatesPath string) { + filesPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateClientFiles), + } + editFilePath := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateClientEditFile), + } + sharesPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateClientShares), + } + sharePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateClientShare), + } + profilePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateClientProfile), + } + changePwdPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateCommonDir, templateChangePwd), + } + loginPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateCommonLogin), + } + messagePaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateCommonDir, templateMessage), + } + mfaPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateClientMFA), + } + twoFactorPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateTwoFactor), + } + twoFactorRecoveryPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateTwoFactorRecovery), + } + forgotPwdPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateForgotPassword), + } + resetPwdPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateCommonDir, templateResetPassword), + } + viewPDFPaths := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientViewPDF), + } + shareLoginPath := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateCommonDir, templateCommonBaseLogin), + filepath.Join(templatesPath, templateClientDir, templateShareLogin), + } + shareUploadPath := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateUploadToShare), + } + shareDownloadPath := []string{ + filepath.Join(templatesPath, templateCommonDir, templateCommonBase), + filepath.Join(templatesPath, templateClientDir, templateClientBase), + filepath.Join(templatesPath, templateClientDir, templateShareDownload), + } + + filesTmpl := util.LoadTemplate(nil, filesPaths...) + profileTmpl := util.LoadTemplate(nil, profilePaths...) + changePwdTmpl := util.LoadTemplate(nil, changePwdPaths...) + loginTmpl := util.LoadTemplate(nil, loginPaths...) + messageTmpl := util.LoadTemplate(nil, messagePaths...) + mfaTmpl := util.LoadTemplate(nil, mfaPaths...) + twoFactorTmpl := util.LoadTemplate(nil, twoFactorPaths...) + twoFactorRecoveryTmpl := util.LoadTemplate(nil, twoFactorRecoveryPaths...) + editFileTmpl := util.LoadTemplate(nil, editFilePath...) + shareLoginTmpl := util.LoadTemplate(nil, shareLoginPath...) + sharesTmpl := util.LoadTemplate(nil, sharesPaths...) + shareTmpl := util.LoadTemplate(nil, sharePaths...) + forgotPwdTmpl := util.LoadTemplate(nil, forgotPwdPaths...) + resetPwdTmpl := util.LoadTemplate(nil, resetPwdPaths...) + viewPDFTmpl := util.LoadTemplate(nil, viewPDFPaths...) + shareUploadTmpl := util.LoadTemplate(nil, shareUploadPath...) + shareDownloadTmpl := util.LoadTemplate(nil, shareDownloadPath...) + + clientTemplates[templateClientFiles] = filesTmpl + clientTemplates[templateClientProfile] = profileTmpl + clientTemplates[templateChangePwd] = changePwdTmpl + clientTemplates[templateCommonLogin] = loginTmpl + clientTemplates[templateMessage] = messageTmpl + clientTemplates[templateClientMFA] = mfaTmpl + clientTemplates[templateTwoFactor] = twoFactorTmpl + clientTemplates[templateTwoFactorRecovery] = twoFactorRecoveryTmpl + clientTemplates[templateClientEditFile] = editFileTmpl + clientTemplates[templateClientShares] = sharesTmpl + clientTemplates[templateClientShare] = shareTmpl + clientTemplates[templateForgotPassword] = forgotPwdTmpl + clientTemplates[templateResetPassword] = resetPwdTmpl + clientTemplates[templateClientViewPDF] = viewPDFTmpl + clientTemplates[templateShareLogin] = shareLoginTmpl + clientTemplates[templateUploadToShare] = shareUploadTmpl + clientTemplates[templateShareDownload] = shareDownloadTmpl +} + +func (s *httpdServer) getBaseClientPageData(title, currentURL string, w http.ResponseWriter, r *http.Request) baseClientPage { + var csrfToken string + if currentURL != "" { + csrfToken = createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath) + } + + data := baseClientPage{ + commonBasePage: getCommonBasePage(r), + Title: title, + CurrentURL: currentURL, + FilesURL: webClientFilesPath, + SharesURL: webClientSharesPath, + ShareURL: webClientSharePath, + ProfileURL: webClientProfilePath, + PingURL: webClientPingPath, + ChangePwdURL: webChangeClientPwdPath, + LogoutURL: webClientLogoutPath, + EditURL: webClientEditFilePath, + MFAURL: webClientMFAPath, + CSRFToken: csrfToken, + LoggedUser: getUserFromToken(r), + IsLoggedToShare: false, + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + if !strings.HasPrefix(r.RequestURI, webClientPubSharesPath) { + data.LoginURL = webClientLoginPath + } + return data +} + +func (s *httpdServer) renderClientForgotPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := forgotPwdPage{ + commonBasePage: getCommonBasePage(r), + CurrentURL: webClientForgotPwdPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), + LoginURL: webClientLoginPath, + Title: util.I18nForgotPwdTitle, + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + renderClientTemplate(w, templateForgotPassword, data) +} + +func (s *httpdServer) renderClientResetPwdPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := resetPwdPage{ + commonBasePage: getCommonBasePage(r), + CurrentURL: webClientResetPwdPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), + LoginURL: webClientLoginPath, + Title: util.I18nResetPwdTitle, + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + renderClientTemplate(w, templateResetPassword, data) +} + +func (s *httpdServer) renderShareLoginPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := shareLoginPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18nShareLoginTitle, + CurrentURL: r.RequestURI, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, rand.Text(), webBaseClientPath), + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + CheckRedirect: false, + } + renderClientTemplate(w, templateShareLogin, data) +} + +func renderClientTemplate(w http.ResponseWriter, tmplName string, data any) { + err := clientTemplates[tmplName].ExecuteTemplate(w, tmplName, data) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (s *httpdServer) renderClientMessagePage(w http.ResponseWriter, r *http.Request, title string, statusCode int, err error, message string) { + data := clientMessagePage{ + baseClientPage: s.getBaseClientPageData(title, "", w, r), + Error: getI18nError(err), + Success: message, + } + w.WriteHeader(statusCode) + renderClientTemplate(w, templateMessage, data) +} + +func (s *httpdServer) renderClientInternalServerErrorPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderClientMessagePage(w, r, util.I18nError500Title, http.StatusInternalServerError, + util.NewI18nError(err, util.I18nError500Message), "") +} + +func (s *httpdServer) renderClientBadRequestPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderClientMessagePage(w, r, util.I18nError400Title, http.StatusBadRequest, + util.NewI18nError(err, util.I18nError400Message), "") +} + +func (s *httpdServer) renderClientForbiddenPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderClientMessagePage(w, r, util.I18nError403Title, http.StatusForbidden, + util.NewI18nError(err, util.I18nError403Message), "") +} + +func (s *httpdServer) renderClientNotFoundPage(w http.ResponseWriter, r *http.Request, err error) { + s.renderClientMessagePage(w, r, util.I18nError404Title, http.StatusNotFound, + util.NewI18nError(err, util.I18nError404Message), "") +} + +func (s *httpdServer) renderClientTwoFactorPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := twoFactorPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18n2FATitle, + CurrentURL: webClientTwoFactorPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), + RecoveryURL: webClientTwoFactorRecoveryPath, + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + if next := r.URL.Query().Get("next"); strings.HasPrefix(next, webClientFilesPath) { + data.CurrentURL += "?next=" + url.QueryEscape(next) + } + renderClientTemplate(w, templateTwoFactor, data) +} + +func (s *httpdServer) renderClientTwoFactorRecoveryPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := twoFactorPage{ + commonBasePage: getCommonBasePage(r), + Title: util.I18n2FATitle, + CurrentURL: webClientTwoFactorRecoveryPath, + Error: err, + CSRFToken: createCSRFToken(w, r, s.csrfTokenAuth, "", webBaseClientPath), + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + renderClientTemplate(w, templateTwoFactorRecovery, data) +} + +func (s *httpdServer) renderClientMFAPage(w http.ResponseWriter, r *http.Request) { + data := clientMFAPage{ + baseClientPage: s.getBaseClientPageData(util.I18n2FATitle, webClientMFAPath, w, r), + TOTPConfigs: mfa.GetAvailableTOTPConfigNames(), + GenerateTOTPURL: webClientTOTPGeneratePath, + ValidateTOTPURL: webClientTOTPValidatePath, + SaveTOTPURL: webClientTOTPSavePath, + RecCodesURL: webClientRecoveryCodesPath, + Protocols: dataprovider.MFAProtocols, + } + user, err := dataprovider.GetUserWithGroupSettings(data.LoggedUser.Username, "") + if err != nil { + s.renderClientInternalServerErrorPage(w, r, err) + return + } + data.TOTPConfig = user.Filters.TOTPConfig + data.RequiredProtocols = user.Filters.TwoFactorAuthProtocols + renderClientTemplate(w, templateClientMFA, data) +} + +func (s *httpdServer) renderEditFilePage(w http.ResponseWriter, r *http.Request, fileName, fileData string, readOnly bool) { + title := util.I18nViewFileTitle + if !readOnly { + title = util.I18nEditFileTitle + } + data := editFilePage{ + baseClientPage: s.getBaseClientPageData(title, webClientEditFilePath, w, r), + Path: fileName, + Name: path.Base(fileName), + CurrentDir: path.Dir(fileName), + FileURL: webClientFilePath, + ReadOnly: readOnly, + Data: fileData, + } + + renderClientTemplate(w, templateClientEditFile, data) +} + +func (s *httpdServer) renderAddUpdateSharePage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share, + err *util.I18nError, isAdd bool) { + currentURL := webClientSharePath + title := util.I18nShareAddTitle + if !isAdd { + currentURL = fmt.Sprintf("%v/%v", webClientSharePath, url.PathEscape(share.ShareID)) + title = util.I18nShareUpdateTitle + } + if share.IsPasswordHashed() { + share.Password = redactedSecret + } + data := clientSharePage{ + baseClientPage: s.getBaseClientPageData(title, currentURL, w, r), + Share: share, + Error: err, + IsAdd: isAdd, + } + + renderClientTemplate(w, templateClientShare, data) +} + +func getDirMapping(dirName, baseWebPath string) []dirMapping { + paths := []dirMapping{} + if dirName != "/" { + paths = append(paths, dirMapping{ + DirName: path.Base(dirName), + Href: getFileObjectURL("/", dirName, baseWebPath), + }) + for { + dirName = path.Dir(dirName) + if dirName == "/" || dirName == "." { + break + } + paths = append([]dirMapping{{ + DirName: path.Base(dirName), + Href: getFileObjectURL("/", dirName, baseWebPath)}, + }, paths...) + } + } + return paths +} + +func (s *httpdServer) renderSharedFilesPage(w http.ResponseWriter, r *http.Request, dirName string, + err *util.I18nError, share dataprovider.Share, +) { + currentURL := path.Join(webClientPubSharesPath, share.ShareID, "browse") + baseData := s.getBaseClientPageData(util.I18nSharedFilesTitle, currentURL, w, r) + baseData.FilesURL = currentURL + baseSharePath := path.Join(webClientPubSharesPath, share.ShareID) + baseData.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout") + baseData.IsLoggedToShare = share.Password != "" + + data := filesPage{ + baseClientPage: baseData, + Error: err, + CurrentDir: url.QueryEscape(dirName), + DownloadURL: path.Join(baseSharePath, "partial"), + // dirName must be escaped because the router expects the full path as single argument + ShareUploadBaseURL: path.Join(baseSharePath, url.PathEscape(dirName)), + ViewPDFURL: path.Join(baseSharePath, "viewpdf"), + DirsURL: path.Join(baseSharePath, "dirs"), + FileURL: "", + FileActionsURL: "", + CheckExistURL: path.Join(baseSharePath, "browse", "exist"), + TasksURL: "", + CanAddFiles: share.Scope == dataprovider.ShareScopeReadWrite, + CanCreateDirs: false, + CanRename: false, + CanDelete: false, + CanDownload: share.Scope != dataprovider.ShareScopeWrite, + CanShare: false, + CanCopy: false, + Paths: getDirMapping(dirName, currentURL), + QuotaUsage: newUserQuotaUsage(&dataprovider.User{}), + KeepAliveInterval: int(cookieRefreshThreshold / time.Millisecond), + } + renderClientTemplate(w, templateClientFiles, data) +} + +func (s *httpdServer) renderShareDownloadPage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share, + downloadLink string, +) { + data := shareDownloadPage{ + baseClientPage: s.getBaseClientPageData(util.I18nShareDownloadTitle, "", w, r), + DownloadLink: downloadLink, + } + data.LogoutURL = "" + if share.Password != "" { + data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout") + } + + renderClientTemplate(w, templateShareDownload, data) +} + +func (s *httpdServer) renderUploadToSharePage(w http.ResponseWriter, r *http.Request, share *dataprovider.Share) { + currentURL := path.Join(webClientPubSharesPath, share.ShareID, "upload") + data := shareUploadPage{ + baseClientPage: s.getBaseClientPageData(util.I18nShareUploadTitle, currentURL, w, r), + Share: share, + UploadBasePath: path.Join(webClientPubSharesPath, share.ShareID), + } + data.LogoutURL = "" + if share.Password != "" { + data.LogoutURL = path.Join(webClientPubSharesPath, share.ShareID, "logout") + } + renderClientTemplate(w, templateUploadToShare, data) +} + +func (s *httpdServer) renderFilesPage(w http.ResponseWriter, r *http.Request, dirName string, + err *util.I18nError, user *dataprovider.User) { + data := filesPage{ + baseClientPage: s.getBaseClientPageData(util.I18nFilesTitle, webClientFilesPath, w, r), + Error: err, + CurrentDir: url.QueryEscape(dirName), + DownloadURL: webClientDownloadZipPath, + ViewPDFURL: webClientViewPDFPath, + DirsURL: webClientDirsPath, + FileURL: webClientFilePath, + FileActionsURL: webClientFileActionsPath, + CheckExistURL: webClientExistPath, + TasksURL: webClientTasksPath, + CanAddFiles: user.CanAddFilesFromWeb(dirName), + CanCreateDirs: user.CanAddDirsFromWeb(dirName), + CanRename: user.CanRenameFromWeb(dirName, dirName), + CanDelete: user.CanDeleteFromWeb(dirName), + CanDownload: user.HasPerm(dataprovider.PermDownload, dirName), + CanShare: user.CanManageShares(), + CanCopy: user.CanCopyFromWeb(dirName, dirName), + ShareUploadBaseURL: "", + Paths: getDirMapping(dirName, webClientFilesPath), + QuotaUsage: newUserQuotaUsage(user), + KeepAliveInterval: int(cookieRefreshThreshold / time.Millisecond), + } + renderClientTemplate(w, templateClientFiles, data) +} + +func (s *httpdServer) renderClientProfilePage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := clientProfilePage{ + baseClientPage: s.getBaseClientPageData(util.I18nProfileTitle, webClientProfilePath, w, r), + Error: err, + } + user, userMerged, errUser := dataprovider.GetUserVariants(data.LoggedUser.Username, "") + if errUser != nil { + s.renderClientInternalServerErrorPage(w, r, errUser) + return + } + data.PublicKeys = user.PublicKeys + data.TLSCerts = user.Filters.TLSCerts + data.AllowAPIKeyAuth = user.Filters.AllowAPIKeyAuth + data.Email = user.Email + data.AdditionalEmails = user.Filters.AdditionalEmails + data.AdditionalEmailsString = strings.Join(data.AdditionalEmails, ", ") + data.Description = user.Description + data.CanSubmit = userMerged.CanUpdateProfile() + renderClientTemplate(w, templateClientProfile, data) +} + +func (s *httpdServer) renderClientChangePasswordPage(w http.ResponseWriter, r *http.Request, err *util.I18nError) { + data := changeClientPasswordPage{ + baseClientPage: s.getBaseClientPageData(util.I18nChangePwdTitle, webChangeClientPwdPath, w, r), + Error: err, + } + + renderClientTemplate(w, templateChangePwd, data) +} + +func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxMultipartMem) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + if err := r.ParseForm(); err != nil { + s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), + util.NewI18nError(err, util.I18nErrorGetUser), "") + return + } + + connID := xid.New().String() + protocol := getProtocolFromRequest(r) + connectionID := fmt.Sprintf("%v_%v", protocol, connID) + if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { + s.renderClientForbiddenPage(w, r, err) + return + } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(err, util.I18nError429Message), "") + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + files := r.Form.Get("files") + var filesList []string + err = json.Unmarshal(util.StringToBytes(files), &filesList) + if err != nil { + s.renderClientBadRequestPage(w, r, err) + return + } + + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", + getCompressedFileName(connection.GetUsername(), filesList))) + renderCompressedFiles(w, connection, name, filesList, nil) +} + +func (s *httpdServer) handleClientSharePartialDownload(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxMultipartMem) + if err := r.ParseForm(); err != nil { + s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") + return + } + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(err, util.I18nError429Message), "") + return + } + defer common.Connections.Remove(connection.GetID()) + + transferQuota := connection.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + err = util.NewI18nError(connection.GetReadQuotaExceededError(), util.I18nErrorQuotaRead) + connection.Log(logger.LevelInfo, "denying share read due to quota limits") + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getMappedStatusCode(err), err, "") + return + } + files := r.Form.Get("files") + var filesList []string + err = json.Unmarshal(util.StringToBytes(files), &filesList) + if err != nil { + s.renderClientBadRequestPage(w, r, err) + return + } + + dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", + getCompressedFileName(fmt.Sprintf("share-%s", share.Name), filesList))) + renderCompressedFiles(w, connection, name, filesList, &share) +} + +func (s *httpdServer) handleShareGetDirContents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nError500Message), getRespStatus(err)) + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nError500Message), getRespStatus(err)) + return + } + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nError429Message), http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + lister, err := connection.ReadDir(name) + if err != nil { + sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirListGeneric), getMappedStatusCode(err)) + return + } + defer lister.Close() + + dataGetter := func(limit, offset int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + return nil, 0, err + } + results := make([]map[string]any, 0, len(contents)) + for idx, info := range contents { + if !info.Mode().IsDir() && !info.Mode().IsRegular() { + continue + } + res := make(map[string]any) + res["id"] = offset + idx + 1 + if info.IsDir() { + res["type"] = "1" + res["size"] = "" + } else { + res["type"] = "2" + res["size"] = info.Size() + } + res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) + res["name"] = info.Name() + res["url"] = getFileObjectURL(share.GetRelativePath(name), info.Name(), + path.Join(webClientPubSharesPath, share.ShareID, "browse")) + res["last_modified"] = getFileObjectModTime(info.ModTime()) + results = append(results, res) + } + data, err := json.Marshal(results) + count := limit + if len(results) == 0 { + count = 0 + } + return data, count, err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleClientUploadToShare(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeWrite, dataprovider.ShareScopeReadWrite} + share, _, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if share.Scope == dataprovider.ShareScopeReadWrite { + http.Redirect(w, r, path.Join(webClientPubSharesPath, share.ShareID, "browse"), http.StatusFound) + return + } + s.renderUploadToSharePage(w, r, &share) +} + +func (s *httpdServer) handleShareGetFiles(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") + return + } + + if err = common.Connections.Add(connection); err != nil { + s.renderSharedFilesPage(w, r, path.Dir(share.GetRelativePath(name)), + util.NewI18nError(err, util.I18nError429Message), share) + return + } + defer common.Connections.Remove(connection.GetID()) + + var info os.FileInfo + if name == "/" { + info = vfs.NewFileInfo(name, true, 0, time.Unix(0, 0), false) + } else { + info, err = connection.Stat(name, 1) + } + if err != nil { + s.renderSharedFilesPage(w, r, path.Dir(share.GetRelativePath(name)), + util.NewI18nError(err, i18nFsMsg(getRespStatus(err))), share) + return + } + if info.IsDir() { + s.renderSharedFilesPage(w, r, share.GetRelativePath(name), nil, share) + return + } + dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck + if status, err := downloadFile(w, r, connection, name, info, false, &share); err != nil { + dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck + if status > 0 { + s.renderSharedFilesPage(w, r, path.Dir(share.GetRelativePath(name)), + util.NewI18nError(err, i18nFsMsg(getRespStatus(err))), share) + } + } +} + +func (s *httpdServer) handleShareViewPDF(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, _, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + name := util.CleanPath(r.URL.Query().Get("path")) + data := viewPDFPage{ + commonBasePage: getCommonBasePage(r), + Title: path.Base(name), + URL: fmt.Sprintf("%s?path=%s&_=%d", path.Join(webClientPubSharesPath, share.ShareID, "getpdf"), + url.QueryEscape(name), time.Now().UTC().Unix()), + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + renderClientTemplate(w, templateClientViewPDF, data) +} + +func (s *httpdServer) handleShareGetPDF(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead, dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, getRespStatus(err), err, "") + return + } + + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(err, util.I18nError429Message), "") + return + } + defer common.Connections.Remove(connection.GetID()) + + info, err := connection.Stat(name, 1) + if err != nil { + status := getRespStatus(err) + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, status, + util.NewI18nError(err, i18nFsMsg(status)), "") + return + } + if info.IsDir() { + s.renderClientBadRequestPage(w, r, util.NewI18nError(fmt.Errorf("%q is not a file", name), util.I18nErrorPDFMessage)) + return + } + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + if err := s.ensurePDF(w, r, name, connection); err != nil { + return + } + dataprovider.UpdateShareLastUse(&share, 1) //nolint:errcheck + if _, err := downloadFile(w, r, connection, name, info, true, &share); err != nil { + dataprovider.UpdateShareLastUse(&share, -1) //nolint:errcheck + } +} + +func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, nil, util.I18nErrorDirList403, http.StatusForbidden) + return + } + + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + sendAPIResponse(w, r, nil, util.I18nErrorDirListUser, getRespStatus(err)) + return + } + + connID := xid.New().String() + protocol := getProtocolFromRequest(r) + connectionID := fmt.Sprintf("%s_%s", protocol, connID) + if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { + sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirList403), http.StatusForbidden) + return + } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, util.I18nErrorDirList429, http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + lister, err := connection.ReadDir(name) + if err != nil { + statusCode := getMappedStatusCode(err) + sendAPIResponse(w, r, err, i18nListDirMsg(statusCode), statusCode) + return + } + defer lister.Close() + + dirTree := r.URL.Query().Get("dirtree") == "1" + dataGetter := func(limit, offset int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + return nil, 0, err + } + results := make([]map[string]any, 0, len(contents)) + for idx, info := range contents { + res := make(map[string]any) + res["id"] = offset + idx + 1 + res["url"] = getFileObjectURL(name, info.Name(), webClientFilesPath) + if info.IsDir() { + res["type"] = "1" + res["size"] = "" + res["dir_path"] = url.QueryEscape(path.Join(name, info.Name())) + } else { + if dirTree { + continue + } + res["type"] = "2" + if info.Mode()&os.ModeSymlink != 0 { + res["size"] = "" + } else { + res["size"] = info.Size() + if info.Size() < httpdMaxEditFileSize { + res["edit_url"] = strings.Replace(res["url"].(string), webClientFilesPath, webClientEditFilePath, 1) + } + } + } + res["meta"] = fmt.Sprintf("%v_%v", res["type"], info.Name()) + res["name"] = info.Name() + res["last_modified"] = getFileObjectModTime(info.ModTime()) + results = append(results, res) + } + data, err := json.Marshal(results) + count := limit + if len(results) == 0 { + count = 0 + } + return data, count, err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), + util.NewI18nError(err, util.I18nErrorGetUser), "") + return + } + + connID := xid.New().String() + protocol := getProtocolFromRequest(r) + connectionID := fmt.Sprintf("%v_%v", protocol, connID) + if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { + s.renderClientForbiddenPage(w, r, err) + return + } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(err, util.I18nError429Message), "") + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + var info os.FileInfo + if name == "/" { + info = vfs.NewFileInfo(name, true, 0, time.Unix(0, 0), false) + } else { + info, err = connection.Stat(name, 0) + } + if err != nil { + s.renderFilesPage(w, r, path.Dir(name), util.NewI18nError(err, i18nFsMsg(getRespStatus(err))), &user) + return + } + if info.IsDir() { + s.renderFilesPage(w, r, name, nil, &user) + return + } + if status, err := downloadFile(w, r, connection, name, info, false, nil); err != nil && status != 0 { + if status > 0 { + if status == http.StatusRequestedRangeNotSatisfiable { + s.renderClientMessagePage(w, r, util.I18nError416Title, status, + util.NewI18nError(err, util.I18nError416Message), "") + return + } + s.renderFilesPage(w, r, path.Dir(name), util.NewI18nError(err, i18nFsMsg(status)), &user) + } + } +} + +func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), + util.NewI18nError(err, util.I18nErrorGetUser), "") + return + } + + connID := xid.New().String() + protocol := getProtocolFromRequest(r) + connectionID := fmt.Sprintf("%v_%v", protocol, connID) + if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { + s.renderClientForbiddenPage(w, r, err) + return + } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(err, util.I18nError429Message), "") + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + info, err := connection.Stat(name, 0) + if err != nil { + status := getRespStatus(err) + s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, status, util.NewI18nError(err, i18nFsMsg(status)), "") + return + } + if info.IsDir() { + s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, http.StatusBadRequest, + util.NewI18nError( + util.NewValidationError(fmt.Sprintf("The path %q does not point to a file", name)), + util.I18nErrorEditDir, + ), "") + return + } + if info.Size() > httpdMaxEditFileSize { + s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, http.StatusBadRequest, + util.NewI18nError( + util.NewValidationError(fmt.Sprintf("The file size %v for %q exceeds the maximum allowed size", + util.ByteCountIEC(info.Size()), name)), + util.I18nErrorEditSize, + ), "") + return + } + + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + reader, err := connection.getFileReader(name, 0, r.Method) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, getRespStatus(err), + util.NewI18nError(err, util.I18nError500Message), "") + return + } + defer reader.Close() + + var b bytes.Buffer + _, err = io.Copy(&b, reader) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nErrorEditorTitle, getRespStatus(err), + util.NewI18nError(err, util.I18nError500Message), "") + return + } + + s.renderEditFilePage(w, r, name, b.String(), !user.CanAddFilesFromWeb(path.Dir(name))) +} + +func (s *httpdServer) handleClientAddShareGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), + util.NewI18nError(err, util.I18nErrorGetUser), "") + return + } + share := &dataprovider.Share{Scope: dataprovider.ShareScopeRead} + if user.Filters.DefaultSharesExpiration > 0 { + share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(user.Filters.DefaultSharesExpiration))) + } else if user.Filters.MaxSharesExpiration > 0 { + share.ExpiresAt = util.GetTimeAsMsSinceEpoch(time.Now().Add(24 * time.Hour * time.Duration(user.Filters.MaxSharesExpiration))) + } + dirName := "/" + if _, ok := r.URL.Query()["path"]; ok { + dirName = util.CleanPath(r.URL.Query().Get("path")) + } + + if _, ok := r.URL.Query()["files"]; ok { + files := r.URL.Query().Get("files") + var filesList []string + err := json.Unmarshal(util.StringToBytes(files), &filesList) + if err != nil { + s.renderClientBadRequestPage(w, r, err) + return + } + for _, f := range filesList { + if f != "" { + share.Paths = append(share.Paths, path.Join(dirName, f)) + } + } + } + + s.renderAddUpdateSharePage(w, r, share, nil, true) +} + +func (s *httpdServer) handleClientUpdateShareGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + shareID := getURLParam(r, "id") + share, err := dataprovider.ShareExists(shareID, claims.Username) + if err == nil { + s.renderAddUpdateSharePage(w, r, &share, nil, false) + } else if errors.Is(err, util.ErrNotFound) { + s.renderClientNotFoundPage(w, r, err) + } else { + s.renderClientInternalServerErrorPage(w, r, err) + } +} + +func (s *httpdServer) handleClientAddSharePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + share, err := getShareFromPostFields(r) + if err != nil { + s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(err, util.I18nError500Message), true) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + share.ID = 0 + share.ShareID = util.GenerateUniqueID() + share.LastUseAt = 0 + share.Username = claims.Username + if share.Password == "" { + if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { + s.renderAddUpdateSharePage(w, r, share, + util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd), + true) + return + } + } + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(err, util.I18nErrorGetUser), true) + return + } + if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(share.ExpiresAt)); err != nil { + s.renderAddUpdateSharePage(w, r, share, util.NewI18nError( + err, + util.I18nErrorShareExpirationOutOfRange, + util.I18nErrorArgs( + map[string]any{ + "val": time.Now().Add(24 * time.Hour * time.Duration(user.Filters.MaxSharesExpiration+1)).UnixMilli(), + "formatParams": map[string]string{ + "year": "numeric", + "month": "numeric", + "day": "numeric", + }, + }, + ), + ), true) + return + } + err = dataprovider.AddShare(share, claims.Username, ipAddr, claims.Role) + if err == nil { + http.Redirect(w, r, webClientSharesPath, http.StatusSeeOther) + } else { + s.renderAddUpdateSharePage(w, r, share, util.NewI18nError(err, util.I18nErrorShareGeneric), true) + } +} + +func (s *httpdServer) handleClientUpdateSharePost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + shareID := getURLParam(r, "id") + share, err := dataprovider.ShareExists(shareID, claims.Username) + if errors.Is(err, util.ErrNotFound) { + s.renderClientNotFoundPage(w, r, err) + return + } else if err != nil { + s.renderClientInternalServerErrorPage(w, r, err) + return + } + updatedShare, err := getShareFromPostFields(r) + if err != nil { + s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(err, util.I18nError500Message), false) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + updatedShare.ShareID = shareID + updatedShare.Username = claims.Username + if updatedShare.Password == redactedSecret { + updatedShare.Password = share.Password + } + if updatedShare.Password == "" { + if slices.Contains(claims.Permissions, sdk.WebClientShareNoPasswordDisabled) { + s.renderAddUpdateSharePage(w, r, updatedShare, + util.NewI18nError(util.NewValidationError("You are not allowed to share files/folders without password"), util.I18nErrorShareNoPwd), + false) + return + } + } + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(err, util.I18nErrorGetUser), false) + return + } + if err := user.CheckMaxShareExpiration(util.GetTimeFromMsecSinceEpoch(updatedShare.ExpiresAt)); err != nil { + s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError( + err, + util.I18nErrorShareExpirationOutOfRange, + util.I18nErrorArgs( + map[string]any{ + "val": time.Now().Add(24 * time.Hour * time.Duration(user.Filters.MaxSharesExpiration+1)).UnixMilli(), + "formatParams": map[string]string{ + "year": "numeric", + "month": "numeric", + "day": "numeric", + }, + }, + ), + ), false) + return + } + err = dataprovider.UpdateShare(updatedShare, claims.Username, ipAddr, claims.Role) + if err == nil { + http.Redirect(w, r, webClientSharesPath, http.StatusSeeOther) + } else { + s.renderAddUpdateSharePage(w, r, updatedShare, util.NewI18nError(err, util.I18nErrorShareGeneric), false) + } +} + +func getAllShares(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, nil, util.I18nErrorInvalidToken, http.StatusForbidden) + return + } + + dataGetter := func(limit, offset int) ([]byte, int, error) { + shares, err := dataprovider.GetShares(limit, offset, dataprovider.OrderASC, claims.Username) + if err != nil { + return nil, 0, err + } + data, err := json.Marshal(shares) + return data, len(shares), err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func (s *httpdServer) handleClientGetShares(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + data := clientSharesPage{ + baseClientPage: s.getBaseClientPageData(util.I18nSharesTitle, webClientSharesPath, w, r), + BasePublicSharesURL: webClientPubSharesPath, + BaseURL: s.binding.BaseURL, + } + renderClientTemplate(w, templateClientShares, data) +} + +func (s *httpdServer) handleClientGetProfile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderClientProfilePage(w, r, nil) +} + +func (s *httpdServer) handleWebClientChangePwd(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderClientChangePasswordPage(w, r, nil) +} + +func (s *httpdServer) handleWebClientProfilePost(w http.ResponseWriter, r *http.Request) { //nolint:gocyclo + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + err := r.ParseForm() + if err != nil { + s.renderClientProfilePage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := verifyCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + user, userMerged, err := dataprovider.GetUserVariants(claims.Username, "") + if err != nil { + s.renderClientProfilePage(w, r, util.NewI18nError(err, util.I18nErrorGetUser)) + return + } + if !userMerged.CanUpdateProfile() { + s.renderClientForbiddenPage(w, r, util.NewI18nError( + errors.New("you are not allowed to change anything"), + util.I18nErrorNoPermissions, + )) + return + } + if userMerged.CanManagePublicKeys() { + for k := range r.Form { + if hasPrefixAndSuffix(k, "public_keys[", "][public_key]") { + r.Form.Add("public_keys", r.Form.Get(k)) + } + } + user.PublicKeys = r.Form["public_keys"] + } + if userMerged.CanManageTLSCerts() { + for k := range r.Form { + if hasPrefixAndSuffix(k, "tls_certs[", "][tls_cert]") { + r.Form.Add("tls_certs", r.Form.Get(k)) + } + } + user.Filters.TLSCerts = r.Form["tls_certs"] + } + if userMerged.CanChangeAPIKeyAuth() { + user.Filters.AllowAPIKeyAuth = r.Form.Get("allow_api_key_auth") != "" + } + if userMerged.CanChangeInfo() { + user.Email = strings.TrimSpace(r.Form.Get("email")) + user.Description = r.Form.Get("description") + for k := range r.Form { + if hasPrefixAndSuffix(k, "additional_emails[", "][additional_email]") { + email := strings.TrimSpace(r.Form.Get(k)) + if email != "" { + r.Form.Add("additional_emails", email) + } + } + } + user.Filters.AdditionalEmails = r.Form["additional_emails"] + } + err = dataprovider.UpdateUser(&user, dataprovider.ActionExecutorSelf, ipAddr, user.Role) + if err != nil { + s.renderClientProfilePage(w, r, util.NewI18nError(err, util.I18nError500Message)) + return + } + s.renderClientMessagePage(w, r, util.I18nProfileTitle, http.StatusOK, nil, util.I18nProfileUpdated) +} + +func (s *httpdServer) handleWebClientMFA(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderClientMFAPage(w, r) +} + +func (s *httpdServer) handleWebClientTwoFactor(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderClientTwoFactorPage(w, r, nil) +} + +func (s *httpdServer) handleWebClientTwoFactorRecovery(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + s.renderClientTwoFactorRecoveryPage(w, r, nil) +} + +func getShareFromPostFields(r *http.Request) (*dataprovider.Share, error) { + share := &dataprovider.Share{} + if err := r.ParseForm(); err != nil { + return share, util.NewI18nError(err, util.I18nErrorInvalidForm) + } + for k := range r.Form { + if hasPrefixAndSuffix(k, "paths[", "][path]") { + r.Form.Add("paths", r.Form.Get(k)) + } + } + + share.Name = strings.TrimSpace(r.Form.Get("name")) + share.Description = r.Form.Get("description") + for _, p := range r.Form["paths"] { + if strings.TrimSpace(p) != "" { + share.Paths = append(share.Paths, p) + } + } + share.Password = strings.TrimSpace(r.Form.Get("password")) + share.AllowFrom = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",") + scope, err := strconv.Atoi(r.Form.Get("scope")) + if err != nil { + return share, util.NewI18nError(err, util.I18nErrorShareScope) + } + share.Scope = dataprovider.ShareScope(scope) + maxTokens, err := strconv.Atoi(r.Form.Get("max_tokens")) + if err != nil { + return share, util.NewI18nError(err, util.I18nErrorShareMaxTokens) + } + share.MaxTokens = maxTokens + expirationDateMillis := int64(0) + expirationDateString := strings.TrimSpace(r.Form.Get("expiration_date")) + if expirationDateString != "" { + expirationDate, err := time.Parse(webDateTimeFormat, expirationDateString) + if err != nil { + return share, util.NewI18nError(err, util.I18nErrorShareExpiration) + } + expirationDateMillis = util.GetTimeAsMsSinceEpoch(expirationDate) + } + share.ExpiresAt = expirationDateMillis + return share, nil +} + +func (s *httpdServer) handleWebClientForgotPwd(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + if !smtp.IsEnabled() { + s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) + return + } + s.renderClientForgotPwdPage(w, r, nil) +} + +func (s *httpdServer) handleWebClientForgotPwdPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + + err := r.ParseForm() + if err != nil { + s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderClientForbiddenPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + username := strings.TrimSpace(r.Form.Get("username")) + err = handleForgotPassword(r, username, false) + if err != nil { + s.renderClientForgotPwdPage(w, r, util.NewI18nError(err, util.I18nErrorPwdResetGeneric)) + return + } + http.Redirect(w, r, webClientResetPwdPath, http.StatusFound) +} + +func (s *httpdServer) handleWebClientPasswordReset(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + if !smtp.IsEnabled() { + s.renderClientNotFoundPage(w, r, errors.New("this page does not exist")) + return + } + s.renderClientResetPwdPage(w, r, nil) +} + +func (s *httpdServer) handleClientViewPDF(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + name := r.URL.Query().Get("path") + if name == "" { + s.renderClientBadRequestPage(w, r, errors.New("no file specified")) + return + } + name = util.CleanPath(name) + data := viewPDFPage{ + commonBasePage: getCommonBasePage(r), + Title: path.Base(name), + URL: fmt.Sprintf("%s?path=%s&_=%d", webClientGetPDFPath, url.QueryEscape(name), time.Now().UTC().Unix()), + Branding: s.binding.webClientBranding(), + Languages: s.binding.languages(), + } + renderClientTemplate(w, templateClientViewPDF, data) +} + +func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + s.renderClientForbiddenPage(w, r, util.NewI18nError(errInvalidTokenClaims, util.I18nErrorInvalidToken)) + return + } + name := r.URL.Query().Get("path") + if name == "" { + s.renderClientBadRequestPage(w, r, util.NewI18nError(errors.New("no file specified"), util.I18nError400Message)) + return + } + name = util.CleanPath(name) + user, err := dataprovider.GetUserWithGroupSettings(claims.Username, "") + if err != nil { + s.renderClientMessagePage(w, r, util.I18nError500Title, getRespStatus(err), + util.NewI18nError(err, util.I18nErrorGetUser), "") + return + } + + connID := xid.New().String() + protocol := getProtocolFromRequest(r) + connectionID := fmt.Sprintf("%v_%v", protocol, connID) + if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil { + s.renderClientForbiddenPage(w, r, err) + return + } + baseConn := common.NewBaseConnection(connID, protocol, util.GetHTTPLocalAddress(r), r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + s.renderClientMessagePage(w, r, util.I18nError429Title, http.StatusTooManyRequests, + util.NewI18nError(err, util.I18nError429Message), "") + return + } + defer common.Connections.Remove(connection.GetID()) + + info, err := connection.Stat(name, 0) + if err != nil { + status := getRespStatus(err) + s.renderClientMessagePage(w, r, util.I18nErrorPDFTitle, status, util.NewI18nError(err, i18nFsMsg(status)), "") + return + } + if info.IsDir() { + s.renderClientBadRequestPage(w, r, util.NewI18nError(fmt.Errorf("%q is not a file", name), util.I18nErrorPDFMessage)) + return + } + connection.User.CheckFsRoot(connection.ID) //nolint:errcheck + if err := s.ensurePDF(w, r, name, connection); err != nil { + return + } + downloadFile(w, r, connection, name, info, true, nil) //nolint:errcheck +} + +func (s *httpdServer) ensurePDF(w http.ResponseWriter, r *http.Request, name string, connection *Connection) error { + reader, err := connection.getFileReader(name, 0, r.Method) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nErrorPDFTitle, + getRespStatus(err), util.NewI18nError(err, util.I18nError500Message), "") + return err + } + defer reader.Close() + + var b bytes.Buffer + _, err = io.CopyN(&b, reader, 128) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nErrorPDFTitle, getRespStatus(err), + util.NewI18nError(err, util.I18nErrorPDFMessage), "") + return err + } + if ctype := http.DetectContentType(b.Bytes()); ctype != "application/pdf" { + connection.Log(logger.LevelDebug, "detected %q content type, expected PDF, file %q", ctype, name) + err := fmt.Errorf("the file %q does not look like a PDF", name) + s.renderClientBadRequestPage(w, r, util.NewI18nError(err, util.I18nErrorPDFMessage)) + return err + } + return nil +} + +func (s *httpdServer) handleClientShareLoginGet(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + s.renderShareLoginPage(w, r, nil) +} + +func (s *httpdServer) handleClientShareLoginPost(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + if err := r.ParseForm(); err != nil { + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidForm)) + return + } + if err := verifyLoginCookieAndCSRFToken(r, s.csrfTokenAuth); err != nil { + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCSRF)) + return + } + invalidateToken(r) + shareID := getURLParam(r, "id") + share, err := dataprovider.ShareExists(shareID, "") + if err != nil { + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nErrorInvalidCredentials)) + return + } + match, err := share.CheckCredentials(strings.TrimSpace(r.Form.Get("share_password"))) + if !match || err != nil { + s.renderShareLoginPage(w, r, util.NewI18nError(dataprovider.ErrInvalidCredentials, util.I18nErrorInvalidCredentials)) + return + } + next := path.Clean(r.URL.Query().Get("next")) + baseShareURL := path.Join(webClientPubSharesPath, share.ShareID) + isRedirect, redirectTo := checkShareRedirectURL(next, baseShareURL) + c := &jwt.Claims{ + Username: shareID, + } + if isRedirect { + c.Ref = next + } + err = createAndSetCookie(w, r, c, s.tokenAuth, tokenAudienceWebShare, ipAddr) + if err != nil { + s.renderShareLoginPage(w, r, util.NewI18nError(err, util.I18nError500Message)) + return + } + if isRedirect { + http.Redirect(w, r, redirectTo, http.StatusFound) + return + } + s.renderClientMessagePage(w, r, util.I18nSharedFilesTitle, http.StatusOK, nil, util.I18nShareLoginOK) +} + +func (s *httpdServer) handleClientShareLogout(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + + shareID := getURLParam(r, "id") + ctx, claims, err := s.getShareClaims(r, shareID) + if err != nil { + s.renderClientMessagePage(w, r, util.I18nShareAccessErrorTitle, http.StatusForbidden, + util.NewI18nError(err, util.I18nErrorInvalidToken), "") + return + } + removeCookie(w, r.WithContext(ctx), webBaseClientPath) + + redirectURL := path.Join(webClientPubSharesPath, shareID, fmt.Sprintf("login?next=%s", url.QueryEscape(claims.Ref))) + http.Redirect(w, r, redirectURL, http.StatusFound) +} + +func (s *httpdServer) handleClientSharedFile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeRead} + share, _, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + query := "" + if r.URL.RawQuery != "" { + query = "?" + r.URL.RawQuery + } + s.renderShareDownloadPage(w, r, &share, path.Join(webClientPubSharesPath, share.ShareID)+query) +} + +func (s *httpdServer) handleClientCheckExist(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + defer common.Connections.Remove(connection.GetID()) + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + + doCheckExist(w, r, connection, name) +} + +func (s *httpdServer) handleClientShareCheckExist(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + validScopes := []dataprovider.ShareScope{dataprovider.ShareScopeReadWrite} + share, connection, err := s.checkPublicShare(w, r, validScopes) + if err != nil { + return + } + if err := validateBrowsableShare(share, connection); err != nil { + sendAPIResponse(w, r, err, "", getRespStatus(err)) + return + } + name, err := getBrowsableSharedPath(share.Paths[0], r) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + + if err = common.Connections.Add(connection); err != nil { + sendAPIResponse(w, r, err, "Unable to add connection", http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + doCheckExist(w, r, connection, name) +} + +type filesToCheck struct { + Files []string `json:"files"` +} + +func doCheckExist(w http.ResponseWriter, r *http.Request, connection *Connection, name string) { + var filesList filesToCheck + err := render.DecodeJSON(r.Body, &filesList) + if err != nil { + sendAPIResponse(w, r, err, "", http.StatusBadRequest) + return + } + if len(filesList.Files) == 0 { + sendAPIResponse(w, r, errors.New("files to be checked are mandatory"), "", http.StatusBadRequest) + return + } + + lister, err := connection.ListDir(name) + if err != nil { + sendAPIResponse(w, r, err, "Unable to get directory contents", getMappedStatusCode(err)) + return + } + defer lister.Close() + + dataGetter := func(limit, _ int) ([]byte, int, error) { + contents, err := lister.Next(limit) + if errors.Is(err, io.EOF) { + err = nil + } + if err != nil { + return nil, 0, err + } + existing := make([]map[string]any, 0) + for _, info := range contents { + if slices.Contains(filesList.Files, info.Name()) { + res := make(map[string]any) + res["name"] = info.Name() + if info.IsDir() { + res["type"] = "1" + res["size"] = "" + } else { + res["type"] = "2" + res["size"] = info.Size() + } + existing = append(existing, res) + } + } + data, err := json.Marshal(existing) + count := limit + if len(existing) == 0 { + count = 0 + } + return data, count, err + } + + streamJSONArray(w, defaultQueryLimit, dataGetter) +} + +func checkShareRedirectURL(next, base string) (bool, string) { + if !strings.HasPrefix(next, base) { + return false, "" + } + if next == base { + return true, path.Join(next, "download") + } + baseURL, err := url.Parse(base) + if err != nil { + return false, "" + } + nextURL, err := url.Parse(next) + if err != nil { + return false, "" + } + if nextURL.Path == baseURL.Path { + redirectURL := nextURL.JoinPath("download") + return true, redirectURL.String() + } + return true, next +} + +func getWebTask(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxLoginBodySize) + claims, err := jwt.FromContext(r.Context()) + if err != nil || claims.Username == "" { + sendAPIResponse(w, r, err, "Invalid token claims", http.StatusBadRequest) + return + } + taskID := getURLParam(r, "id") + + task, err := webTaskMgr.Get(taskID) + if err != nil { + sendAPIResponse(w, r, err, "Unable to get task", getMappedStatusCode(err)) + return + } + if task.User != claims.Username { + sendAPIResponse(w, r, nil, http.StatusText(http.StatusForbidden), http.StatusForbidden) + return + } + render.JSON(w, r, task) +} + +func taskDeleteDir(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + + name := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + task := webTaskData{ + ID: connection.GetID(), + User: connection.GetUsername(), + Path: name, + Timestamp: util.GetTimeAsMsSinceEpoch(time.Now()), + Status: 0, + } + if err := webTaskMgr.Add(task); err != nil { + common.Connections.Remove(connection.GetID()) + sendAPIResponse(w, r, nil, "Unable to create task", http.StatusInternalServerError) + return + } + go executeDeleteTask(connection, task) + sendAPIResponse(w, r, nil, task.ID, http.StatusAccepted) +} + +func taskRenameFsEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + oldName := connection.User.GetCleanedPath(r.URL.Query().Get("path")) + newName := connection.User.GetCleanedPath(r.URL.Query().Get("target")) + task := webTaskData{ + ID: connection.GetID(), + User: connection.GetUsername(), + Path: oldName, + Target: newName, + Timestamp: util.GetTimeAsMsSinceEpoch(time.Now()), + Status: 0, + } + if err := webTaskMgr.Add(task); err != nil { + common.Connections.Remove(connection.GetID()) + sendAPIResponse(w, r, nil, "Unable to create task", http.StatusInternalServerError) + return + } + go executeRenameTask(connection, task) + sendAPIResponse(w, r, nil, task.ID, http.StatusAccepted) +} + +func taskCopyFsEntry(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxRequestSize) + connection, err := getUserConnection(w, r) + if err != nil { + return + } + source := r.URL.Query().Get("path") + target := r.URL.Query().Get("target") + copyFromSource := strings.HasSuffix(source, "/") + copyInTarget := strings.HasSuffix(target, "/") + source = connection.User.GetCleanedPath(source) + target = connection.User.GetCleanedPath(target) + if copyFromSource { + source += "/" + } + if copyInTarget { + target += "/" + } + task := webTaskData{ + ID: connection.GetID(), + User: connection.GetUsername(), + Path: source, + Target: target, + Timestamp: util.GetTimeAsMsSinceEpoch(time.Now()), + Status: 0, + } + if err := webTaskMgr.Add(task); err != nil { + common.Connections.Remove(connection.GetID()) + sendAPIResponse(w, r, nil, "Unable to create task", http.StatusInternalServerError) + return + } + go executeCopyTask(connection, task) + sendAPIResponse(w, r, nil, task.ID, http.StatusAccepted) +} + +func executeDeleteTask(conn *Connection, task webTaskData) { + done := make(chan bool) + + defer func() { + close(done) + common.Connections.Remove(conn.GetID()) + }() + + go keepAliveTask(task, done, 2*time.Minute) + + status := http.StatusOK + if err := conn.RemoveAll(task.Path); err != nil { + status = getMappedStatusCode(err) + } + + task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) + task.Status = status + err := webTaskMgr.Add(task) + conn.Log(logger.LevelDebug, "delete task finished, status: %d, update task err: %v", status, err) +} + +func executeRenameTask(conn *Connection, task webTaskData) { + done := make(chan bool) + + defer func() { + close(done) + common.Connections.Remove(conn.GetID()) + }() + + go keepAliveTask(task, done, 2*time.Minute) + + status := http.StatusOK + + if !conn.IsSameResource(task.Path, task.Target) { + if err := conn.Copy(task.Path, task.Target); err != nil { + status = getMappedStatusCode(err) + task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) + task.Status = status + err = webTaskMgr.Add(task) + conn.Log(logger.LevelDebug, "copy step for rename task finished, status: %d, update task err: %v", status, err) + return + } + if err := conn.RemoveAll(task.Path); err != nil { + status = getMappedStatusCode(err) + } + } else { + if err := conn.Rename(task.Path, task.Target); err != nil { + status = getMappedStatusCode(err) + } + } + + task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) + task.Status = status + err := webTaskMgr.Add(task) + conn.Log(logger.LevelDebug, "rename task finished, status: %d, update task err: %v", status, err) +} + +func executeCopyTask(conn *Connection, task webTaskData) { + done := make(chan bool) + + defer func() { + close(done) + common.Connections.Remove(conn.GetID()) + }() + + go keepAliveTask(task, done, 2*time.Minute) + + status := http.StatusOK + if err := conn.Copy(task.Path, task.Target); err != nil { + status = getMappedStatusCode(err) + } + + task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) + task.Status = status + err := webTaskMgr.Add(task) + conn.Log(logger.LevelDebug, "copy task finished, status: %d, update task err: %v", status, err) +} + +func keepAliveTask(task webTaskData, done chan bool, interval time.Duration) { + ticker := time.NewTicker(interval) + defer func() { + ticker.Stop() + }() + + for { + select { + case <-done: + return + case <-ticker.C: + task.Timestamp = util.GetTimeAsMsSinceEpoch(time.Now()) + err := webTaskMgr.Add(task) + logger.Debug(logSender, task.ID, "task timestamp updated, err: %v", err) + } + } +} diff --git a/internal/httpd/webtask.go b/internal/httpd/webtask.go new file mode 100644 index 00000000..0c328c06 --- /dev/null +++ b/internal/httpd/webtask.go @@ -0,0 +1,108 @@ +// Copyright (C) 2024 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + webTaskMgr webTaskManager +) + +func newWebTaskManager(isShared int) webTaskManager { + if isShared == 1 { + logger.Info(logSender, "", "using provider task manager") + return &dbTaskManager{} + } + logger.Info(logSender, "", "using memory task manager") + return &memoryTaskManager{} +} + +type webTaskManager interface { + Add(data webTaskData) error + Get(ID string) (webTaskData, error) + Cleanup() +} + +type webTaskData struct { + ID string `json:"id"` + User string `json:"user"` + Path string `json:"path"` + Target string `json:"target"` + Timestamp int64 `json:"ts"` + Status int `json:"status"` // 0 in progress or http status code (200 ok, 403 and so on) +} + +type memoryTaskManager struct { + tasks sync.Map +} + +func (m *memoryTaskManager) Add(data webTaskData) error { + m.tasks.Store(data.ID, &data) + return nil +} + +func (m *memoryTaskManager) Get(ID string) (webTaskData, error) { + data, ok := m.tasks.Load(ID) + if !ok { + return webTaskData{}, util.NewRecordNotFoundError(fmt.Sprintf("task for ID %q not found", ID)) + } + return *data.(*webTaskData), nil +} + +func (m *memoryTaskManager) Cleanup() { + m.tasks.Range(func(key, value any) bool { + data := value.(*webTaskData) + if data.Timestamp < util.GetTimeAsMsSinceEpoch(time.Now().Add(-5*time.Minute)) { + m.tasks.Delete(key) + } + return true + }) +} + +type dbTaskManager struct{} + +func (m *dbTaskManager) Add(data webTaskData) error { + session := dataprovider.Session{ + Key: data.ID, + Data: data, + Type: dataprovider.SessionTypeWebTask, + Timestamp: data.Timestamp, + } + return dataprovider.AddSharedSession(session) +} + +func (m *dbTaskManager) Get(ID string) (webTaskData, error) { + sess, err := dataprovider.GetSharedSession(ID, dataprovider.SessionTypeWebTask) + if err != nil { + return webTaskData{}, err + } + d := sess.Data.([]byte) + var data webTaskData + err = json.Unmarshal(d, &data) + return data, err +} + +func (m *dbTaskManager) Cleanup() { + dataprovider.CleanupSharedSessions(dataprovider.SessionTypeWebTask, time.Now().Add(-5*time.Minute)) //nolint:errcheck +} diff --git a/internal/httpd/webtask_test.go b/internal/httpd/webtask_test.go new file mode 100644 index 00000000..f4e2b949 --- /dev/null +++ b/internal/httpd/webtask_test.go @@ -0,0 +1,133 @@ +// Copyright (C) 2024 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpd + +import ( + "testing" + "time" + + "github.com/rs/xid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func TestMemoryWebTaskManager(t *testing.T) { + mgr := newWebTaskManager(0) + m, ok := mgr.(*memoryTaskManager) + require.True(t, ok) + task := webTaskData{ + ID: xid.New().String(), + User: defeaultUsername, + Timestamp: time.Now().Add(-1 * time.Hour).UnixMilli(), + Status: 0, + } + task1 := webTaskData{ + ID: xid.New().String(), + User: defeaultUsername, + Timestamp: time.Now().UnixMilli(), + Status: 0, + } + err := m.Add(task) + require.NoError(t, err) + err = m.Add(task1) + require.NoError(t, err) + taskGet, err := m.Get(task.ID) + require.NoError(t, err) + require.Equal(t, task, taskGet) + m.Cleanup() + _, err = m.Get(task.ID) + require.ErrorIs(t, err, util.ErrNotFound) + taskGet, err = m.Get(task1.ID) + require.NoError(t, err) + require.Equal(t, task1, taskGet) + task1.Timestamp = time.Now().Add(-1 * time.Hour).UnixMilli() + err = m.Add(task1) + require.NoError(t, err) + m.Cleanup() + _, err = m.Get(task.ID) + require.ErrorIs(t, err, util.ErrNotFound) + // test keep alive task + oldMgr := webTaskMgr + webTaskMgr = mgr + + done := make(chan bool) + go keepAliveTask(task, done, 50*time.Millisecond) + + time.Sleep(120 * time.Millisecond) + close(done) + taskGet, err = m.Get(task.ID) + require.NoError(t, err) + assert.Greater(t, taskGet.Timestamp, task.Timestamp) + m.Cleanup() + _, err = m.Get(task.ID) + require.NoError(t, err) + err = m.Add(task) + require.NoError(t, err) + m.Cleanup() + _, err = m.Get(task.ID) + require.ErrorIs(t, err, util.ErrNotFound) + + webTaskMgr = oldMgr +} + +func TestDbWebTaskManager(t *testing.T) { + if !isSharedProviderSupported() { + t.Skip("this test it is not available with this provider") + } + mgr := newWebTaskManager(1) + m, ok := mgr.(*dbTaskManager) + require.True(t, ok) + + task := webTaskData{ + ID: xid.New().String(), + User: defeaultUsername, + Timestamp: time.Now().Add(-1 * time.Hour).UnixMilli(), + Status: 0, + } + err := m.Add(task) + require.NoError(t, err) + taskGet, err := m.Get(task.ID) + require.NoError(t, err) + require.Equal(t, task, taskGet) + m.Cleanup() + _, err = m.Get(task.ID) + require.ErrorIs(t, err, util.ErrNotFound) + err = m.Add(task) + require.NoError(t, err) + // test keep alive task + oldMgr := webTaskMgr + webTaskMgr = mgr + + done := make(chan bool) + go keepAliveTask(task, done, 50*time.Millisecond) + + time.Sleep(120 * time.Millisecond) + close(done) + taskGet, err = m.Get(task.ID) + require.NoError(t, err) + assert.Greater(t, taskGet.Timestamp, task.Timestamp) + m.Cleanup() + _, err = m.Get(task.ID) + require.NoError(t, err) + err = m.Add(task) + require.NoError(t, err) + m.Cleanup() + _, err = m.Get(task.ID) + require.ErrorIs(t, err, util.ErrNotFound) + + webTaskMgr = oldMgr +} diff --git a/internal/httpdtest/httpdtest.go b/internal/httpdtest/httpdtest.go new file mode 100644 index 00000000..bacfa675 --- /dev/null +++ b/internal/httpdtest/httpdtest.go @@ -0,0 +1,2980 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package httpdtest provides utilities for testing the supported REST API. +package httpdtest + +import ( + "bytes" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "path" + "slices" + "strconv" + "strings" + + "github.com/go-chi/render" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + tokenPath = "/api/v2/token" + activeConnectionsPath = "/api/v2/connections" + quotasBasePath = "/api/v2/quotas" + quotaScanPath = "/api/v2/quotas/users/scans" + quotaScanVFolderPath = "/api/v2/quotas/folders/scans" + userPath = "/api/v2/users" + groupPath = "/api/v2/groups" + versionPath = "/api/v2/version" + folderPath = "/api/v2/folders" + serverStatusPath = "/api/v2/status" + dumpDataPath = "/api/v2/dumpdata" + loadDataPath = "/api/v2/loaddata" + defenderHosts = "/api/v2/defender/hosts" + adminPath = "/api/v2/admins" + adminPwdPath = "/api/v2/admin/changepwd" + apiKeysPath = "/api/v2/apikeys" + retentionChecksPath = "/api/v2/retention/users/checks" + eventActionsPath = "/api/v2/eventactions" + eventRulesPath = "/api/v2/eventrules" + rolesPath = "/api/v2/roles" + ipListsPath = "/api/v2/iplists" +) + +const ( + defaultTokenAuthUser = "admin" + defaultTokenAuthPass = "password" +) + +var ( + httpBaseURL = "http://127.0.0.1:8080" + jwtToken = "" +) + +// SetBaseURL sets the base url to use for HTTP requests. +// Default URL is "http://127.0.0.1:8080" +func SetBaseURL(url string) { + httpBaseURL = url +} + +// SetJWTToken sets the JWT token to use +func SetJWTToken(token string) { + jwtToken = token +} + +func sendHTTPRequest(method, url string, body io.Reader, contentType, token string) (*http.Response, error) { + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + if contentType != "" { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token)) + } + return httpclient.GetHTTPClient().Do(req) +} + +func buildURLRelativeToBase(paths ...string) string { + // we need to use path.Join and not filepath.Join + // since filepath.Join will use backslash separator on Windows + p := path.Join(paths...) + return fmt.Sprintf("%s/%s", strings.TrimRight(httpBaseURL, "/"), strings.TrimLeft(p, "/")) +} + +// GetToken tries to return a JWT token +func GetToken(username, password string) (string, map[string]any, error) { + req, err := http.NewRequest(http.MethodGet, buildURLRelativeToBase(tokenPath), nil) + if err != nil { + return "", nil, err + } + req.SetBasicAuth(username, password) + resp, err := httpclient.GetHTTPClient().Do(req) + if err != nil { + return "", nil, err + } + defer resp.Body.Close() + + err = checkResponse(resp.StatusCode, http.StatusOK) + if err != nil { + return "", nil, err + } + responseHolder := make(map[string]any) + err = render.DecodeJSON(resp.Body, &responseHolder) + if err != nil { + return "", nil, err + } + return responseHolder["access_token"].(string), responseHolder, nil +} + +func getDefaultToken() string { + if jwtToken != "" { + return jwtToken + } + token, _, err := GetToken(defaultTokenAuthUser, defaultTokenAuthPass) + if err != nil { + return "" + } + return token +} + +// AddUser adds a new user and checks the received HTTP Status code against expectedStatusCode. +func AddUser(user dataprovider.User, expectedStatusCode int) (dataprovider.User, []byte, error) { + var newUser dataprovider.User + var body []byte + userAsJSON, _ := json.Marshal(user) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(userPath), bytes.NewBuffer(userAsJSON), + "application/json", getDefaultToken()) + if err != nil { + return newUser, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newUser, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newUser) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + err = checkUser(&user, &newUser) + } + return newUser, body, err +} + +// UpdateUserWithJSON update a user using the provided JSON as POST body +func UpdateUserWithJSON(user dataprovider.User, expectedStatusCode int, disconnect string, userAsJSON []byte) (dataprovider.User, []byte, error) { + var newUser dataprovider.User + var body []byte + url, err := addUpdateUserQueryParams(buildURLRelativeToBase(userPath, url.PathEscape(user.Username)), disconnect) + if err != nil { + return user, body, err + } + resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", + getDefaultToken()) + if err != nil { + return user, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newUser, body, err + } + if err == nil { + newUser, body, err = GetUserByUsername(user.Username, expectedStatusCode) + } + if err == nil { + err = checkUser(&user, &newUser) + } + return newUser, body, err +} + +// UpdateUser updates an existing user and checks the received HTTP Status code against expectedStatusCode. +func UpdateUser(user dataprovider.User, expectedStatusCode int, disconnect string) (dataprovider.User, []byte, error) { + userAsJSON, _ := json.Marshal(user) + return UpdateUserWithJSON(user, expectedStatusCode, disconnect, userAsJSON) +} + +// RemoveUser removes an existing user and checks the received HTTP Status code against expectedStatusCode. +func RemoveUser(user dataprovider.User, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(userPath, url.PathEscape(user.Username)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetUserByUsername gets a user by username and checks the received HTTP Status code against expectedStatusCode. +func GetUserByUsername(username string, expectedStatusCode int) (dataprovider.User, []byte, error) { + var user dataprovider.User + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(userPath, url.PathEscape(username)), + nil, "", getDefaultToken()) + if err != nil { + return user, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &user) + } else { + body, _ = getResponseBody(resp) + } + return user, body, err +} + +// GetUsers returns a list of users and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetUsers(limit, offset int64, expectedStatusCode int) ([]dataprovider.User, []byte, error) { + var users []dataprovider.User + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(userPath), limit, offset) + if err != nil { + return users, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return users, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &users) + } else { + body, _ = getResponseBody(resp) + } + return users, body, err +} + +// AddGroup adds a new group and checks the received HTTP Status code against expectedStatusCode. +func AddGroup(group dataprovider.Group, expectedStatusCode int) (dataprovider.Group, []byte, error) { + var newGroup dataprovider.Group + var body []byte + asJSON, _ := json.Marshal(group) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(groupPath), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newGroup, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newGroup, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newGroup) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + group.UserSettings.Filters.TLSCerts = nil + err = checkGroup(group, newGroup) + } + return newGroup, body, err +} + +// UpdateGroup updates an existing group and checks the received HTTP Status code against expectedStatusCode +func UpdateGroup(group dataprovider.Group, expectedStatusCode int) (dataprovider.Group, []byte, error) { + var newGroup dataprovider.Group + var body []byte + + asJSON, _ := json.Marshal(group) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(groupPath, url.PathEscape(group.Name)), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newGroup, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newGroup, body, err + } + if err == nil { + newGroup, body, err = GetGroupByName(group.Name, expectedStatusCode) + } + if err == nil { + err = checkGroup(group, newGroup) + } + return newGroup, body, err +} + +// RemoveGroup removes an existing group and checks the received HTTP Status code against expectedStatusCode. +func RemoveGroup(group dataprovider.Group, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(groupPath, url.PathEscape(group.Name)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetGroupByName gets a group by name and checks the received HTTP Status code against expectedStatusCode. +func GetGroupByName(name string, expectedStatusCode int) (dataprovider.Group, []byte, error) { + var group dataprovider.Group + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(groupPath, url.PathEscape(name)), + nil, "", getDefaultToken()) + if err != nil { + return group, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &group) + } else { + body, _ = getResponseBody(resp) + } + return group, body, err +} + +// GetGroups returns a list of groups and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetGroups(limit, offset int64, expectedStatusCode int) ([]dataprovider.Group, []byte, error) { + var groups []dataprovider.Group + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(groupPath), limit, offset) + if err != nil { + return groups, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return groups, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &groups) + } else { + body, _ = getResponseBody(resp) + } + return groups, body, err +} + +// AddRole adds a new role and checks the received HTTP Status code against expectedStatusCode. +func AddRole(role dataprovider.Role, expectedStatusCode int) (dataprovider.Role, []byte, error) { + var newRole dataprovider.Role + var body []byte + asJSON, _ := json.Marshal(role) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(rolesPath), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newRole, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newRole, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newRole) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + err = checkRole(role, newRole) + } + return newRole, body, err +} + +// UpdateRole updates an existing role and checks the received HTTP Status code against expectedStatusCode +func UpdateRole(role dataprovider.Role, expectedStatusCode int) (dataprovider.Role, []byte, error) { + var newRole dataprovider.Role + var body []byte + + asJSON, _ := json.Marshal(role) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(rolesPath, url.PathEscape(role.Name)), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newRole, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newRole, body, err + } + if err == nil { + newRole, body, err = GetRoleByName(role.Name, expectedStatusCode) + } + if err == nil { + err = checkRole(role, newRole) + } + return newRole, body, err +} + +// RemoveRole removes an existing role and checks the received HTTP Status code against expectedStatusCode. +func RemoveRole(role dataprovider.Role, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(rolesPath, url.PathEscape(role.Name)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetRoleByName gets a role by name and checks the received HTTP Status code against expectedStatusCode. +func GetRoleByName(name string, expectedStatusCode int) (dataprovider.Role, []byte, error) { + var role dataprovider.Role + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(rolesPath, url.PathEscape(name)), + nil, "", getDefaultToken()) + if err != nil { + return role, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &role) + } else { + body, _ = getResponseBody(resp) + } + return role, body, err +} + +// GetRoles returns a list of roles and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetRoles(limit, offset int64, expectedStatusCode int) ([]dataprovider.Role, []byte, error) { + var roles []dataprovider.Role + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(rolesPath), limit, offset) + if err != nil { + return roles, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return roles, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &roles) + } else { + body, _ = getResponseBody(resp) + } + return roles, body, err +} + +// AddIPListEntry adds a new IP list entry and checks the received HTTP Status code against expectedStatusCode. +func AddIPListEntry(entry dataprovider.IPListEntry, expectedStatusCode int) (dataprovider.IPListEntry, []byte, error) { + var newEntry dataprovider.IPListEntry + var body []byte + + asJSON, _ := json.Marshal(entry) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(ipListsPath, strconv.Itoa(int(entry.Type))), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newEntry, body, err + } + defer resp.Body.Close() + + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newEntry, body, err + } + if err == nil { + newEntry, body, err = GetIPListEntry(entry.IPOrNet, entry.Type, http.StatusOK) + } + if err == nil { + err = checkIPListEntry(entry, newEntry) + } + return newEntry, body, err +} + +// UpdateIPListEntry updates an existing IP list entry and checks the received HTTP Status code against expectedStatusCode +func UpdateIPListEntry(entry dataprovider.IPListEntry, expectedStatusCode int) (dataprovider.IPListEntry, []byte, error) { + var newEntry dataprovider.IPListEntry + var body []byte + + asJSON, _ := json.Marshal(entry) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(ipListsPath, fmt.Sprintf("%d", entry.Type), + url.PathEscape(entry.IPOrNet)), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newEntry, body, err + } + defer resp.Body.Close() + + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newEntry, body, err + } + if err == nil { + newEntry, body, err = GetIPListEntry(entry.IPOrNet, entry.Type, http.StatusOK) + } + if err == nil { + err = checkIPListEntry(entry, newEntry) + } + return newEntry, body, err +} + +// RemoveIPListEntry removes an existing IP list entry and checks the received HTTP Status code against expectedStatusCode. +func RemoveIPListEntry(entry dataprovider.IPListEntry, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(ipListsPath, fmt.Sprintf("%d", entry.Type), + url.PathEscape(entry.IPOrNet)), nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetIPListEntry returns an IP list entry matching the specified parameters, if exists, +// and checks the received HTTP Status code against expectedStatusCode. +func GetIPListEntry(ipOrNet string, listType dataprovider.IPListType, expectedStatusCode int, +) (dataprovider.IPListEntry, []byte, error) { + var entry dataprovider.IPListEntry + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(ipListsPath, fmt.Sprintf("%d", listType), url.PathEscape(ipOrNet)), + nil, "", getDefaultToken()) + if err != nil { + return entry, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &entry) + } else { + body, _ = getResponseBody(resp) + } + return entry, body, err +} + +// GetIPListEntries returns a list of IP list entries and checks the received HTTP Status code against expectedStatusCode. +func GetIPListEntries(listType dataprovider.IPListType, filter, from, order string, limit int64, + expectedStatusCode int, +) ([]dataprovider.IPListEntry, []byte, error) { + var entries []dataprovider.IPListEntry + var body []byte + + url, err := url.Parse(buildURLRelativeToBase(ipListsPath, strconv.Itoa(int(listType)))) + if err != nil { + return entries, body, err + } + q := url.Query() + q.Add("filter", filter) + q.Add("from", from) + q.Add("order", order) + if limit > 0 { + q.Add("limit", strconv.FormatInt(limit, 10)) + } + url.RawQuery = q.Encode() + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return entries, body, err + } + defer resp.Body.Close() + + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &entries) + } else { + body, _ = getResponseBody(resp) + } + return entries, body, err +} + +// AddAdmin adds a new admin and checks the received HTTP Status code against expectedStatusCode. +func AddAdmin(admin dataprovider.Admin, expectedStatusCode int) (dataprovider.Admin, []byte, error) { + var newAdmin dataprovider.Admin + var body []byte + asJSON, _ := json.Marshal(admin) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(adminPath), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newAdmin, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newAdmin, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newAdmin) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + err = checkAdmin(&admin, &newAdmin) + } + return newAdmin, body, err +} + +// UpdateAdmin updates an existing admin and checks the received HTTP Status code against expectedStatusCode +func UpdateAdmin(admin dataprovider.Admin, expectedStatusCode int) (dataprovider.Admin, []byte, error) { + var newAdmin dataprovider.Admin + var body []byte + + asJSON, _ := json.Marshal(admin) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(adminPath, url.PathEscape(admin.Username)), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newAdmin, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newAdmin, body, err + } + if err == nil { + newAdmin, body, err = GetAdminByUsername(admin.Username, expectedStatusCode) + } + if err == nil { + err = checkAdmin(&admin, &newAdmin) + } + return newAdmin, body, err +} + +// RemoveAdmin removes an existing admin and checks the received HTTP Status code against expectedStatusCode. +func RemoveAdmin(admin dataprovider.Admin, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(adminPath, url.PathEscape(admin.Username)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetAdminByUsername gets an admin by username and checks the received HTTP Status code against expectedStatusCode. +func GetAdminByUsername(username string, expectedStatusCode int) (dataprovider.Admin, []byte, error) { + var admin dataprovider.Admin + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(adminPath, url.PathEscape(username)), + nil, "", getDefaultToken()) + if err != nil { + return admin, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &admin) + } else { + body, _ = getResponseBody(resp) + } + return admin, body, err +} + +// GetAdmins returns a list of admins and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetAdmins(limit, offset int64, expectedStatusCode int) ([]dataprovider.Admin, []byte, error) { + var admins []dataprovider.Admin + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(adminPath), limit, offset) + if err != nil { + return admins, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return admins, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &admins) + } else { + body, _ = getResponseBody(resp) + } + return admins, body, err +} + +// ChangeAdminPassword changes the password for an existing admin +func ChangeAdminPassword(currentPassword, newPassword string, expectedStatusCode int) ([]byte, error) { + var body []byte + + pwdChange := make(map[string]string) + pwdChange["current_password"] = currentPassword + pwdChange["new_password"] = newPassword + + asJSON, _ := json.Marshal(&pwdChange) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(adminPwdPath), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + + err = checkResponse(resp.StatusCode, expectedStatusCode) + body, _ = getResponseBody(resp) + + return body, err +} + +// GetAPIKeys returns a list of API keys and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetAPIKeys(limit, offset int64, expectedStatusCode int) ([]dataprovider.APIKey, []byte, error) { + var apiKeys []dataprovider.APIKey + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(apiKeysPath), limit, offset) + if err != nil { + return apiKeys, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return apiKeys, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &apiKeys) + } else { + body, _ = getResponseBody(resp) + } + return apiKeys, body, err +} + +// AddAPIKey adds a new API key and checks the received HTTP Status code against expectedStatusCode. +func AddAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { + var newAPIKey dataprovider.APIKey + var body []byte + asJSON, _ := json.Marshal(apiKey) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(apiKeysPath), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newAPIKey, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newAPIKey, body, err + } + if err != nil { + body, _ = getResponseBody(resp) + return newAPIKey, body, err + } + response := make(map[string]string) + err = render.DecodeJSON(resp.Body, &response) + if err == nil { + newAPIKey, body, err = GetAPIKeyByID(resp.Header.Get("X-Object-ID"), http.StatusOK) + } + if err == nil { + err = checkAPIKey(&apiKey, &newAPIKey) + } + newAPIKey.Key = response["key"] + + return newAPIKey, body, err +} + +// UpdateAPIKey updates an existing API key and checks the received HTTP Status code against expectedStatusCode +func UpdateAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { + var newAPIKey dataprovider.APIKey + var body []byte + + asJSON, _ := json.Marshal(apiKey) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(apiKeysPath, url.PathEscape(apiKey.KeyID)), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newAPIKey, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newAPIKey, body, err + } + if err == nil { + newAPIKey, body, err = GetAPIKeyByID(apiKey.KeyID, expectedStatusCode) + } + if err == nil { + err = checkAPIKey(&apiKey, &newAPIKey) + } + return newAPIKey, body, err +} + +// RemoveAPIKey removes an existing API key and checks the received HTTP Status code against expectedStatusCode. +func RemoveAPIKey(apiKey dataprovider.APIKey, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(apiKeysPath, url.PathEscape(apiKey.KeyID)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetAPIKeyByID gets a API key by ID and checks the received HTTP Status code against expectedStatusCode. +func GetAPIKeyByID(keyID string, expectedStatusCode int) (dataprovider.APIKey, []byte, error) { + var apiKey dataprovider.APIKey + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(apiKeysPath, url.PathEscape(keyID)), + nil, "", getDefaultToken()) + if err != nil { + return apiKey, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &apiKey) + } else { + body, _ = getResponseBody(resp) + } + return apiKey, body, err +} + +// AddEventAction adds a new event action +func AddEventAction(action dataprovider.BaseEventAction, expectedStatusCode int) (dataprovider.BaseEventAction, []byte, error) { + var newAction dataprovider.BaseEventAction + var body []byte + asJSON, _ := json.Marshal(action) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(eventActionsPath), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newAction, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newAction, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newAction) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + err = checkEventAction(action, newAction) + } + return newAction, body, err +} + +// UpdateEventAction updates an existing event action +func UpdateEventAction(action dataprovider.BaseEventAction, expectedStatusCode int) (dataprovider.BaseEventAction, []byte, error) { + var newAction dataprovider.BaseEventAction + var body []byte + + asJSON, _ := json.Marshal(action) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(eventActionsPath, url.PathEscape(action.Name)), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newAction, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newAction, body, err + } + if err == nil { + newAction, body, err = GetEventActionByName(action.Name, expectedStatusCode) + } + if err == nil { + err = checkEventAction(action, newAction) + } + return newAction, body, err +} + +// RemoveEventAction removes an existing action and checks the received HTTP Status code against expectedStatusCode. +func RemoveEventAction(action dataprovider.BaseEventAction, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(eventActionsPath, url.PathEscape(action.Name)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetEventActionByName gets an event action by name and checks the received HTTP Status code against expectedStatusCode. +func GetEventActionByName(name string, expectedStatusCode int) (dataprovider.BaseEventAction, []byte, error) { + var action dataprovider.BaseEventAction + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(eventActionsPath, url.PathEscape(name)), + nil, "", getDefaultToken()) + if err != nil { + return action, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &action) + } else { + body, _ = getResponseBody(resp) + } + return action, body, err +} + +// GetEventActions returns a list of event actions and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetEventActions(limit, offset int64, expectedStatusCode int) ([]dataprovider.BaseEventAction, []byte, error) { + var actions []dataprovider.BaseEventAction + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(eventActionsPath), limit, offset) + if err != nil { + return actions, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return actions, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &actions) + } else { + body, _ = getResponseBody(resp) + } + return actions, body, err +} + +// AddEventRule adds a new event rule +func AddEventRule(rule dataprovider.EventRule, expectedStatusCode int) (dataprovider.EventRule, []byte, error) { + var newRule dataprovider.EventRule + var body []byte + asJSON, _ := json.Marshal(rule) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(eventRulesPath), bytes.NewBuffer(asJSON), + "application/json", getDefaultToken()) + if err != nil { + return newRule, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newRule, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newRule) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + err = checkEventRule(rule, newRule) + } + return newRule, body, err +} + +// UpdateEventRule updates an existing event rule +func UpdateEventRule(rule dataprovider.EventRule, expectedStatusCode int) (dataprovider.EventRule, []byte, error) { + var newRule dataprovider.EventRule + var body []byte + + asJSON, _ := json.Marshal(rule) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(eventRulesPath, url.PathEscape(rule.Name)), + bytes.NewBuffer(asJSON), "application/json", getDefaultToken()) + if err != nil { + return newRule, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return newRule, body, err + } + if err == nil { + newRule, body, err = GetEventRuleByName(rule.Name, expectedStatusCode) + } + if err == nil { + err = checkEventRule(rule, newRule) + } + return newRule, body, err +} + +// RemoveEventRule removes an existing rule and checks the received HTTP Status code against expectedStatusCode. +func RemoveEventRule(rule dataprovider.EventRule, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(eventRulesPath, url.PathEscape(rule.Name)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetEventRuleByName gets an event rule by name and checks the received HTTP Status code against expectedStatusCode. +func GetEventRuleByName(name string, expectedStatusCode int) (dataprovider.EventRule, []byte, error) { + var rule dataprovider.EventRule + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(eventRulesPath, url.PathEscape(name)), + nil, "", getDefaultToken()) + if err != nil { + return rule, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &rule) + } else { + body, _ = getResponseBody(resp) + } + return rule, body, err +} + +// GetEventRules returns a list of event rules and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +func GetEventRules(limit, offset int64, expectedStatusCode int) ([]dataprovider.EventRule, []byte, error) { + var rules []dataprovider.EventRule + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(eventRulesPath), limit, offset) + if err != nil { + return rules, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return rules, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &rules) + } else { + body, _ = getResponseBody(resp) + } + return rules, body, err +} + +// RunOnDemandRule executes the specified on demand rule +func RunOnDemandRule(name string, expectedStatusCode int) ([]byte, error) { + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(eventRulesPath, "run", url.PathEscape(name)), + nil, "application/json", getDefaultToken()) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := getResponseBody(resp) + if err != nil { + return b, err + } + if err := checkResponse(resp.StatusCode, expectedStatusCode); err != nil { + return b, err + } + return b, nil +} + +// GetQuotaScans gets active quota scans for users and checks the received HTTP Status code against expectedStatusCode. +func GetQuotaScans(expectedStatusCode int) ([]common.ActiveQuotaScan, []byte, error) { + var quotaScans []common.ActiveQuotaScan + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanPath), nil, "", getDefaultToken()) + if err != nil { + return quotaScans, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, "aScans) + } else { + body, _ = getResponseBody(resp) + } + return quotaScans, body, err +} + +// StartQuotaScan starts a new quota scan for the given user and checks the received HTTP Status code against expectedStatusCode. +func StartQuotaScan(user dataprovider.User, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotasBasePath, "users", user.Username, "scan"), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// UpdateQuotaUsage updates the user used quota limits and checks the received +// HTTP Status code against expectedStatusCode. +func UpdateQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { + var body []byte + userAsJSON, _ := json.Marshal(user) + url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "users", user.Username, "usage"), mode) + if err != nil { + return body, err + } + resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", + getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// UpdateTransferQuotaUsage updates the user used transfer quota limits and checks the received +// HTTP Status code against expectedStatusCode. +func UpdateTransferQuotaUsage(user dataprovider.User, mode string, expectedStatusCode int) ([]byte, error) { + var body []byte + userAsJSON, _ := json.Marshal(user) + url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "users", user.Username, "transfer-usage"), mode) + if err != nil { + return body, err + } + resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(userAsJSON), "application/json", + getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetRetentionChecks returns the active retention checks +func GetRetentionChecks(expectedStatusCode int) ([]common.ActiveRetentionChecks, []byte, error) { + var checks []common.ActiveRetentionChecks + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(retentionChecksPath), nil, "", getDefaultToken()) + if err != nil { + return checks, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &checks) + } else { + body, _ = getResponseBody(resp) + } + return checks, body, err +} + +// GetConnections returns status and stats for active SFTP/SCP connections +func GetConnections(expectedStatusCode int) ([]common.ConnectionStatus, []byte, error) { + var connections []common.ConnectionStatus + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(activeConnectionsPath), nil, "", getDefaultToken()) + if err != nil { + return connections, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &connections) + } else { + body, _ = getResponseBody(resp) + } + return connections, body, err +} + +// CloseConnection closes an active connection identified by connectionID +func CloseConnection(connectionID string, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(activeConnectionsPath, connectionID), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + body, _ = getResponseBody(resp) + return body, err +} + +// AddFolder adds a new folder and checks the received HTTP Status code against expectedStatusCode +func AddFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { + var newFolder vfs.BaseVirtualFolder + var body []byte + folderAsJSON, _ := json.Marshal(folder) + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(folderPath), bytes.NewBuffer(folderAsJSON), + "application/json", getDefaultToken()) + if err != nil { + return newFolder, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusCreated { + body, _ = getResponseBody(resp) + return newFolder, body, err + } + if err == nil { + err = render.DecodeJSON(resp.Body, &newFolder) + } else { + body, _ = getResponseBody(resp) + } + if err == nil { + err = checkFolder(&folder, &newFolder) + } + return newFolder, body, err +} + +// UpdateFolder updates an existing folder and checks the received HTTP Status code against expectedStatusCode. +func UpdateFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { + var updatedFolder vfs.BaseVirtualFolder + var body []byte + + folderAsJSON, _ := json.Marshal(folder) + resp, err := sendHTTPRequest(http.MethodPut, buildURLRelativeToBase(folderPath, url.PathEscape(folder.Name)), + bytes.NewBuffer(folderAsJSON), "application/json", getDefaultToken()) + if err != nil { + return updatedFolder, body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + + err = checkResponse(resp.StatusCode, expectedStatusCode) + if expectedStatusCode != http.StatusOK { + return updatedFolder, body, err + } + if err == nil { + updatedFolder, body, err = GetFolderByName(folder.Name, expectedStatusCode) + } + if err == nil { + err = checkFolder(&folder, &updatedFolder) + } + return updatedFolder, body, err +} + +// RemoveFolder removes an existing user and checks the received HTTP Status code against expectedStatusCode. +func RemoveFolder(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(folderPath, url.PathEscape(folder.Name)), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetFolderByName gets a folder by name and checks the received HTTP Status code against expectedStatusCode. +func GetFolderByName(name string, expectedStatusCode int) (vfs.BaseVirtualFolder, []byte, error) { + var folder vfs.BaseVirtualFolder + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(folderPath, url.PathEscape(name)), + nil, "", getDefaultToken()) + if err != nil { + return folder, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &folder) + } else { + body, _ = getResponseBody(resp) + } + return folder, body, err +} + +// GetFolders returns a list of folders and checks the received HTTP Status code against expectedStatusCode. +// The number of results can be limited specifying a limit. +// Some results can be skipped specifying an offset. +// The results can be filtered specifying a folder path, the folder path filter is an exact match +func GetFolders(limit int64, offset int64, expectedStatusCode int) ([]vfs.BaseVirtualFolder, []byte, error) { + var folders []vfs.BaseVirtualFolder + var body []byte + url, err := addLimitAndOffsetQueryParams(buildURLRelativeToBase(folderPath), limit, offset) + if err != nil { + return folders, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return folders, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &folders) + } else { + body, _ = getResponseBody(resp) + } + return folders, body, err +} + +// GetFoldersQuotaScans gets active quota scans for folders and checks the received HTTP Status code against expectedStatusCode. +func GetFoldersQuotaScans(expectedStatusCode int) ([]common.ActiveVirtualFolderQuotaScan, []byte, error) { + var quotaScans []common.ActiveVirtualFolderQuotaScan + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(quotaScanVFolderPath), nil, "", getDefaultToken()) + if err != nil { + return quotaScans, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, "aScans) + } else { + body, _ = getResponseBody(resp) + } + return quotaScans, body, err +} + +// StartFolderQuotaScan start a new quota scan for the given folder and checks the received HTTP Status code against expectedStatusCode. +func StartFolderQuotaScan(folder vfs.BaseVirtualFolder, expectedStatusCode int) ([]byte, error) { + var body []byte + resp, err := sendHTTPRequest(http.MethodPost, buildURLRelativeToBase(quotasBasePath, "folders", folder.Name, "scan"), + nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// UpdateFolderQuotaUsage updates the folder used quota limits and checks the received HTTP Status code against expectedStatusCode. +func UpdateFolderQuotaUsage(folder vfs.BaseVirtualFolder, mode string, expectedStatusCode int) ([]byte, error) { + var body []byte + folderAsJSON, _ := json.Marshal(folder) + url, err := addModeQueryParam(buildURLRelativeToBase(quotasBasePath, "folders", folder.Name, "usage"), mode) + if err != nil { + return body, err + } + resp, err := sendHTTPRequest(http.MethodPut, url.String(), bytes.NewBuffer(folderAsJSON), "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// GetVersion returns version details +func GetVersion(expectedStatusCode int) (version.Info, []byte, error) { + var appVersion version.Info + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(versionPath), nil, "", getDefaultToken()) + if err != nil { + return appVersion, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &appVersion) + } else { + body, _ = getResponseBody(resp) + } + return appVersion, body, err +} + +// GetStatus returns the server status +func GetStatus(expectedStatusCode int) (httpd.ServicesStatus, []byte, error) { + var response httpd.ServicesStatus + var body []byte + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(serverStatusPath), nil, "", getDefaultToken()) + if err != nil { + return response, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && (expectedStatusCode == http.StatusOK) { + err = render.DecodeJSON(resp.Body, &response) + } else { + body, _ = getResponseBody(resp) + } + return response, body, err +} + +// GetDefenderHosts returns hosts that are banned or for which some violations have been detected +func GetDefenderHosts(expectedStatusCode int) ([]dataprovider.DefenderEntry, []byte, error) { + var response []dataprovider.DefenderEntry + var body []byte + url, err := url.Parse(buildURLRelativeToBase(defenderHosts)) + if err != nil { + return response, body, err + } + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return response, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &response) + } else { + body, _ = getResponseBody(resp) + } + return response, body, err +} + +// GetDefenderHostByIP returns the host with the given IP, if it exists +func GetDefenderHostByIP(ip string, expectedStatusCode int) (dataprovider.DefenderEntry, []byte, error) { + var host dataprovider.DefenderEntry + var body []byte + id := hex.EncodeToString([]byte(ip)) + resp, err := sendHTTPRequest(http.MethodGet, buildURLRelativeToBase(defenderHosts, id), + nil, "", getDefaultToken()) + if err != nil { + return host, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &host) + } else { + body, _ = getResponseBody(resp) + } + return host, body, err +} + +// RemoveDefenderHostByIP removes the host with the given IP from the defender list +func RemoveDefenderHostByIP(ip string, expectedStatusCode int) ([]byte, error) { + var body []byte + id := hex.EncodeToString([]byte(ip)) + resp, err := sendHTTPRequest(http.MethodDelete, buildURLRelativeToBase(defenderHosts, id), nil, "", getDefaultToken()) + if err != nil { + return body, err + } + defer resp.Body.Close() + body, _ = getResponseBody(resp) + return body, checkResponse(resp.StatusCode, expectedStatusCode) +} + +// Dumpdata requests a backup to outputFile. +// outputFile is relative to the configured backups_path +func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int, scopes ...string) (map[string]any, []byte, error) { + var response map[string]any + var body []byte + url, err := url.Parse(buildURLRelativeToBase(dumpDataPath)) + if err != nil { + return response, body, err + } + q := url.Query() + if outputData != "" { + q.Add("output-data", outputData) + } + if outputFile != "" { + q.Add("output-file", outputFile) + } + if indent != "" { + q.Add("indent", indent) + } + if len(scopes) > 0 { + q.Add("scopes", strings.Join(scopes, ",")) + } + url.RawQuery = q.Encode() + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return response, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &response) + } else { + body, _ = getResponseBody(resp) + } + return response, body, err +} + +// Loaddata restores a backup. +func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any + var body []byte + url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) + if err != nil { + return response, body, err + } + q := url.Query() + q.Add("input-file", inputFile) + if scanQuota != "" { + q.Add("scan-quota", scanQuota) + } + if mode != "" { + q.Add("mode", mode) + } + url.RawQuery = q.Encode() + resp, err := sendHTTPRequest(http.MethodGet, url.String(), nil, "", getDefaultToken()) + if err != nil { + return response, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &response) + } else { + body, _ = getResponseBody(resp) + } + return response, body, err +} + +// LoaddataFromPostBody restores a backup +func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) { + var response map[string]any + var body []byte + url, err := url.Parse(buildURLRelativeToBase(loadDataPath)) + if err != nil { + return response, body, err + } + q := url.Query() + if scanQuota != "" { + q.Add("scan-quota", scanQuota) + } + if mode != "" { + q.Add("mode", mode) + } + url.RawQuery = q.Encode() + resp, err := sendHTTPRequest(http.MethodPost, url.String(), bytes.NewReader(data), "", getDefaultToken()) + if err != nil { + return response, body, err + } + defer resp.Body.Close() + err = checkResponse(resp.StatusCode, expectedStatusCode) + if err == nil && expectedStatusCode == http.StatusOK { + err = render.DecodeJSON(resp.Body, &response) + } else { + body, _ = getResponseBody(resp) + } + return response, body, err +} + +func checkResponse(actual int, expected int) error { + if expected != actual { + return fmt.Errorf("wrong status code: got %v want %v", actual, expected) + } + return nil +} + +func getResponseBody(resp *http.Response) ([]byte, error) { + return io.ReadAll(resp.Body) +} + +func checkEventAction(expected, actual dataprovider.BaseEventAction) error { + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual action ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("action ID mismatch") + } + } + if dataprovider.ConvertName(expected.Name) != actual.Name { + return errors.New("name mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if expected.Type != actual.Type { + return errors.New("type mismatch") + } + if expected.Options.PwdExpirationConfig.Threshold != actual.Options.PwdExpirationConfig.Threshold { + return errors.New("password expiration threshold mismatch") + } + if expected.Options.UserInactivityConfig.DisableThreshold != actual.Options.UserInactivityConfig.DisableThreshold { + return errors.New("user inactivity disable threshold mismatch") + } + if expected.Options.UserInactivityConfig.DeleteThreshold != actual.Options.UserInactivityConfig.DeleteThreshold { + return errors.New("user inactivity delete threshold mismatch") + } + if err := compareEventActionIDPConfigFields(expected.Options.IDPConfig, actual.Options.IDPConfig); err != nil { + return err + } + if err := compareEventActionCmdConfigFields(expected.Options.CmdConfig, actual.Options.CmdConfig); err != nil { + return err + } + if err := compareEventActionEmailConfigFields(expected.Options.EmailConfig, actual.Options.EmailConfig); err != nil { + return err + } + if err := compareEventActionDataRetentionFields(expected.Options.RetentionConfig, actual.Options.RetentionConfig); err != nil { + return err + } + if err := compareEventActionFsConfigFields(expected.Options.FsConfig, actual.Options.FsConfig); err != nil { + return err + } + return compareEventActionHTTPConfigFields(expected.Options.HTTPConfig, actual.Options.HTTPConfig) +} + +func checkEventSchedules(expected, actual []dataprovider.Schedule) error { + if len(expected) != len(actual) { + return errors.New("schedules mismatch") + } + for _, ex := range expected { + found := false + for _, ac := range actual { + if ac.DayOfMonth == ex.DayOfMonth && ac.DayOfWeek == ex.DayOfWeek && ac.Hours == ex.Hours && ac.Month == ex.Month { + found = true + break + } + } + if !found { + return errors.New("schedules content mismatch") + } + } + return nil +} + +func compareConditionPatternOptions(expected, actual []dataprovider.ConditionPattern) error { + if len(expected) != len(actual) { + return errors.New("condition pattern mismatch") + } + for _, ex := range expected { + found := false + for _, ac := range actual { + if ac.Pattern == ex.Pattern && ac.InverseMatch == ex.InverseMatch { + found = true + break + } + } + if !found { + return errors.New("condition pattern content mismatch") + } + } + return nil +} + +func checkEventConditionOptions(expected, actual dataprovider.ConditionOptions) error { //nolint:gocyclo + if err := compareConditionPatternOptions(expected.Names, actual.Names); err != nil { + return errors.New("condition names mismatch") + } + if err := compareConditionPatternOptions(expected.GroupNames, actual.GroupNames); err != nil { + return errors.New("condition group names mismatch") + } + if err := compareConditionPatternOptions(expected.RoleNames, actual.RoleNames); err != nil { + return errors.New("condition role names mismatch") + } + if err := compareConditionPatternOptions(expected.FsPaths, actual.FsPaths); err != nil { + return errors.New("condition fs_paths mismatch") + } + if len(expected.Protocols) != len(actual.Protocols) { + return errors.New("condition protocols mismatch") + } + for _, v := range expected.Protocols { + if !slices.Contains(actual.Protocols, v) { + return errors.New("condition protocols content mismatch") + } + } + if len(expected.EventStatuses) != len(actual.EventStatuses) { + return errors.New("condition statuses mismatch") + } + for _, v := range expected.EventStatuses { + if !slices.Contains(actual.EventStatuses, v) { + return errors.New("condition statuses content mismatch") + } + } + if len(expected.ProviderObjects) != len(actual.ProviderObjects) { + return errors.New("condition provider objects mismatch") + } + for _, v := range expected.ProviderObjects { + if !slices.Contains(actual.ProviderObjects, v) { + return errors.New("condition provider objects content mismatch") + } + } + if expected.MinFileSize != actual.MinFileSize { + return errors.New("condition min file size mismatch") + } + if expected.MaxFileSize != actual.MaxFileSize { + return errors.New("condition max file size mismatch") + } + return nil +} + +func checkEventConditions(expected, actual dataprovider.EventConditions) error { + if len(expected.FsEvents) != len(actual.FsEvents) { + return errors.New("fs events mismatch") + } + for _, v := range expected.FsEvents { + if !slices.Contains(actual.FsEvents, v) { + return errors.New("fs events content mismatch") + } + } + if len(expected.ProviderEvents) != len(actual.ProviderEvents) { + return errors.New("provider events mismatch") + } + for _, v := range expected.ProviderEvents { + if !slices.Contains(actual.ProviderEvents, v) { + return errors.New("provider events content mismatch") + } + } + if err := checkEventConditionOptions(expected.Options, actual.Options); err != nil { + return err + } + if expected.IDPLoginEvent != actual.IDPLoginEvent { + return errors.New("IDP login event mismatch") + } + + return checkEventSchedules(expected.Schedules, actual.Schedules) +} + +func checkEventRuleActions(expected, actual []dataprovider.EventAction) error { + if len(expected) != len(actual) { + return errors.New("actions mismatch") + } + for _, ex := range expected { + found := false + for _, ac := range actual { + if ex.Name == ac.Name && ex.Order == ac.Order && ex.Options.ExecuteSync == ac.Options.ExecuteSync && + ex.Options.IsFailureAction == ac.Options.IsFailureAction && ex.Options.StopOnFailure == ac.Options.StopOnFailure { + found = true + break + } + } + if !found { + return errors.New("actions contents mismatch") + } + } + return nil +} + +func checkEventRule(expected, actual dataprovider.EventRule) error { + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual group ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("group ID mismatch") + } + } + if dataprovider.ConvertName(expected.Name) != actual.Name { + return errors.New("name mismatch") + } + if expected.Status != actual.Status { + return errors.New("status mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if actual.CreatedAt == 0 { + return errors.New("created_at unset") + } + if actual.UpdatedAt == 0 { + return errors.New("updated_at unset") + } + if expected.Trigger != actual.Trigger { + return errors.New("trigger mismatch") + } + if err := checkEventConditions(expected.Conditions, actual.Conditions); err != nil { + return err + } + return checkEventRuleActions(expected.Actions, actual.Actions) +} + +func checkIPListEntry(expected, actual dataprovider.IPListEntry) error { + if expected.IPOrNet != actual.IPOrNet { + return errors.New("ipornet mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if expected.Type != actual.Type { + return errors.New("type mismatch") + } + if expected.Mode != actual.Mode { + return errors.New("mode mismatch") + } + if expected.Protocols != actual.Protocols { + return errors.New("protocols mismatch") + } + if actual.CreatedAt == 0 { + return errors.New("created_at unset") + } + if actual.UpdatedAt == 0 { + return errors.New("updated_at unset") + } + return nil +} + +func checkRole(expected, actual dataprovider.Role) error { + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual role ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("role ID mismatch") + } + } + if dataprovider.ConvertName(expected.Name) != actual.Name { + return errors.New("name mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if actual.CreatedAt == 0 { + return errors.New("created_at unset") + } + if actual.UpdatedAt == 0 { + return errors.New("updated_at unset") + } + return nil +} + +func checkGroup(expected, actual dataprovider.Group) error { + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual group ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("group ID mismatch") + } + } + if dataprovider.ConvertName(expected.Name) != actual.Name { + return errors.New("name mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if actual.CreatedAt == 0 { + return errors.New("created_at unset") + } + if actual.UpdatedAt == 0 { + return errors.New("updated_at unset") + } + if err := compareEqualGroupSettingsFields(expected.UserSettings.BaseGroupUserSettings, + actual.UserSettings.BaseGroupUserSettings); err != nil { + return err + } + if err := compareVirtualFolders(expected.VirtualFolders, actual.VirtualFolders); err != nil { + return err + } + if err := compareUserFilters(expected.UserSettings.Filters, actual.UserSettings.Filters); err != nil { + return err + } + return compareFsConfig(&expected.UserSettings.FsConfig, &actual.UserSettings.FsConfig) +} + +func checkFolder(expected *vfs.BaseVirtualFolder, actual *vfs.BaseVirtualFolder) error { + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual folder ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("folder ID mismatch") + } + } + if dataprovider.ConvertName(expected.Name) != actual.Name { + return errors.New("name mismatch") + } + if expected.MappedPath != actual.MappedPath { + return errors.New("mapped path mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + return compareFsConfig(&expected.FsConfig, &actual.FsConfig) +} + +func checkAPIKey(expected, actual *dataprovider.APIKey) error { + if actual.Key != "" { + return errors.New("key must not be visible") + } + if actual.KeyID == "" { + return errors.New("actual key_id cannot be empty") + } + if expected.Name != actual.Name { + return errors.New("name mismatch") + } + if expected.Scope != actual.Scope { + return errors.New("scope mismatch") + } + if actual.CreatedAt == 0 { + return errors.New("created_at cannot be 0") + } + if actual.UpdatedAt == 0 { + return errors.New("updated_at cannot be 0") + } + if expected.ExpiresAt != actual.ExpiresAt { + return errors.New("expires_at mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if expected.User != actual.User { + return errors.New("user mismatch") + } + if expected.Admin != actual.Admin { + return errors.New("admin mismatch") + } + + return nil +} + +func checkAdmin(expected, actual *dataprovider.Admin) error { + if actual.Password != "" { + return errors.New("admin password must not be visible") + } + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual admin ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("admin ID mismatch") + } + } + if expected.CreatedAt > 0 { + if expected.CreatedAt != actual.CreatedAt { + return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) + } + } + if err := compareAdminEqualFields(expected, actual); err != nil { + return err + } + if len(expected.Permissions) != len(actual.Permissions) { + return errors.New("permissions mismatch") + } + for _, p := range expected.Permissions { + if !slices.Contains(actual.Permissions, p) { + return errors.New("permissions content mismatch") + } + } + if err := compareAdminFilters(expected.Filters, actual.Filters); err != nil { + return err + } + return compareAdminGroups(expected, actual) +} + +func compareAdminFilters(expected, actual dataprovider.AdminFilters) error { + if expected.AllowAPIKeyAuth != actual.AllowAPIKeyAuth { + return errors.New("allow_api_key_auth mismatch") + } + if len(expected.AllowList) != len(actual.AllowList) { + return errors.New("allow list mismatch") + } + for _, v := range expected.AllowList { + if !slices.Contains(actual.AllowList, v) { + return errors.New("allow list content mismatch") + } + } + if expected.Preferences.HideUserPageSections != actual.Preferences.HideUserPageSections { + return errors.New("hide user page sections mismatch") + } + if expected.Preferences.DefaultUsersExpiration != actual.Preferences.DefaultUsersExpiration { + return errors.New("default users expiration mismatch") + } + if expected.RequirePasswordChange != actual.RequirePasswordChange { + return errors.New("require password change mismatch") + } + if expected.RequireTwoFactor != actual.RequireTwoFactor { + return errors.New("require two factor mismatch") + } + return nil +} + +func compareAdminEqualFields(expected *dataprovider.Admin, actual *dataprovider.Admin) error { + if dataprovider.ConvertName(expected.Username) != actual.Username { + return errors.New("sername mismatch") + } + if expected.Email != actual.Email { + return errors.New("email mismatch") + } + if expected.Status != actual.Status { + return errors.New("status mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if expected.AdditionalInfo != actual.AdditionalInfo { + return errors.New("additional info mismatch") + } + if expected.Role != actual.Role { + return errors.New("role mismatch") + } + return nil +} + +func checkUser(expected *dataprovider.User, actual *dataprovider.User) error { + if actual.Password != "" { + return errors.New("user password must not be visible") + } + if expected.ID <= 0 { + if actual.ID <= 0 { + return errors.New("actual user ID must be > 0") + } + } else { + if actual.ID != expected.ID { + return errors.New("user ID mismatch") + } + } + if expected.CreatedAt > 0 { + if expected.CreatedAt != actual.CreatedAt { + return fmt.Errorf("created_at mismatch %v != %v", expected.CreatedAt, actual.CreatedAt) + } + } + + if expected.Email != actual.Email { + return errors.New("email mismatch") + } + if !slices.Equal(expected.Filters.AdditionalEmails, actual.Filters.AdditionalEmails) { + return errors.New("additional emails mismatch") + } + if expected.Filters.RequirePasswordChange != actual.Filters.RequirePasswordChange { + return errors.New("require_password_change mismatch") + } + if err := compareUserPermissions(expected.Permissions, actual.Permissions); err != nil { + return err + } + if err := compareUserFilters(expected.Filters.BaseUserFilters, actual.Filters.BaseUserFilters); err != nil { + return err + } + if err := compareFsConfig(&expected.FsConfig, &actual.FsConfig); err != nil { + return err + } + if err := compareUserGroups(expected, actual); err != nil { + return err + } + if err := compareVirtualFolders(expected.VirtualFolders, actual.VirtualFolders); err != nil { + return err + } + return compareEqualsUserFields(expected, actual) +} + +func compareUserPermissions(expected map[string][]string, actual map[string][]string) error { + if len(expected) != len(actual) { + return errors.New("permissions mismatch") + } + for dir, perms := range expected { + if actualPerms, ok := actual[dir]; ok { + for _, v := range actualPerms { + if !slices.Contains(perms, v) { + return errors.New("permissions contents mismatch") + } + } + } else { + return errors.New("permissions directories mismatch") + } + } + return nil +} + +func compareAdminGroups(expected *dataprovider.Admin, actual *dataprovider.Admin) error { + if len(actual.Groups) != len(expected.Groups) { + return errors.New("groups len mismatch") + } + for _, g := range actual.Groups { + found := false + for _, g1 := range expected.Groups { + if g1.Name == g.Name { + found = true + if g1.Options.AddToUsersAs != g.Options.AddToUsersAs { + return fmt.Errorf("add to users as field mismatch for group %s", g.Name) + } + } + } + if !found { + return errors.New("groups mismatch") + } + } + return nil +} + +func compareUserGroups(expected *dataprovider.User, actual *dataprovider.User) error { + if len(actual.Groups) != len(expected.Groups) { + return errors.New("groups len mismatch") + } + for _, g := range actual.Groups { + found := false + for _, g1 := range expected.Groups { + if g1.Name == g.Name { + found = true + if g1.Type != g.Type { + return fmt.Errorf("type mismatch for group %s", g.Name) + } + } + } + if !found { + return errors.New("groups mismatch") + } + } + return nil +} + +func compareVirtualFolders(expected []vfs.VirtualFolder, actual []vfs.VirtualFolder) error { + if len(actual) != len(expected) { + return errors.New("virtual folders len mismatch") + } + for _, v := range actual { + found := false + for _, v1 := range expected { + if path.Clean(v.VirtualPath) == path.Clean(v1.VirtualPath) { + if dataprovider.ConvertName(v1.Name) != v.Name { + return errors.New("virtual folder name mismatch") + } + if v.QuotaSize != v1.QuotaSize { + return errors.New("vfolder quota size mismatch") + } + if (v.QuotaFiles) != (v1.QuotaFiles) { + return errors.New("vfolder quota files mismatch") + } + found = true + break + } + } + if !found { + return errors.New("virtual folders mismatch") + } + } + return nil +} + +func compareFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.Provider != actual.Provider { + return errors.New("fs provider mismatch") + } + if expected.OSConfig.ReadBufferSize != actual.OSConfig.ReadBufferSize { + return fmt.Errorf("read buffer size mismatch") + } + if expected.OSConfig.WriteBufferSize != actual.OSConfig.WriteBufferSize { + return fmt.Errorf("write buffer size mismatch") + } + if err := compareS3Config(expected, actual); err != nil { + return err + } + if err := compareGCSConfig(expected, actual); err != nil { + return err + } + if err := compareAzBlobConfig(expected, actual); err != nil { + return err + } + if err := checkEncryptedSecret(expected.CryptConfig.Passphrase, actual.CryptConfig.Passphrase); err != nil { + return err + } + if expected.CryptConfig.ReadBufferSize != actual.CryptConfig.ReadBufferSize { + return fmt.Errorf("crypt read buffer size mismatch") + } + if expected.CryptConfig.WriteBufferSize != actual.CryptConfig.WriteBufferSize { + return fmt.Errorf("crypt write buffer size mismatch") + } + if err := compareSFTPFsConfig(expected, actual); err != nil { + return err + } + return compareHTTPFsConfig(expected, actual) +} + +func compareS3Config(expected *vfs.Filesystem, actual *vfs.Filesystem) error { //nolint:gocyclo + if expected.S3Config.Bucket != actual.S3Config.Bucket { + return errors.New("fs S3 bucket mismatch") + } + if expected.S3Config.Region != actual.S3Config.Region { + return errors.New("fs S3 region mismatch") + } + if expected.S3Config.AccessKey != actual.S3Config.AccessKey { + return errors.New("fs S3 access key mismatch") + } + if expected.S3Config.RoleARN != actual.S3Config.RoleARN { + return errors.New("fs S3 role ARN mismatch") + } + if err := checkEncryptedSecret(expected.S3Config.AccessSecret, actual.S3Config.AccessSecret); err != nil { + return fmt.Errorf("fs S3 access secret mismatch: %v", err) + } + if err := checkEncryptedSecret(expected.S3Config.SSECustomerKey, actual.S3Config.SSECustomerKey); err != nil { + return fmt.Errorf("fs S3 SSE customer key mismatch: %v", err) + } + if expected.S3Config.Endpoint != actual.S3Config.Endpoint { + return errors.New("fs S3 endpoint mismatch") + } + if expected.S3Config.StorageClass != actual.S3Config.StorageClass { + return errors.New("fs S3 storage class mismatch") + } + if expected.S3Config.ACL != actual.S3Config.ACL { + return errors.New("fs S3 ACL mismatch") + } + if expected.S3Config.UploadPartSize != actual.S3Config.UploadPartSize { + return errors.New("fs S3 upload part size mismatch") + } + if expected.S3Config.UploadConcurrency != actual.S3Config.UploadConcurrency { + return errors.New("fs S3 upload concurrency mismatch") + } + if expected.S3Config.DownloadPartSize != actual.S3Config.DownloadPartSize { + return errors.New("fs S3 download part size mismatch") + } + if expected.S3Config.DownloadConcurrency != actual.S3Config.DownloadConcurrency { + return errors.New("fs S3 download concurrency mismatch") + } + if expected.S3Config.ForcePathStyle != actual.S3Config.ForcePathStyle { + return errors.New("fs S3 force path style mismatch") + } + if expected.S3Config.SkipTLSVerify != actual.S3Config.SkipTLSVerify { + return errors.New("fs S3 skip TLS verify mismatch") + } + if expected.S3Config.DownloadPartMaxTime != actual.S3Config.DownloadPartMaxTime { + return errors.New("fs S3 download part max time mismatch") + } + if expected.S3Config.UploadPartMaxTime != actual.S3Config.UploadPartMaxTime { + return errors.New("fs S3 upload part max time mismatch") + } + if expected.S3Config.KeyPrefix != actual.S3Config.KeyPrefix && + expected.S3Config.KeyPrefix+"/" != actual.S3Config.KeyPrefix { + return errors.New("fs S3 key prefix mismatch") + } + return nil +} + +func compareGCSConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.GCSConfig.Bucket != actual.GCSConfig.Bucket { + return errors.New("GCS bucket mismatch") + } + if expected.GCSConfig.StorageClass != actual.GCSConfig.StorageClass { + return errors.New("GCS storage class mismatch") + } + if expected.GCSConfig.ACL != actual.GCSConfig.ACL { + return errors.New("GCS ACL mismatch") + } + if expected.GCSConfig.KeyPrefix != actual.GCSConfig.KeyPrefix && + expected.GCSConfig.KeyPrefix+"/" != actual.GCSConfig.KeyPrefix { + return errors.New("GCS key prefix mismatch") + } + if expected.GCSConfig.AutomaticCredentials != actual.GCSConfig.AutomaticCredentials { + return errors.New("GCS automatic credentials mismatch") + } + if expected.GCSConfig.UploadPartSize != actual.GCSConfig.UploadPartSize { + return errors.New("GCS upload part size mismatch") + } + if expected.GCSConfig.UploadPartMaxTime != actual.GCSConfig.UploadPartMaxTime { + return errors.New("GCS upload part max time mismatch") + } + return nil +} + +func compareHTTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.HTTPConfig.Endpoint != actual.HTTPConfig.Endpoint { + return errors.New("HTTPFs endpoint mismatch") + } + if expected.HTTPConfig.Username != actual.HTTPConfig.Username { + return errors.New("HTTPFs username mismatch") + } + if expected.HTTPConfig.SkipTLSVerify != actual.HTTPConfig.SkipTLSVerify { + return errors.New("HTTPFs skip_tls_verify mismatch") + } + if expected.SFTPConfig.EqualityCheckMode != actual.SFTPConfig.EqualityCheckMode { + return errors.New("HTTPFs equality_check_mode mismatch") + } + if err := checkEncryptedSecret(expected.HTTPConfig.Password, actual.HTTPConfig.Password); err != nil { + return fmt.Errorf("HTTPFs password mismatch: %v", err) + } + if err := checkEncryptedSecret(expected.HTTPConfig.APIKey, actual.HTTPConfig.APIKey); err != nil { + return fmt.Errorf("HTTPFs API key mismatch: %v", err) + } + return nil +} + +func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.SFTPConfig.Endpoint != actual.SFTPConfig.Endpoint { + return errors.New("SFTPFs endpoint mismatch") + } + if expected.SFTPConfig.Username != actual.SFTPConfig.Username { + return errors.New("SFTPFs username mismatch") + } + if expected.SFTPConfig.DisableCouncurrentReads != actual.SFTPConfig.DisableCouncurrentReads { + return errors.New("SFTPFs disable_concurrent_reads mismatch") + } + if expected.SFTPConfig.BufferSize != actual.SFTPConfig.BufferSize { + return errors.New("SFTPFs buffer_size mismatch") + } + if expected.SFTPConfig.EqualityCheckMode != actual.SFTPConfig.EqualityCheckMode { + return errors.New("SFTPFs equality_check_mode mismatch") + } + if err := checkEncryptedSecret(expected.SFTPConfig.Password, actual.SFTPConfig.Password); err != nil { + return fmt.Errorf("SFTPFs password mismatch: %v", err) + } + if err := checkEncryptedSecret(expected.SFTPConfig.PrivateKey, actual.SFTPConfig.PrivateKey); err != nil { + return fmt.Errorf("SFTPFs private key mismatch: %v", err) + } + if err := checkEncryptedSecret(expected.SFTPConfig.KeyPassphrase, actual.SFTPConfig.KeyPassphrase); err != nil { + return fmt.Errorf("SFTPFs private key passphrase mismatch: %v", err) + } + if expected.SFTPConfig.Prefix != actual.SFTPConfig.Prefix { + if expected.SFTPConfig.Prefix != "" && actual.SFTPConfig.Prefix != "/" { + return errors.New("SFTPFs prefix mismatch") + } + } + if len(expected.SFTPConfig.Fingerprints) != len(actual.SFTPConfig.Fingerprints) { + return errors.New("SFTPFs fingerprints mismatch") + } + for _, value := range actual.SFTPConfig.Fingerprints { + if !slices.Contains(expected.SFTPConfig.Fingerprints, value) { + return errors.New("SFTPFs fingerprints mismatch") + } + } + return nil +} + +func compareAzBlobConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error { + if expected.AzBlobConfig.Container != actual.AzBlobConfig.Container { + return errors.New("azure Blob container mismatch") + } + if expected.AzBlobConfig.AccountName != actual.AzBlobConfig.AccountName { + return errors.New("azure Blob account name mismatch") + } + if err := checkEncryptedSecret(expected.AzBlobConfig.AccountKey, actual.AzBlobConfig.AccountKey); err != nil { + return fmt.Errorf("azure Blob account key mismatch: %v", err) + } + if expected.AzBlobConfig.Endpoint != actual.AzBlobConfig.Endpoint { + return errors.New("azure Blob endpoint mismatch") + } + if err := checkEncryptedSecret(expected.AzBlobConfig.SASURL, actual.AzBlobConfig.SASURL); err != nil { + return fmt.Errorf("azure Blob SAS URL mismatch: %v", err) + } + if expected.AzBlobConfig.UploadPartSize != actual.AzBlobConfig.UploadPartSize { + return errors.New("azure Blob upload part size mismatch") + } + if expected.AzBlobConfig.UploadConcurrency != actual.AzBlobConfig.UploadConcurrency { + return errors.New("azure Blob upload concurrency mismatch") + } + if expected.AzBlobConfig.DownloadPartSize != actual.AzBlobConfig.DownloadPartSize { + return errors.New("azure Blob download part size mismatch") + } + if expected.AzBlobConfig.DownloadConcurrency != actual.AzBlobConfig.DownloadConcurrency { + return errors.New("azure Blob download concurrency mismatch") + } + if expected.AzBlobConfig.KeyPrefix != actual.AzBlobConfig.KeyPrefix && + expected.AzBlobConfig.KeyPrefix+"/" != actual.AzBlobConfig.KeyPrefix { + return errors.New("azure Blob key prefix mismatch") + } + if expected.AzBlobConfig.UseEmulator != actual.AzBlobConfig.UseEmulator { + return errors.New("azure Blob use emulator mismatch") + } + if expected.AzBlobConfig.AccessTier != actual.AzBlobConfig.AccessTier { + return errors.New("azure Blob access tier mismatch") + } + return nil +} + +func areSecretEquals(expected, actual *kms.Secret) bool { + if expected == nil && actual == nil { + return true + } + if expected != nil && expected.IsEmpty() && actual == nil { + return true + } + if actual != nil && actual.IsEmpty() && expected == nil { + return true + } + return false +} + +func checkEncryptedSecret(expected, actual *kms.Secret) error { + if areSecretEquals(expected, actual) { + return nil + } + if expected == nil && actual != nil && !actual.IsEmpty() { + return errors.New("secret mismatch") + } + if actual == nil && expected != nil && !expected.IsEmpty() { + return errors.New("secret mismatch") + } + if expected.IsPlain() && actual.IsEncrypted() { + if actual.GetPayload() == "" { + return errors.New("invalid secret payload") + } + if actual.GetAdditionalData() != "" { + return errors.New("invalid secret additional data") + } + if actual.GetKey() != "" { + return errors.New("invalid secret key") + } + } else { + if expected.GetStatus() != actual.GetStatus() || expected.GetPayload() != actual.GetPayload() { + return errors.New("secret mismatch") + } + } + return nil +} + +func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { + for _, IPMask := range expected.AllowedIP { + if !slices.Contains(actual.AllowedIP, IPMask) { + return errors.New("allowed IP contents mismatch") + } + } + for _, IPMask := range expected.DeniedIP { + if !slices.Contains(actual.DeniedIP, IPMask) { + return errors.New("denied IP contents mismatch") + } + } + for _, method := range expected.DeniedLoginMethods { + if !slices.Contains(actual.DeniedLoginMethods, method) { + return errors.New("denied login methods contents mismatch") + } + } + for _, protocol := range expected.DeniedProtocols { + if !slices.Contains(actual.DeniedProtocols, protocol) { + return errors.New("denied protocols contents mismatch") + } + } + for _, options := range expected.WebClient { + if !slices.Contains(actual.WebClient, options) { + return errors.New("web client options contents mismatch") + } + } + + if len(expected.TLSCerts) != len(actual.TLSCerts) { + return errors.New("TLS certs mismatch") + } + for _, cert := range expected.TLSCerts { + if !slices.Contains(actual.TLSCerts, cert) { + return errors.New("TLS certs content mismatch") + } + } + + return compareUserFiltersEqualFields(expected, actual) +} + +func compareUserFiltersEqualFields(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { + if expected.Hooks.ExternalAuthDisabled != actual.Hooks.ExternalAuthDisabled { + return errors.New("external_auth_disabled hook mismatch") + } + if expected.Hooks.PreLoginDisabled != actual.Hooks.PreLoginDisabled { + return errors.New("pre_login_disabled hook mismatch") + } + if expected.Hooks.CheckPasswordDisabled != actual.Hooks.CheckPasswordDisabled { + return errors.New("check_password_disabled hook mismatch") + } + if expected.DisableFsChecks != actual.DisableFsChecks { + return errors.New("disable_fs_checks mismatch") + } + if expected.StartDirectory != actual.StartDirectory { + return errors.New("start_directory mismatch") + } + return nil +} + +func compareBaseUserFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { //nolint:gocyclo + if len(expected.AllowedIP) != len(actual.AllowedIP) { + return errors.New("allowed IP mismatch") + } + if len(expected.DeniedIP) != len(actual.DeniedIP) { + return errors.New("denied IP mismatch") + } + if len(expected.DeniedLoginMethods) != len(actual.DeniedLoginMethods) { + return errors.New("denied login methods mismatch") + } + if len(expected.DeniedProtocols) != len(actual.DeniedProtocols) { + return errors.New("denied protocols mismatch") + } + if expected.MaxUploadFileSize != actual.MaxUploadFileSize { + return errors.New("max upload file size mismatch") + } + if expected.TLSUsername != actual.TLSUsername { + return errors.New("TLSUsername mismatch") + } + if len(expected.WebClient) != len(actual.WebClient) { + return errors.New("WebClient filter mismatch") + } + if expected.AllowAPIKeyAuth != actual.AllowAPIKeyAuth { + return errors.New("allow_api_key_auth mismatch") + } + if expected.ExternalAuthCacheTime != actual.ExternalAuthCacheTime { + return errors.New("external_auth_cache_time mismatch") + } + if expected.FTPSecurity != actual.FTPSecurity { + return errors.New("ftp_security mismatch") + } + if expected.IsAnonymous != actual.IsAnonymous { + return errors.New("is_anonymous mismatch") + } + if expected.DefaultSharesExpiration != actual.DefaultSharesExpiration { + return errors.New("default_shares_expiration mismatch") + } + if expected.MaxSharesExpiration != actual.MaxSharesExpiration { + return errors.New("max_shares_expiration mismatch") + } + if expected.PasswordExpiration != actual.PasswordExpiration { + return errors.New("password_expiration mismatch") + } + if expected.PasswordStrength != actual.PasswordStrength { + return errors.New("password_strength mismatch") + } + return nil +} + +func compareUserFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { + if err := compareBaseUserFilters(expected, actual); err != nil { + return err + } + if err := compareUserFilterSubStructs(expected, actual); err != nil { + return err + } + if err := compareUserBandwidthLimitFilters(expected, actual); err != nil { + return err + } + if err := compareAccessTimeFilters(expected, actual); err != nil { + return err + } + return compareUserFilePatternsFilters(expected, actual) +} + +func checkFilterMatch(expected []string, actual []string) bool { + if len(expected) != len(actual) { + return false + } + for _, e := range expected { + if !slices.Contains(actual, strings.ToLower(e)) { + return false + } + } + return true +} + +func compareAccessTimeFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { + if len(expected.AccessTime) != len(actual.AccessTime) { + return errors.New("access time filters mismatch") + } + + for idx, p := range expected.AccessTime { + if actual.AccessTime[idx].DayOfWeek != p.DayOfWeek { + return errors.New("access time day of week mismatch") + } + if actual.AccessTime[idx].From != p.From { + return errors.New("access time from mismatch") + } + if actual.AccessTime[idx].To != p.To { + return errors.New("access time to mismatch") + } + } + + return nil +} + +func compareUserBandwidthLimitFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { + if len(expected.BandwidthLimits) != len(actual.BandwidthLimits) { + return errors.New("bandwidth limits filters mismatch") + } + + for idx, l := range expected.BandwidthLimits { + if actual.BandwidthLimits[idx].UploadBandwidth != l.UploadBandwidth { + return errors.New("bandwidth filters upload_bandwidth mismatch") + } + if actual.BandwidthLimits[idx].DownloadBandwidth != l.DownloadBandwidth { + return errors.New("bandwidth filters download_bandwidth mismatch") + } + if len(actual.BandwidthLimits[idx].Sources) != len(l.Sources) { + return errors.New("bandwidth filters sources mismatch") + } + for _, source := range actual.BandwidthLimits[idx].Sources { + if !slices.Contains(l.Sources, source) { + return errors.New("bandwidth filters source mismatch") + } + } + } + + return nil +} + +func compareUserFilePatternsFilters(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error { + if len(expected.FilePatterns) != len(actual.FilePatterns) { + return errors.New("file patterns mismatch") + } + for _, f := range expected.FilePatterns { + found := false + for _, f1 := range actual.FilePatterns { + if path.Clean(f.Path) == path.Clean(f1.Path) && f.DenyPolicy == f1.DenyPolicy { + if !checkFilterMatch(f.AllowedPatterns, f1.AllowedPatterns) || + !checkFilterMatch(f.DeniedPatterns, f1.DeniedPatterns) { + return errors.New("file patterns contents mismatch") + } + found = true + } + } + if !found { + return errors.New("file patterns contents mismatch") + } + } + return nil +} + +func compareRenameConfigs(expected, actual []dataprovider.RenameConfig) error { + if len(expected) != len(actual) { + return errors.New("rename configs mismatch") + } + for _, ex := range expected { + found := false + for _, ac := range actual { + if ac.Key == ex.Key && ac.Value == ex.Value && ac.UpdateModTime == ex.UpdateModTime { + found = true + break + } + } + if !found { + return errors.New("rename configs mismatch") + } + } + return nil +} + +func compareKeyValues(expected, actual []dataprovider.KeyValue) error { + if len(expected) != len(actual) { + return errors.New("key values mismatch") + } + for _, ex := range expected { + found := false + for _, ac := range actual { + if ac.Key == ex.Key && ac.Value == ex.Value { + found = true + break + } + } + if !found { + return errors.New("key values mismatch") + } + } + return nil +} + +func compareHTTPparts(expected, actual []dataprovider.HTTPPart) error { + for _, p1 := range expected { + found := false + for _, p2 := range actual { + if p1.Name == p2.Name { + found = true + if err := compareKeyValues(p1.Headers, p2.Headers); err != nil { + return fmt.Errorf("http headers mismatch for part %q", p1.Name) + } + if p1.Body != p2.Body || p1.Filepath != p2.Filepath { + return fmt.Errorf("http part %q mismatch", p1.Name) + } + } + } + if !found { + return fmt.Errorf("expected http part %q not found", p1.Name) + } + } + return nil +} + +func compareEventActionHTTPConfigFields(expected, actual dataprovider.EventActionHTTPConfig) error { + if expected.Endpoint != actual.Endpoint { + return errors.New("http endpoint mismatch") + } + if expected.Username != actual.Username { + return errors.New("http username mismatch") + } + if err := checkEncryptedSecret(expected.Password, actual.Password); err != nil { + return err + } + if err := compareKeyValues(expected.Headers, actual.Headers); err != nil { + return errors.New("http headers mismatch") + } + if expected.Timeout != actual.Timeout { + return errors.New("http timeout mismatch") + } + if expected.SkipTLSVerify != actual.SkipTLSVerify { + return errors.New("http skip TLS verify mismatch") + } + if expected.Method != actual.Method { + return errors.New("http method mismatch") + } + if err := compareKeyValues(expected.QueryParameters, actual.QueryParameters); err != nil { + return errors.New("http query parameters mismatch") + } + if expected.Body != actual.Body { + return errors.New("http body mismatch") + } + if len(expected.Parts) != len(actual.Parts) { + return errors.New("http parts mismatch") + } + return compareHTTPparts(expected.Parts, actual.Parts) +} + +func compareEventActionEmailConfigFields(expected, actual dataprovider.EventActionEmailConfig) error { + if len(expected.Recipients) != len(actual.Recipients) { + return errors.New("email recipients mismatch") + } + for _, v := range expected.Recipients { + if !slices.Contains(actual.Recipients, v) { + return errors.New("email recipients content mismatch") + } + } + if len(expected.Bcc) != len(actual.Bcc) { + return errors.New("email bcc mismatch") + } + for _, v := range expected.Bcc { + if !slices.Contains(actual.Bcc, v) { + return errors.New("email bcc content mismatch") + } + } + if expected.Subject != actual.Subject { + return errors.New("email subject mismatch") + } + if expected.ContentType != actual.ContentType { + return errors.New("email content type mismatch") + } + if expected.Body != actual.Body { + return errors.New("email body mismatch") + } + if len(expected.Attachments) != len(actual.Attachments) { + return errors.New("email attachments mismatch") + } + for _, v := range expected.Attachments { + if !slices.Contains(actual.Attachments, v) { + return errors.New("email attachments content mismatch") + } + } + return nil +} + +func compareEventActionFsCompressFields(expected, actual dataprovider.EventActionFsCompress) error { + if expected.Name != actual.Name { + return errors.New("fs compress name mismatch") + } + if len(expected.Paths) != len(actual.Paths) { + return errors.New("fs compress paths mismatch") + } + for _, v := range expected.Paths { + if !slices.Contains(actual.Paths, v) { + return errors.New("fs compress paths content mismatch") + } + } + return nil +} + +func compareEventActionFsConfigFields(expected, actual dataprovider.EventActionFilesystemConfig) error { + if expected.Type != actual.Type { + return errors.New("fs type mismatch") + } + if err := compareRenameConfigs(expected.Renames, actual.Renames); err != nil { + return errors.New("fs renames mismatch") + } + if err := compareKeyValues(expected.Copy, actual.Copy); err != nil { + return errors.New("fs copy mismatch") + } + if len(expected.Deletes) != len(actual.Deletes) { + return errors.New("fs deletes mismatch") + } + for _, v := range expected.Deletes { + if !slices.Contains(actual.Deletes, v) { + return errors.New("fs deletes content mismatch") + } + } + if len(expected.MkDirs) != len(actual.MkDirs) { + return errors.New("fs mkdirs mismatch") + } + for _, v := range expected.MkDirs { + if !slices.Contains(actual.MkDirs, v) { + return errors.New("fs mkdir content mismatch") + } + } + if len(expected.Exist) != len(actual.Exist) { + return errors.New("fs exist mismatch") + } + for _, v := range expected.Exist { + if !slices.Contains(actual.Exist, v) { + return errors.New("fs exist content mismatch") + } + } + return compareEventActionFsCompressFields(expected.Compress, actual.Compress) +} + +func compareEventActionIDPConfigFields(expected, actual dataprovider.EventActionIDPAccountCheck) error { + if expected.Mode != actual.Mode { + return errors.New("mode mismatch") + } + if expected.TemplateAdmin != actual.TemplateAdmin { + return errors.New("admin template mismatch") + } + if expected.TemplateUser != actual.TemplateUser { + return errors.New("user template mismatch") + } + return nil +} + +func compareEventActionCmdConfigFields(expected, actual dataprovider.EventActionCommandConfig) error { + if expected.Cmd != actual.Cmd { + return errors.New("command mismatch") + } + if expected.Timeout != actual.Timeout { + return errors.New("cmd timeout mismatch") + } + if len(expected.Args) != len(actual.Args) { + return errors.New("cmd args mismatch") + } + for _, v := range expected.Args { + if !slices.Contains(actual.Args, v) { + return errors.New("cmd args content mismatch") + } + } + if err := compareKeyValues(expected.EnvVars, actual.EnvVars); err != nil { + return errors.New("cmd env vars mismatch") + } + return nil +} + +func compareEventActionDataRetentionFields(expected, actual dataprovider.EventActionDataRetentionConfig) error { + if len(expected.Folders) != len(actual.Folders) { + return errors.New("retention folders mismatch") + } + for _, f1 := range expected.Folders { + found := false + for _, f2 := range actual.Folders { + if f1.Path == f2.Path { + found = true + if f1.Retention != f2.Retention { + return fmt.Errorf("retention mismatch for folder %s", f1.Path) + } + if f1.DeleteEmptyDirs != f2.DeleteEmptyDirs { + return fmt.Errorf("delete_empty_dirs mismatch for folder %s", f1.Path) + } + break + } + } + if !found { + return errors.New("retention folders mismatch") + } + } + return nil +} + +func compareEqualGroupSettingsFields(expected sdk.BaseGroupUserSettings, actual sdk.BaseGroupUserSettings) error { + if expected.HomeDir != actual.HomeDir { + return errors.New("home dir mismatch") + } + if expected.MaxSessions != actual.MaxSessions { + return errors.New("MaxSessions mismatch") + } + if expected.QuotaSize != actual.QuotaSize { + return errors.New("QuotaSize mismatch") + } + if expected.QuotaFiles != actual.QuotaFiles { + return errors.New("QuotaFiles mismatch") + } + if expected.UploadBandwidth != actual.UploadBandwidth { + return errors.New("UploadBandwidth mismatch") + } + if expected.DownloadBandwidth != actual.DownloadBandwidth { + return errors.New("DownloadBandwidth mismatch") + } + if expected.UploadDataTransfer != actual.UploadDataTransfer { + return errors.New("upload_data_transfer mismatch") + } + if expected.DownloadDataTransfer != actual.DownloadDataTransfer { + return errors.New("download_data_transfer mismatch") + } + if expected.TotalDataTransfer != actual.TotalDataTransfer { + return errors.New("total_data_transfer mismatch") + } + if expected.ExpiresIn != actual.ExpiresIn { + return errors.New("expires_in mismatch") + } + return compareUserPermissions(expected.Permissions, actual.Permissions) +} + +func compareEqualsUserFields(expected *dataprovider.User, actual *dataprovider.User) error { + if dataprovider.ConvertName(expected.Username) != actual.Username { + return errors.New("username mismatch") + } + if expected.HomeDir != actual.HomeDir { + return errors.New("home dir mismatch") + } + if expected.UID != actual.UID { + return errors.New("UID mismatch") + } + if expected.GID != actual.GID { + return errors.New("GID mismatch") + } + if expected.MaxSessions != actual.MaxSessions { + return errors.New("MaxSessions mismatch") + } + if len(expected.Permissions) != len(actual.Permissions) { + return errors.New("permissions mismatch") + } + if expected.UploadBandwidth != actual.UploadBandwidth { + return errors.New("UploadBandwidth mismatch") + } + if expected.DownloadBandwidth != actual.DownloadBandwidth { + return errors.New("DownloadBandwidth mismatch") + } + if expected.Status != actual.Status { + return errors.New("status mismatch") + } + if expected.ExpirationDate != actual.ExpirationDate { + return errors.New("ExpirationDate mismatch") + } + if expected.AdditionalInfo != actual.AdditionalInfo { + return errors.New("AdditionalInfo mismatch") + } + if expected.Description != actual.Description { + return errors.New("description mismatch") + } + if expected.Role != actual.Role { + return errors.New("role mismatch") + } + return compareQuotaUserFields(expected, actual) +} + +func compareQuotaUserFields(expected *dataprovider.User, actual *dataprovider.User) error { + if expected.QuotaSize != actual.QuotaSize { + return errors.New("QuotaSize mismatch") + } + if expected.QuotaFiles != actual.QuotaFiles { + return errors.New("QuotaFiles mismatch") + } + if expected.UploadDataTransfer != actual.UploadDataTransfer { + return errors.New("upload_data_transfer mismatch") + } + if expected.DownloadDataTransfer != actual.DownloadDataTransfer { + return errors.New("download_data_transfer mismatch") + } + if expected.TotalDataTransfer != actual.TotalDataTransfer { + return errors.New("total_data_transfer mismatch") + } + return nil +} + +func addLimitAndOffsetQueryParams(rawurl string, limit, offset int64) (*url.URL, error) { + url, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + q := url.Query() + if limit > 0 { + q.Add("limit", strconv.FormatInt(limit, 10)) + } + if offset > 0 { + q.Add("offset", strconv.FormatInt(offset, 10)) + } + url.RawQuery = q.Encode() + return url, err +} + +func addModeQueryParam(rawurl, mode string) (*url.URL, error) { + url, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + q := url.Query() + if len(mode) > 0 { + q.Add("mode", mode) + } + url.RawQuery = q.Encode() + return url, err +} + +func addUpdateUserQueryParams(rawurl, disconnect string) (*url.URL, error) { + url, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + q := url.Query() + if disconnect != "" { + q.Add("disconnect", disconnect) + } + url.RawQuery = q.Encode() + return url, err +} diff --git a/internal/httpdtest/httpfsimpl.go b/internal/httpdtest/httpfsimpl.go new file mode 100644 index 00000000..2389fe21 --- /dev/null +++ b/internal/httpdtest/httpfsimpl.go @@ -0,0 +1,545 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package httpdtest + +import ( + "context" + "errors" + "fmt" + "io" + "mime" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" + "github.com/shirou/gopsutil/v3/disk" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + statPath = "/api/v1/stat" + openPath = "/api/v1/open" + createPath = "/api/v1/create" + renamePath = "/api/v1/rename" + removePath = "/api/v1/remove" + mkdirPath = "/api/v1/mkdir" + chmodPath = "/api/v1/chmod" + chtimesPath = "/api/v1/chtimes" + truncatePath = "/api/v1/truncate" + readdirPath = "/api/v1/readdir" + dirsizePath = "/api/v1/dirsize" + mimetypePath = "/api/v1/mimetype" + statvfsPath = "/api/v1/statvfs" +) + +// HTTPFsCallbacks defines additional callbacks to customize the HTTPfs responses +type HTTPFsCallbacks struct { + Readdir func(string) []os.FileInfo +} + +// StartTestHTTPFs starts a test HTTP service that implements httpfs +// and listens on the specified port +func StartTestHTTPFs(port int, callbacks *HTTPFsCallbacks) error { + fs := httpFsImpl{ + port: port, + callbacks: callbacks, + } + + return fs.Run() +} + +// StartTestHTTPFsOverUnixSocket starts a test HTTP service that implements httpfs +// and listens on the specified UNIX domain socket path +func StartTestHTTPFsOverUnixSocket(socketPath string) error { + fs := httpFsImpl{ + unixSocketPath: socketPath, + } + return fs.Run() +} + +type httpFsImpl struct { + router *chi.Mux + basePath string + port int + unixSocketPath string + callbacks *HTTPFsCallbacks +} + +type apiResponse struct { + Error string `json:"error,omitempty"` + Message string `json:"message,omitempty"` +} + +func (fs *httpFsImpl) sendAPIResponse(w http.ResponseWriter, r *http.Request, err error, message string, code int) { + var errorString string + if err != nil { + errorString = err.Error() + } + resp := apiResponse{ + Error: errorString, + Message: message, + } + ctx := context.WithValue(r.Context(), render.StatusCtxKey, code) + render.JSON(w, r.WithContext(ctx), resp) +} + +func (fs *httpFsImpl) getUsername(r *http.Request) (string, error) { + username, _, ok := r.BasicAuth() + if !ok || username == "" { + return "", os.ErrPermission + } + rootPath := filepath.Join(fs.basePath, username) + _, err := os.Stat(rootPath) + if errors.Is(err, os.ErrNotExist) { + err = os.MkdirAll(rootPath, os.ModePerm) + if err != nil { + return username, err + } + } + return username, nil +} + +func (fs *httpFsImpl) getRespStatus(err error) int { + if errors.Is(err, os.ErrPermission) { + return http.StatusForbidden + } + if errors.Is(err, os.ErrNotExist) { + return http.StatusNotFound + } + + return http.StatusInternalServerError +} + +func (fs *httpFsImpl) stat(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + info, err := os.Stat(fsPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + render.JSON(w, r, getStatFromInfo(info)) +} + +func (fs *httpFsImpl) open(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + var offset int64 + if r.URL.Query().Has("offset") { + offset, err = strconv.ParseInt(r.URL.Query().Get("offset"), 10, 64) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + f, err := os.Open(fsPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + defer f.Close() + + if offset > 0 { + _, err = f.Seek(offset, io.SeekStart) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + } + ctype := mime.TypeByExtension(filepath.Ext(name)) + if ctype != "" { + ctype = "application/octet-stream" + } + w.Header().Set("Content-Type", ctype) + _, err = io.Copy(w, f) + if err != nil { + panic(http.ErrAbortHandler) + } +} + +func (fs *httpFsImpl) create(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + flags := os.O_RDWR | os.O_CREATE | os.O_TRUNC + if r.URL.Query().Has("flags") { + openFlags, err := strconv.ParseInt(r.URL.Query().Get("flags"), 10, 32) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + if openFlags > 0 { + flags = int(openFlags) + } + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + f, err := os.OpenFile(fsPath, flags, 0666) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + defer f.Close() + + _, err = io.Copy(f, r.Body) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "upload OK", http.StatusOK) +} + +func (fs *httpFsImpl) rename(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + target := r.URL.Query().Get("target") + if target == "" { + fs.sendAPIResponse(w, r, nil, "target path cannot be empty", http.StatusBadRequest) + return + } + name := getNameURLParam(r) + sourcePath := filepath.Join(fs.basePath, username, name) + targetPath := filepath.Join(fs.basePath, username, target) + err = os.Rename(sourcePath, targetPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "rename OK", http.StatusOK) +} + +func (fs *httpFsImpl) remove(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + err = os.Remove(fsPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "remove OK", http.StatusOK) +} + +func (fs *httpFsImpl) mkdir(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + err = os.Mkdir(fsPath, os.ModePerm) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "mkdir OK", http.StatusOK) +} + +func (fs *httpFsImpl) chmod(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + mode, err := strconv.ParseUint(r.URL.Query().Get("mode"), 10, 32) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + err = os.Chmod(fsPath, os.FileMode(mode)) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "chmod OK", http.StatusOK) +} + +func (fs *httpFsImpl) chtimes(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + atime, err := time.Parse(time.RFC3339, r.URL.Query().Get("access_time")) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + mtime, err := time.Parse(time.RFC3339, r.URL.Query().Get("modification_time")) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + err = os.Chtimes(fsPath, atime, mtime) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "chtimes OK", http.StatusOK) +} + +func (fs *httpFsImpl) truncate(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + size, err := strconv.ParseInt(r.URL.Query().Get("size"), 10, 64) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + err = os.Truncate(fsPath, size) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + fs.sendAPIResponse(w, r, nil, "chmod OK", http.StatusOK) +} + +func (fs *httpFsImpl) readdir(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + f, err := os.Open(fsPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + list, err := f.Readdir(-1) + f.Close() + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + result := make([]map[string]any, 0, len(list)) + for _, fi := range list { + result = append(result, getStatFromInfo(fi)) + } + if fs.callbacks != nil && fs.callbacks.Readdir != nil { + for _, fi := range fs.callbacks.Readdir(name) { + result = append(result, getStatFromInfo(fi)) + } + } + render.JSON(w, r, result) +} + +func (fs *httpFsImpl) dirsize(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + info, err := os.Stat(fsPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + numFiles := 0 + size := int64(0) + if info.IsDir() { + err = filepath.Walk(fsPath, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info != nil && info.Mode().IsRegular() { + size += info.Size() + numFiles++ + } + return err + }) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + } + render.JSON(w, r, map[string]any{ + "files": numFiles, + "size": size, + }) +} + +func (fs *httpFsImpl) mimetype(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + f, err := os.OpenFile(fsPath, os.O_RDONLY, 0) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + defer f.Close() + var buf [512]byte + n, err := io.ReadFull(f, buf[:]) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + ctype := http.DetectContentType(buf[:n]) + render.JSON(w, r, map[string]any{ + "mime": ctype, + }) +} + +func (fs *httpFsImpl) statvfs(w http.ResponseWriter, r *http.Request) { + username, err := fs.getUsername(r) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + name := getNameURLParam(r) + fsPath := filepath.Join(fs.basePath, username, name) + usage, err := disk.Usage(fsPath) + if err != nil { + fs.sendAPIResponse(w, r, err, "", fs.getRespStatus(err)) + return + } + // we assume block size = 4096 + bsize := uint64(4096) + blocks := usage.Total / bsize + bfree := usage.Free / bsize + files := usage.InodesTotal + ffree := usage.InodesFree + if files == 0 { + // these assumptions are wrong but still better than returning 0 + files = blocks / 4 + ffree = bfree / 4 + } + render.JSON(w, r, map[string]any{ + "bsize": bsize, + "frsize": bsize, + "blocks": blocks, + "bfree": bfree, + "bavail": bfree, + "files": files, + "ffree": ffree, + "favail": ffree, + "namemax": 255, + }) +} + +func (fs *httpFsImpl) configureRouter() { + fs.router = chi.NewRouter() + fs.router.Use(middleware.Recoverer) + + fs.router.Get(statPath+"/{name}", fs.stat) //nolint:goconst + fs.router.Get(openPath+"/{name}", fs.open) + fs.router.Post(createPath+"/{name}", fs.create) + fs.router.Patch(renamePath+"/{name}", fs.rename) + fs.router.Delete(removePath+"/{name}", fs.remove) + fs.router.Post(mkdirPath+"/{name}", fs.mkdir) + fs.router.Patch(chmodPath+"/{name}", fs.chmod) + fs.router.Patch(chtimesPath+"/{name}", fs.chtimes) + fs.router.Patch(truncatePath+"/{name}", fs.truncate) + fs.router.Get(readdirPath+"/{name}", fs.readdir) + fs.router.Get(dirsizePath+"/{name}", fs.dirsize) + fs.router.Get(mimetypePath+"/{name}", fs.mimetype) + fs.router.Get(statvfsPath+"/{name}", fs.statvfs) +} + +func (fs *httpFsImpl) Run() error { + fs.basePath = filepath.Join(os.TempDir(), "httpfs") + if err := os.RemoveAll(fs.basePath); err != nil { + return err + } + if err := os.MkdirAll(fs.basePath, os.ModePerm); err != nil { + return err + } + fs.configureRouter() + + httpServer := http.Server{ + Addr: fmt.Sprintf(":%d", fs.port), + Handler: fs.router, + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 16, // 64KB + } + + if fs.unixSocketPath == "" { + return httpServer.ListenAndServe() + } + err := os.Remove(fs.unixSocketPath) + if err != nil && !os.IsNotExist(err) { + return err + } + listener, err := net.Listen("unix", fs.unixSocketPath) + if err != nil { + return err + } + return httpServer.Serve(listener) +} + +func getStatFromInfo(info os.FileInfo) map[string]any { + return map[string]any{ + "name": info.Name(), + "size": info.Size(), + "mode": info.Mode(), + "last_modified": info.ModTime(), + } +} + +func getNameURLParam(r *http.Request) string { + v := chi.URLParam(r, "name") + unescaped, err := url.PathUnescape(v) + if err != nil { + return util.CleanPath(v) + } + return util.CleanPath(unescaped) +} diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go new file mode 100644 index 00000000..01493bc4 --- /dev/null +++ b/internal/jwt/jwt.go @@ -0,0 +1,268 @@ +// Copyright (C) 2025 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package jwt provides functionality for creating, parsing, and validating +// JSON Web Tokens (JWT) used in authentication and authorization workflows. +package jwt + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strings" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/rs/xid" +) + +const ( + CookieKey = "jwt" +) + +var ( + TokenCtxKey = &contextKey{"Token"} + ErrorCtxKey = &contextKey{"Error"} +) + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. This technique +// for defining context keys was copied from Go 1.7's new use of context in net/http. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { + return "jwt context value " + k.name +} + +func NewClaims(audience, ip string, duration time.Duration) *Claims { + now := time.Now() + claims := &Claims{} + claims.IssuedAt = jwt.NewNumericDate(now) + claims.NotBefore = jwt.NewNumericDate(now.Add(-10 * time.Second)) + claims.Expiry = jwt.NewNumericDate(now.Add(duration)) + claims.Audience = []string{audience, ip} + return claims +} + +type Claims struct { + jwt.Claims + Username string `json:"username,omitempty"` + Permissions []string `json:"permissions,omitempty"` + Role string `json:"role,omitempty"` + APIKeyID string `json:"api_key,omitempty"` + NodeID string `json:"node_id,omitempty"` + MustSetTwoFactorAuth bool `json:"2fa_required,omitempty"` + MustChangePassword bool `json:"chpwd,omitempty"` + RequiredTwoFactorProtocols []string `json:"2fa_protos,omitempty"` + HideUserPageSections int `json:"hus,omitempty"` + Ref string `json:"ref,omitempty"` +} + +func (c *Claims) SetIssuedAt(t time.Time) { + c.IssuedAt = jwt.NewNumericDate(t) +} + +func (c *Claims) SetNotBefore(t time.Time) { + c.NotBefore = jwt.NewNumericDate(t) +} + +func (c *Claims) SetExpiry(t time.Time) { + c.Expiry = jwt.NewNumericDate(t) +} + +func (c *Claims) HasPerm(perm string) bool { + for _, p := range c.Permissions { + if p == "*" || p == perm { + return true + } + } + return false +} + +func (c *Claims) HasAnyAudience(audiences []string) bool { + for _, a := range c.Audience { + if slices.Contains(audiences, a) { + return true + } + } + return false +} + +func (c *Claims) GenerateTokenResponse(signer *Signer) (TokenResponse, error) { + token, err := signer.Sign(c) + if err != nil { + return TokenResponse{}, err + } + return c.BuildTokenResponse(token), nil +} + +func (c *Claims) BuildTokenResponse(token string) TokenResponse { + return TokenResponse{Token: token, Expiry: c.Expiry.Time().UTC().Format(time.RFC3339)} +} + +type TokenResponse struct { + Token string `json:"access_token"` + Expiry string `json:"expires_at"` +} + +func NewSigner(algo jose.SignatureAlgorithm, key any) (*Signer, error) { + opts := (&jose.SignerOptions{}).WithType("JWT") + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: algo, Key: key}, opts) + if err != nil { + return nil, err + } + return &Signer{ + signer: signer, + algo: []jose.SignatureAlgorithm{algo}, + key: key, + }, nil +} + +type Signer struct { + algo []jose.SignatureAlgorithm + signer jose.Signer + key any +} + +func (s *Signer) Sign(claims *Claims) (string, error) { + if claims.ID == "" { + claims.ID = xid.New().String() + } + if claims.IssuedAt == nil { + claims.IssuedAt = jwt.NewNumericDate(time.Now()) + } + if claims.NotBefore == nil { + claims.NotBefore = jwt.NewNumericDate(time.Now().Add(-10 * time.Second)) + } + if claims.Expiry == nil { + return "", errors.New("expiration must be set") + } + if len(claims.Audience) == 0 { + return "", errors.New("audience must be set") + } + + return jwt.Signed(s.signer).Claims(claims).Serialize() +} + +func (s *Signer) Signer() jose.Signer { + return s.signer +} + +func (s *Signer) SetSigner(signer jose.Signer) { + s.signer = signer +} + +func (s *Signer) SignWithParams(claims *Claims, audience, ip string, duration time.Duration) (string, error) { + claims.Expiry = jwt.NewNumericDate(time.Now().Add(duration)) + claims.Audience = []string{audience, ip} + return s.Sign(claims) +} + +func NewContext(ctx context.Context, claims *Claims, err error) context.Context { + ctx = context.WithValue(ctx, TokenCtxKey, claims) + ctx = context.WithValue(ctx, ErrorCtxKey, err) + return ctx +} + +func FromContext(ctx context.Context) (*Claims, error) { + val := ctx.Value(TokenCtxKey) + token, ok := val.(*Claims) + if !ok && val != nil { + return nil, fmt.Errorf("invalid type for TokenCtxKey: %T", val) + } + + valErr := ctx.Value(ErrorCtxKey) + err, ok := valErr.(error) + if !ok && valErr != nil { + return nil, fmt.Errorf("invalid type for ErrorCtxKey: %T", valErr) + } + if token == nil { + return nil, errors.New("no token found") + } + + return token, err +} + +func Verify(s *Signer, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + hfn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + token, err := VerifyRequest(s, r, findTokenFns...) + ctx = NewContext(ctx, token, err) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(hfn) + } +} + +func VerifyRequest(s *Signer, r *http.Request, findTokenFns ...func(r *http.Request) string) (*Claims, error) { + var tokenString string + for _, fn := range findTokenFns { + tokenString = fn(r) + if tokenString != "" { + break + } + } + if tokenString == "" { + return nil, errors.New("no token found") + } + return VerifyToken(s, tokenString) +} + +func VerifyToken(s *Signer, payload string) (*Claims, error) { + return VerifyTokenWithKey(payload, s.algo, s.key) +} + +func VerifyTokenWithKey(payload string, algo []jose.SignatureAlgorithm, key any) (*Claims, error) { + token, err := jwt.ParseSigned(payload, algo) + if err != nil { + return nil, err + } + var claims Claims + err = token.Claims(key, &claims) + if err != nil { + return nil, err + } + if err := claims.ValidateWithLeeway(jwt.Expected{Time: time.Now()}, 30*time.Second); err != nil { + return nil, err + } + return &claims, nil +} + +// TokenFromCookie tries to retrieve the token string from a cookie named +// "jwt". +func TokenFromCookie(r *http.Request) string { + cookie, err := r.Cookie(CookieKey) + if err != nil { + return "" + } + return cookie.Value +} + +// TokenFromHeader tries to retrieve the token string from the +// "Authorization" request header: "Authorization: BEARER T". +func TokenFromHeader(r *http.Request) string { + // Get token from authorization header. + bearer := r.Header.Get("Authorization") + const prefix = "Bearer " + if len(bearer) >= len(prefix) && strings.EqualFold(bearer[:len(prefix)], prefix) { + return bearer[len(prefix):] + } + return "" +} diff --git a/internal/jwt/jwt_test.go b/internal/jwt/jwt_test.go new file mode 100644 index 00000000..cc7fcddb --- /dev/null +++ b/internal/jwt/jwt_test.go @@ -0,0 +1,255 @@ +package jwt + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/go-jose/go-jose/v4/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type failingJoseSigner struct{} + +func (s *failingJoseSigner) Sign(payload []byte) (*jose.JSONWebSignature, error) { + return nil, errors.New("sign test error") +} + +func (s *failingJoseSigner) Options() jose.SignerOptions { + return jose.SignerOptions{} +} + +func TestJWTToken(t *testing.T) { + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + username := util.GenerateUniqueID() + claims := Claims{ + Username: username, + Claims: jwt.Claims{ + Audience: jwt.Audience{"test"}, + Expiry: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + NotBefore: jwt.NewNumericDate(time.Now()), + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + } + token, err := s.Sign(&claims) + require.NoError(t, err) + require.NotEmpty(t, token) + + parsed, err := VerifyToken(s, token) + require.NoError(t, err) + require.Equal(t, username, parsed.Username) + + ja1, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + + token, err = ja1.Sign(&claims) + require.NoError(t, err) + require.NotEmpty(t, token) + _, err = VerifyToken(s, token) + require.Error(t, err) + _, err = VerifyToken(ja1, token) + require.NoError(t, err) +} + +func TestClaims(t *testing.T) { + claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute) + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + token, err := s.Sign(claims) + require.NoError(t, err) + assert.NotEmpty(t, token) + assert.NotNil(t, claims.Expiry) + assert.NotNil(t, claims.IssuedAt) + assert.NotNil(t, claims.NotBefore) + + claims = &Claims{ + Permissions: []string{"myperm"}, + } + claims.SetExpiry(time.Now().Add(1 * time.Minute)) + claims.Audience = []string{"testaudience"} + _, err = s.Sign(claims) + assert.NoError(t, err) + assert.NotNil(t, claims.IssuedAt) + assert.NotNil(t, claims.NotBefore) + assert.True(t, claims.HasAnyAudience([]string{util.GenerateUniqueID(), util.GenerateUniqueID(), "testaudience"})) + assert.False(t, claims.HasAnyAudience([]string{util.GenerateUniqueID()})) + assert.True(t, claims.HasPerm("myperm")) + assert.False(t, claims.HasPerm(util.GenerateUniqueID())) + resp, err := claims.GenerateTokenResponse(s) + require.NoError(t, err) + assert.NotEmpty(t, resp.Token) + assert.Equal(t, claims.Expiry.Time().UTC().Format(time.RFC3339), resp.Expiry) + claims.SetIssuedAt(time.Now()) + claims.SetNotBefore(time.Now().Add(10 * time.Minute)) + token, err = s.SignWithParams(claims, util.GenerateUniqueID(), "127.0.0.1", time.Minute) + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "nbf") + claims = &Claims{} + _, err = s.Sign(claims) + assert.ErrorContains(t, err, "expiration must be set") + claims.SetExpiry(time.Now()) + _, err = s.Sign(claims) + assert.ErrorContains(t, err, "audience must be set") + claims = &Claims{} + _, err = s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute) + assert.NoError(t, err) +} + +func TestClaimsPermissions(t *testing.T) { + c := Claims{ + Permissions: []string{"*"}, + } + assert.True(t, c.HasPerm(util.GenerateUniqueID())) + c.Permissions = []string{"list"} + assert.False(t, c.HasPerm(util.GenerateUniqueID())) + assert.True(t, c.HasPerm("list")) +} + +func TestErrors(t *testing.T) { + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + _, err = VerifyToken(s, util.GenerateUniqueID()) + assert.Error(t, err) + claims := &Claims{} + claims.SetExpiry(time.Now().Add(-1 * time.Minute)) + token, err := jwt.Signed(s.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "exp") + claims.SetExpiry(time.Now().Add(2 * time.Minute)) + claims.SetIssuedAt(time.Now().Add(1 * time.Minute)) + token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "iat") + claims.SetIssuedAt(time.Now()) + claims.SetNotBefore(time.Now().Add(1 * time.Minute)) + token, err = jwt.Signed(s.Signer()).Claims(claims).Serialize() + assert.NoError(t, err) + _, err = VerifyToken(s, token) + assert.ErrorContains(t, err, "nbf") + + s.SetSigner(&failingJoseSigner{}) + claims = NewClaims(util.GenerateUniqueID(), "", time.Minute) + _, err = s.Sign(claims) + assert.Error(t, err) + _, err = claims.GenerateTokenResponse(s) + assert.Error(t, err) + // Wrong algorithm + _, err = NewSigner("PS256", util.GenerateRandomBytes(32)) + assert.Error(t, err) +} + +func TestTokenFromRequest(t *testing.T) { + claims := NewClaims(util.GenerateUniqueID(), "", 10*time.Minute) + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + token, err := s.Sign(claims) + require.NoError(t, err) + assert.NotEmpty(t, token) + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.Header.Set("Cookie", fmt.Sprintf("jwt=%s", token)) + cookie := TokenFromCookie(req) + assert.Equal(t, token, cookie) + req, err = http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + _, err = VerifyRequest(s, req, TokenFromHeader) + assert.NoError(t, err) + req.Header.Set("Authorization", token) + assert.Empty(t, TokenFromHeader(req)) + assert.Empty(t, TokenFromCookie(req)) + _, err = VerifyRequest(s, req, TokenFromCookie) + assert.ErrorContains(t, err, "no token found") +} + +func TestContext(t *testing.T) { + claims := &Claims{ + Username: util.GenerateUniqueID(), + } + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + token, err := s.SignWithParams(claims, util.GenerateUniqueID(), "", time.Minute) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + h := Verify(s, TokenFromHeader) + wrapped := h(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, err := FromContext(r.Context()) + assert.Nil(t, err) + assert.Equal(t, claims.Username, token.Username) + w.WriteHeader(http.StatusOK) + })) + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) + + _, err = FromContext(context.Background()) + assert.ErrorContains(t, err, "no token found") + + ctx := NewContext(context.Background(), &Claims{}, fs.ErrClosed) + _, err = FromContext(ctx) + assert.Equal(t, fs.ErrClosed, err) + + ctx = context.WithValue(context.Background(), TokenCtxKey, "1") + _, err = FromContext(ctx) + assert.ErrorContains(t, err, "invalid type for TokenCtxKey") + + ctx = context.WithValue(context.Background(), ErrorCtxKey, 2) + _, err = FromContext(ctx) + assert.ErrorContains(t, err, "invalid type for ErrorCtxKey") + claims = NewClaims(util.GenerateUniqueID(), "127.1.1.1", time.Minute) + _, err = s.Sign(claims) + require.NoError(t, err) + ctx = context.WithValue(context.Background(), TokenCtxKey, claims) + claimsFromContext, err := FromContext(ctx) + assert.NoError(t, err) + assert.Equal(t, claims, claimsFromContext) + + assert.Equal(t, "jwt context value Token", TokenCtxKey.String()) +} + +func TestValidationLeeway(t *testing.T) { + s, err := NewSigner(jose.HS256, util.GenerateRandomBytes(32)) + require.NoError(t, err) + claims := &Claims{} + claims.Audience = []string{util.GenerateUniqueID()} + claims.SetIssuedAt(time.Now().Add(10 * time.Second)) // issued at in the future + claims.SetExpiry(time.Now().Add(10 * time.Second)) + token, err := s.Sign(claims) + require.NoError(t, err) + _, err = VerifyToken(s, token) + assert.NoError(t, err) + + claims = &Claims{} + claims.Audience = []string{util.GenerateUniqueID()} + claims.SetExpiry(time.Now().Add(-10 * time.Second)) // expired + token, err = s.Sign(claims) + require.NoError(t, err) + _, err = VerifyToken(s, token) + assert.NoError(t, err) + + claims = &Claims{} + claims.Audience = []string{util.GenerateUniqueID()} + claims.SetExpiry(time.Now().Add(30 * time.Second)) + claims.SetNotBefore(time.Now().Add(10 * time.Second)) // not before in the future + token, err = s.Sign(claims) + require.NoError(t, err) + _, err = VerifyToken(s, token) + assert.NoError(t, err) +} diff --git a/internal/kms/basesecret.go b/internal/kms/basesecret.go new file mode 100644 index 00000000..be2e98ed --- /dev/null +++ b/internal/kms/basesecret.go @@ -0,0 +1,85 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package kms + +import ( + sdkkms "github.com/sftpgo/sdk/kms" +) + +// BaseSecret defines the base struct shared among all the secret providers +type BaseSecret struct { + Status sdkkms.SecretStatus `json:"status,omitempty"` + Payload string `json:"payload,omitempty"` + Key string `json:"key,omitempty"` + AdditionalData string `json:"additional_data,omitempty"` + // 1 means encrypted using a master key + Mode int `json:"mode,omitempty"` +} + +// GetStatus returns the secret's status +func (s *BaseSecret) GetStatus() sdkkms.SecretStatus { + return s.Status +} + +// GetPayload returns the secret's payload +func (s *BaseSecret) GetPayload() string { + return s.Payload +} + +// GetKey returns the secret's key +func (s *BaseSecret) GetKey() string { + return s.Key +} + +// GetMode returns the encryption mode +func (s *BaseSecret) GetMode() int { + return s.Mode +} + +// GetAdditionalData returns the secret's additional data +func (s *BaseSecret) GetAdditionalData() string { + return s.AdditionalData +} + +// SetKey sets the secret's key +func (s *BaseSecret) SetKey(value string) { + s.Key = value +} + +// SetAdditionalData sets the secret's additional data +func (s *BaseSecret) SetAdditionalData(value string) { + s.AdditionalData = value +} + +// SetStatus sets the secret's status +func (s *BaseSecret) SetStatus(value sdkkms.SecretStatus) { + s.Status = value +} + +func (s *BaseSecret) isEmpty() bool { + if s.Status != "" { + return false + } + if s.Payload != "" { + return false + } + if s.Key != "" { + return false + } + if s.AdditionalData != "" { + return false + } + return true +} diff --git a/internal/kms/builtin.go b/internal/kms/builtin.go new file mode 100644 index 00000000..5072bdd3 --- /dev/null +++ b/internal/kms/builtin.go @@ -0,0 +1,155 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package kms + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + + sdkkms "github.com/sftpgo/sdk/kms" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +var ( + errMalformedCiphertext = errors.New("malformed ciphertext") +) + +type builtinSecret struct { + BaseSecret +} + +func init() { + RegisterSecretProvider(sdkkms.SchemeBuiltin, sdkkms.SecretStatusAES256GCM, newBuiltinSecret) +} + +func newBuiltinSecret(base BaseSecret, _, _ string) SecretProvider { + return &builtinSecret{ + BaseSecret: base, + } +} + +func (s *builtinSecret) Name() string { + return "Builtin" +} + +func (s *builtinSecret) IsEncrypted() bool { + return s.Status == sdkkms.SecretStatusAES256GCM +} + +func (s *builtinSecret) deriveKey(key []byte) []byte { + var combined []byte + combined = append(combined, key...) + if s.AdditionalData != "" { + combined = append(combined, []byte(s.AdditionalData)...) + } + combined = append(combined, key...) + hash := sha256.Sum256(combined) + return hash[:] +} + +func (s *builtinSecret) Encrypt() error { + if s.Payload == "" { + return ErrInvalidSecret + } + switch s.Status { + case sdkkms.SecretStatusPlain: + key := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return err + } + block, err := aes.NewCipher(s.deriveKey(key)) + if err != nil { + return err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return err + } + var aad []byte + if s.AdditionalData != "" { + aad = []byte(s.AdditionalData) + } + ciphertext := gcm.Seal(nonce, nonce, []byte(s.Payload), aad) + s.Key = hex.EncodeToString(key) + s.Payload = hex.EncodeToString(ciphertext) + s.Status = sdkkms.SecretStatusAES256GCM + return nil + default: + return ErrWrongSecretStatus + } +} + +func (s *builtinSecret) Decrypt() error { + switch s.Status { + case sdkkms.SecretStatusAES256GCM: + encrypted, err := hex.DecodeString(s.Payload) + if err != nil { + return err + } + key, err := hex.DecodeString(s.Key) + if err != nil { + return err + } + block, err := aes.NewCipher(s.deriveKey(key)) + if err != nil { + return err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return err + } + nonceSize := gcm.NonceSize() + if len(encrypted) < nonceSize { + return errMalformedCiphertext + } + nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] + var aad []byte + if s.AdditionalData != "" { + aad = []byte(s.AdditionalData) + } + plaintext, err := gcm.Open(nil, nonce, ciphertext, aad) + if err != nil { + return err + } + s.Status = sdkkms.SecretStatusPlain + s.Payload = util.BytesToString(plaintext) + s.Key = "" + s.AdditionalData = "" + return nil + default: + return ErrWrongSecretStatus + } +} + +func (s *builtinSecret) Clone() SecretProvider { + baseSecret := BaseSecret{ + Status: s.Status, + Payload: s.Payload, + Key: s.Key, + AdditionalData: s.AdditionalData, + Mode: s.Mode, + } + return newBuiltinSecret(baseSecret, "", "") +} diff --git a/internal/kms/kms.go b/internal/kms/kms.go new file mode 100644 index 00000000..914af4ea --- /dev/null +++ b/internal/kms/kms.go @@ -0,0 +1,432 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package kms provides Key Management Services support +package kms + +import ( + "encoding/json" + "errors" + "strings" + "sync" + + sdkkms "github.com/sftpgo/sdk/kms" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// SecretProvider defines the interface for a KMS secrets provider +type SecretProvider interface { + Name() string + Encrypt() error + Decrypt() error + IsEncrypted() bool + GetStatus() sdkkms.SecretStatus + GetPayload() string + GetKey() string + GetAdditionalData() string + GetMode() int + SetKey(string) + SetAdditionalData(string) + SetStatus(sdkkms.SecretStatus) + Clone() SecretProvider +} + +const ( + logSender = "kms" +) + +// Configuration defines the KMS configuration +type Configuration struct { + Secrets Secrets `json:"secrets" mapstructure:"secrets"` +} + +// Secrets define the KMS configuration for encryption/decryption +type Secrets struct { + URL string `json:"url" mapstructure:"url"` + MasterKeyPath string `json:"master_key_path" mapstructure:"master_key_path"` + MasterKeyString string `json:"master_key" mapstructure:"master_key"` + masterKey string +} + +type registeredSecretProvider struct { + encryptedStatus sdkkms.SecretStatus + newFn func(base BaseSecret, url, masterKey string) SecretProvider +} + +var ( + // ErrWrongSecretStatus defines the error to return if the secret status is not appropriate + // for the request operation + ErrWrongSecretStatus = errors.New("wrong secret status") + // ErrInvalidSecret defines the error to return if a secret is not valid + ErrInvalidSecret = errors.New("invalid secret") + validSecretStatuses = []string{sdkkms.SecretStatusPlain, sdkkms.SecretStatusAES256GCM, sdkkms.SecretStatusSecretBox, + sdkkms.SecretStatusVaultTransit, sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, sdkkms.SecretStatusAzureKeyVault, + sdkkms.SecretStatusOracleKeyVault, sdkkms.SecretStatusRedacted} + config Configuration + secretProviders = make(map[string]registeredSecretProvider) +) + +// RegisterSecretProvider register a new secret provider +func RegisterSecretProvider(scheme string, encryptedStatus sdkkms.SecretStatus, + fn func(base BaseSecret, url, masterKey string) SecretProvider, +) { + secretProviders[scheme] = registeredSecretProvider{ + encryptedStatus: encryptedStatus, + newFn: fn, + } +} + +// NewSecret builds a new Secret using the provided arguments +func NewSecret(status sdkkms.SecretStatus, payload, key, data string) *Secret { + return config.newSecret(status, payload, key, data) +} + +// NewEmptySecret returns an empty secret +func NewEmptySecret() *Secret { + return NewSecret("", "", "", "") +} + +// NewPlainSecret stores the give payload in a plain text secret +func NewPlainSecret(payload string) *Secret { + return NewSecret(sdkkms.SecretStatusPlain, strings.TrimSpace(payload), "", "") +} + +// Initialize configures the KMS support +func (c *Configuration) Initialize() error { + if c.Secrets.MasterKeyPath != "" { + mKey, err := util.ReadConfigFromFile(c.Secrets.MasterKeyPath, "") + if err != nil { + return err + } + c.Secrets.masterKey = mKey + } else if c.Secrets.MasterKeyString != "" { + c.Secrets.masterKey = c.Secrets.MasterKeyString + } + config = *c + if config.Secrets.URL == "" { + config.Secrets.URL = sdkkms.SchemeLocal + "://" + } + for k, v := range secretProviders { + logger.Info(logSender, "", "secret provider registered for scheme: %q, encrypted status: %q", + k, v.encryptedStatus) + } + return nil +} + +func (c *Configuration) newSecret(status sdkkms.SecretStatus, payload, key, data string) *Secret { + base := BaseSecret{ + Status: status, + Key: key, + Payload: payload, + AdditionalData: data, + } + return &Secret{ + provider: c.getSecretProvider(base), + } +} + +func (c *Configuration) getSecretProvider(base BaseSecret) SecretProvider { + for k, v := range secretProviders { + if strings.HasPrefix(c.Secrets.URL, k) { + return v.newFn(base, c.Secrets.URL, c.Secrets.masterKey) + } + } + logger.Warn(logSender, "", "no secret provider registered for URL %v, fallback to local provider", c.Secrets.URL) + return NewLocalSecret(base, c.Secrets.URL, c.Secrets.masterKey) +} + +// Secret defines the struct used to store confidential data +type Secret struct { + sync.RWMutex + provider SecretProvider +} + +// MarshalJSON return the JSON encoding of the Secret object +func (s *Secret) MarshalJSON() ([]byte, error) { + s.RLock() + defer s.RUnlock() + + return json.Marshal(&BaseSecret{ + Status: s.provider.GetStatus(), + Payload: s.provider.GetPayload(), + Key: s.provider.GetKey(), + AdditionalData: s.provider.GetAdditionalData(), + Mode: s.provider.GetMode(), + }) +} + +// UnmarshalJSON parses the JSON-encoded data and stores the result +// in the Secret object +func (s *Secret) UnmarshalJSON(data []byte) error { + s.Lock() + defer s.Unlock() + + baseSecret := BaseSecret{} + err := json.Unmarshal(data, &baseSecret) + if err != nil { + return err + } + if baseSecret.isEmpty() { + s.provider = config.getSecretProvider(baseSecret) + return nil + } + + if baseSecret.Status == sdkkms.SecretStatusPlain || baseSecret.Status == sdkkms.SecretStatusRedacted { + s.provider = config.getSecretProvider(baseSecret) + return nil + } + + for _, v := range secretProviders { + if v.encryptedStatus == baseSecret.Status { + s.provider = v.newFn(baseSecret, config.Secrets.URL, config.Secrets.masterKey) + return nil + } + } + logger.Error(logSender, "", "no provider registered for status %q", baseSecret.Status) + return ErrInvalidSecret +} + +// IsEqual returns true if all the secrets fields are equal +func (s *Secret) IsEqual(other *Secret) bool { + if s.GetStatus() != other.GetStatus() { + return false + } + if s.GetPayload() != other.GetPayload() { + return false + } + if s.GetKey() != other.GetKey() { + return false + } + if s.GetAdditionalData() != other.GetAdditionalData() { + return false + } + if s.GetMode() != other.GetMode() { + return false + } + return true +} + +// Clone returns a copy of the secret object +func (s *Secret) Clone() *Secret { + s.RLock() + defer s.RUnlock() + + return &Secret{ + provider: s.provider.Clone(), + } +} + +// IsEncrypted returns true if the secret is encrypted +// This isn't a pointer receiver because we don't want to pass +// a pointer to html template +func (s *Secret) IsEncrypted() bool { + s.RLock() + defer s.RUnlock() + + return s.provider.IsEncrypted() +} + +// IsPlain returns true if the secret is in plain text +func (s *Secret) IsPlain() bool { + s.RLock() + defer s.RUnlock() + + return s.provider.GetStatus() == sdkkms.SecretStatusPlain +} + +// IsNotPlainAndNotEmpty returns true if the secret is not plain and not empty. +// This is an utility method, we update the secret for an existing user +// if it is empty or plain +func (s *Secret) IsNotPlainAndNotEmpty() bool { + s.RLock() + defer s.RUnlock() + + return !s.IsPlain() && !s.IsEmpty() +} + +// IsRedacted returns true if the secret is redacted +func (s *Secret) IsRedacted() bool { + s.RLock() + defer s.RUnlock() + + return s.provider.GetStatus() == sdkkms.SecretStatusRedacted +} + +// GetPayload returns the secret payload +func (s *Secret) GetPayload() string { + s.RLock() + defer s.RUnlock() + + return s.provider.GetPayload() +} + +// GetAdditionalData returns the secret additional data +func (s *Secret) GetAdditionalData() string { + s.RLock() + defer s.RUnlock() + + return s.provider.GetAdditionalData() +} + +// GetStatus returns the secret status +func (s *Secret) GetStatus() sdkkms.SecretStatus { + s.RLock() + defer s.RUnlock() + + return s.provider.GetStatus() +} + +// GetKey returns the secret key +func (s *Secret) GetKey() string { + s.RLock() + defer s.RUnlock() + + return s.provider.GetKey() +} + +// GetMode returns the secret mode +func (s *Secret) GetMode() int { + s.RLock() + defer s.RUnlock() + + return s.provider.GetMode() +} + +// SetAdditionalData sets the given additional data +func (s *Secret) SetAdditionalData(value string) { + s.Lock() + defer s.Unlock() + + s.provider.SetAdditionalData(value) +} + +// SetStatus sets the status for this secret +func (s *Secret) SetStatus(value sdkkms.SecretStatus) { + s.Lock() + defer s.Unlock() + + s.provider.SetStatus(value) +} + +// SetKey sets the key for this secret +func (s *Secret) SetKey(value string) { + s.Lock() + defer s.Unlock() + + s.provider.SetKey(value) +} + +// IsEmpty returns true if all fields are empty +func (s *Secret) IsEmpty() bool { + s.RLock() + defer s.RUnlock() + + if s.provider.GetStatus() != "" { + return false + } + if s.provider.GetPayload() != "" { + return false + } + if s.provider.GetKey() != "" { + return false + } + if s.provider.GetAdditionalData() != "" { + return false + } + return true +} + +// IsValid returns true if the secret is not empty and valid +func (s *Secret) IsValid() bool { + s.RLock() + defer s.RUnlock() + + if !s.IsValidInput() { + return false + } + switch s.provider.GetStatus() { + case sdkkms.SecretStatusAES256GCM, sdkkms.SecretStatusSecretBox: + if len(s.provider.GetKey()) != 64 { + return false + } + case sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, sdkkms.SecretStatusVaultTransit: + key := s.provider.GetKey() + if key != "" && len(key) != 64 { + return false + } + } + return true +} + +// IsValidInput returns true if the secret is a valid user input +func (s *Secret) IsValidInput() bool { + s.RLock() + defer s.RUnlock() + + if !isSecretStatusValid(s.provider.GetStatus()) { + return false + } + if s.provider.GetPayload() == "" { + return false + } + return true +} + +// Hide hides info to decrypt data +func (s *Secret) Hide() { + s.Lock() + defer s.Unlock() + + s.provider.SetKey("") + s.provider.SetAdditionalData("") +} + +// Encrypt encrypts a plain text Secret object +func (s *Secret) Encrypt() error { + s.Lock() + defer s.Unlock() + + return s.provider.Encrypt() +} + +// Decrypt decrypts a Secret object +func (s *Secret) Decrypt() error { + s.Lock() + defer s.Unlock() + + return s.provider.Decrypt() +} + +// TryDecrypt decrypts a Secret object if encrypted. +// It returns a nil error if the object is not encrypted +func (s *Secret) TryDecrypt() error { + s.Lock() + defer s.Unlock() + + if s.provider.IsEncrypted() { + return s.provider.Decrypt() + } + return nil +} + +func isSecretStatusValid(status string) bool { + for idx := range validSecretStatuses { + if validSecretStatuses[idx] == status { + return true + } + } + return false +} diff --git a/internal/kms/local.go b/internal/kms/local.go new file mode 100644 index 00000000..fe92511c --- /dev/null +++ b/internal/kms/local.go @@ -0,0 +1,158 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package kms + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "io" + + sdkkms "github.com/sftpgo/sdk/kms" + "gocloud.dev/secrets/localsecrets" + "golang.org/x/crypto/hkdf" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +func init() { + RegisterSecretProvider(sdkkms.SchemeLocal, sdkkms.SecretStatusSecretBox, NewLocalSecret) +} + +type localSecret struct { + BaseSecret + masterKey string +} + +// NewLocalSecret returns a SecretProvider that use a locally provided symmetric key +func NewLocalSecret(base BaseSecret, _, masterKey string) SecretProvider { + return &localSecret{ + BaseSecret: base, + masterKey: masterKey, + } +} + +func (s *localSecret) Name() string { + return "Local" +} + +func (s *localSecret) IsEncrypted() bool { + return s.Status == sdkkms.SecretStatusSecretBox +} + +func (s *localSecret) Encrypt() error { + if s.Status != sdkkms.SecretStatusPlain { + return ErrWrongSecretStatus + } + if s.Payload == "" { + return ErrInvalidSecret + } + secretKey, err := localsecrets.NewRandomKey() + if err != nil { + return err + } + key, err := s.deriveKey(secretKey[:], false) + if err != nil { + return err + } + keeper := localsecrets.NewKeeper(key) + defer keeper.Close() + + ciphertext, err := keeper.Encrypt(context.Background(), []byte(s.Payload)) + if err != nil { + return err + } + s.Key = hex.EncodeToString(secretKey[:]) + s.Payload = base64.StdEncoding.EncodeToString(ciphertext) + s.Status = sdkkms.SecretStatusSecretBox + s.Mode = s.getEncryptionMode() + return nil +} + +func (s *localSecret) Decrypt() error { + if !s.IsEncrypted() { + return ErrWrongSecretStatus + } + encrypted, err := base64.StdEncoding.DecodeString(s.Payload) + if err != nil { + return err + } + secretKey, err := hex.DecodeString(s.Key) + if err != nil { + return err + } + key, err := s.deriveKey(secretKey[:], true) + if err != nil { + return err + } + keeper := localsecrets.NewKeeper(key) + defer keeper.Close() + + plaintext, err := keeper.Decrypt(context.Background(), encrypted) + if err != nil { + return err + } + s.Status = sdkkms.SecretStatusPlain + s.Payload = util.BytesToString(plaintext) + s.Key = "" + s.AdditionalData = "" + s.Mode = 0 + return nil +} + +func (s *localSecret) deriveKey(key []byte, isForDecryption bool) ([32]byte, error) { + var masterKey []byte + if s.masterKey == "" || (isForDecryption && s.Mode == 0) { + var combined []byte + combined = append(combined, key...) + if s.AdditionalData != "" { + combined = append(combined, []byte(s.AdditionalData)...) + } + combined = append(combined, key...) + hash := sha256.Sum256(combined) + masterKey = hash[:] + } else { + masterKey = []byte(s.masterKey) + } + var derivedKey [32]byte + var info []byte + if s.AdditionalData != "" { + info = []byte(s.AdditionalData) + } + kdf := hkdf.New(sha256.New, masterKey, key, info) + if _, err := io.ReadFull(kdf, derivedKey[:]); err != nil { + return derivedKey, err + } + return derivedKey, nil +} + +func (s *localSecret) getEncryptionMode() int { + if s.masterKey == "" { + return 0 + } + return 1 +} + +func (s *localSecret) Clone() SecretProvider { + baseSecret := BaseSecret{ + Status: s.Status, + Payload: s.Payload, + Key: s.Key, + AdditionalData: s.AdditionalData, + Mode: s.Mode, + } + return NewLocalSecret(baseSecret, "", s.masterKey) +} diff --git a/internal/logger/hclog.go b/internal/logger/hclog.go new file mode 100644 index 00000000..b05afe15 --- /dev/null +++ b/internal/logger/hclog.go @@ -0,0 +1,96 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package logger + +import ( + "io" + "log" + + "github.com/hashicorp/go-hclog" + "github.com/rs/zerolog" +) + +// HCLogAdapter is an adapter for hclog.Logger +type HCLogAdapter struct { + hclog.Logger +} + +// Log emits a message and key/value pairs at a provided log level +func (l *HCLogAdapter) Log(level hclog.Level, msg string, args ...any) { + // Workaround to avoid logging plugin arguments that may contain sensitive data. + // Check everytime we update go-plugin library. + if msg == "starting plugin" { + return + } + var ev *zerolog.Event + switch level { + case hclog.Info: + ev = logger.Info() + case hclog.Warn: + ev = logger.Warn() + case hclog.Error: + ev = logger.Error() + default: + ev = logger.Debug() + } + ev.Timestamp().Str("sender", l.Name()) + addKeysAndValues(ev, args...) + ev.Msg(msg) +} + +// Trace emits a message and key/value pairs at the TRACE level +func (l *HCLogAdapter) Trace(msg string, args ...any) { + l.Log(hclog.Debug, msg, args...) +} + +// Debug emits a message and key/value pairs at the DEBUG level +func (l *HCLogAdapter) Debug(msg string, args ...any) { + l.Log(hclog.Debug, msg, args...) +} + +// Info emits a message and key/value pairs at the INFO level +func (l *HCLogAdapter) Info(msg string, args ...any) { + l.Log(hclog.Info, msg, args...) +} + +// Warn emits a message and key/value pairs at the WARN level +func (l *HCLogAdapter) Warn(msg string, args ...any) { + l.Log(hclog.Warn, msg, args...) +} + +// Error emits a message and key/value pairs at the ERROR level +func (l *HCLogAdapter) Error(msg string, args ...any) { + l.Log(hclog.Error, msg, args...) +} + +// With creates a sub-logger +func (l *HCLogAdapter) With(args ...any) hclog.Logger { + return &HCLogAdapter{Logger: l.Logger.With(args...)} +} + +// Named creates a logger that will prepend the name string on the front of all messages +func (l *HCLogAdapter) Named(name string) hclog.Logger { + return &HCLogAdapter{Logger: l.Logger.Named(name)} +} + +// StandardLogger returns a value that conforms to the stdlib log.Logger interface +func (l *HCLogAdapter) StandardLogger(_ *hclog.StandardLoggerOptions) *log.Logger { + return log.New(&StdLoggerWrapper{Sender: l.Name()}, "", 0) +} + +// StandardWriter returns a value that conforms to io.Writer, which can be passed into log.SetOutput() +func (l *HCLogAdapter) StandardWriter(_ *hclog.StandardLoggerOptions) io.Writer { + return &StdLoggerWrapper{Sender: l.Name()} +} diff --git a/internal/logger/lego.go b/internal/logger/lego.go new file mode 100644 index 00000000..0503e536 --- /dev/null +++ b/internal/logger/lego.go @@ -0,0 +1,72 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package logger + +import "fmt" + +const ( + legoLogSender = "lego" +) + +// LegoAdapter is an adapter for lego.StdLogger +type LegoAdapter struct { + LogToConsole bool +} + +// Fatal emits a log at Error level +func (l *LegoAdapter) Fatal(args ...any) { + if l.LogToConsole { + ErrorToConsole("%s", fmt.Sprint(args...)) + return + } + Log(LevelError, legoLogSender, "", "%s", fmt.Sprint(args...)) +} + +// Fatalln is the same as Fatal +func (l *LegoAdapter) Fatalln(args ...any) { + l.Fatal(args...) +} + +// Fatalf emits a log at Error level +func (l *LegoAdapter) Fatalf(format string, args ...any) { + if l.LogToConsole { + ErrorToConsole(format, args...) + return + } + Log(LevelError, legoLogSender, "", format, args...) +} + +// Print emits a log at Info level +func (l *LegoAdapter) Print(args ...any) { + if l.LogToConsole { + InfoToConsole("%s", fmt.Sprint(args...)) + return + } + Log(LevelInfo, legoLogSender, "", "%s", fmt.Sprint(args...)) +} + +// Println is the same as Print +func (l *LegoAdapter) Println(args ...any) { + l.Print(args...) +} + +// Printf emits a log at Info level +func (l *LegoAdapter) Printf(format string, args ...any) { + if l.LogToConsole { + InfoToConsole(format, args...) + return + } + Log(LevelInfo, legoLogSender, "", format, args...) +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 00000000..59ff863f --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,382 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package logger provides logging capabilities. +// It is a wrapper around zerolog for logging and lumberjack for log rotation. +// Logs are written to the specified log file. +// Logging on the console is provided to print initialization info, errors and warnings. +// The package provides a request logger to log the HTTP requests for REST API too. +// The request logger uses chi.middleware.RequestLogger, +// chi.middleware.LogFormatter and chi.middleware.LogEntry to build a structured +// logger using zerolog +package logger + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "time" + + "github.com/rs/zerolog" + lumberjack "gopkg.in/natefinch/lumberjack.v2" +) + +const ( + dateFormat = "2006-01-02T15:04:05.000" // YYYY-MM-DDTHH:MM:SS.ZZZ +) + +// LogLevel defines log levels. +type LogLevel uint8 + +// defines our own log levels, just in case we'll change logger in future +const ( + LevelDebug LogLevel = iota + LevelInfo + LevelWarn + LevelError +) + +var ( + logger zerolog.Logger + consoleLogger zerolog.Logger + rollingLogger *lumberjack.Logger +) + +func init() { + zerolog.TimeFieldFormat = dateFormat +} + +// GetLogger get the configured logger instance +func GetLogger() *zerolog.Logger { + return &logger +} + +// InitLogger configures the logger using the given parameters +func InitLogger(logFilePath string, logMaxSize int, logMaxBackups int, logMaxAge int, logCompress, logUTCTime bool, + level zerolog.Level, +) { + SetLogTime(logUTCTime) + if isLogFilePathValid(logFilePath) { + logDir := filepath.Dir(logFilePath) + if _, err := os.Stat(logDir); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(logDir, os.ModePerm) + if err != nil { + fmt.Printf("unable to create log dir %q: %v", logDir, err) + } + } + rollingLogger = &lumberjack.Logger{ + Filename: logFilePath, + MaxSize: logMaxSize, + MaxBackups: logMaxBackups, + MaxAge: logMaxAge, + Compress: logCompress, + LocalTime: !logUTCTime, + } + logger = zerolog.New(rollingLogger) + EnableConsoleLogger(level) + } else { + logger = zerolog.New(&logSyncWrapper{ + output: os.Stdout, + }) + consoleLogger = zerolog.Nop() + } + logger = logger.Level(level) +} + +// InitStdErrLogger configures the logger to write to stderr +func InitStdErrLogger(level zerolog.Level) { + logger = zerolog.New(&logSyncWrapper{ + output: os.Stderr, + }).Level(level) + consoleLogger = zerolog.Nop() +} + +// DisableLogger disable the main logger. +// ConsoleLogger will not be affected +func DisableLogger() { + logger = zerolog.Nop() + rollingLogger = nil +} + +// EnableConsoleLogger enables the console logger +func EnableConsoleLogger(level zerolog.Level) { + consoleOutput := zerolog.ConsoleWriter{ + Out: os.Stdout, + TimeFormat: dateFormat, + } + consoleLogger = zerolog.New(consoleOutput).With().Timestamp().Logger().Level(level) +} + +// RotateLogFile closes the existing log file and immediately create a new one +func RotateLogFile() error { + if rollingLogger != nil { + return rollingLogger.Rotate() + } + return errors.New("logging to file is disabled") +} + +// SetLogTime sets logging time related setting +func SetLogTime(utc bool) { + if utc { + zerolog.TimestampFunc = func() time.Time { + return time.Now().UTC() + } + } else { + zerolog.TimestampFunc = time.Now + } +} + +// Log logs at the specified level for the specified sender +func Log(level LogLevel, sender string, connectionID string, format string, v ...any) { + var ev *zerolog.Event + switch level { + case LevelDebug: + ev = logger.Debug() + case LevelInfo: + ev = logger.Info() + case LevelWarn: + ev = logger.Warn() + default: + ev = logger.Error() + } + ev.Timestamp().Str("sender", sender) + if connectionID != "" { + ev.Str("connection_id", connectionID) + } + ev.Msg(fmt.Sprintf(format, v...)) +} + +// Debug logs at debug level for the specified sender +func Debug(sender, connectionID, format string, v ...any) { + Log(LevelDebug, sender, connectionID, format, v...) +} + +// Info logs at info level for the specified sender +func Info(sender, connectionID, format string, v ...any) { + Log(LevelInfo, sender, connectionID, format, v...) +} + +// Warn logs at warn level for the specified sender +func Warn(sender, connectionID, format string, v ...any) { + Log(LevelWarn, sender, connectionID, format, v...) +} + +// Error logs at error level for the specified sender +func Error(sender, connectionID, format string, v ...any) { + Log(LevelError, sender, connectionID, format, v...) +} + +// DebugToConsole logs at debug level to stdout +func DebugToConsole(format string, v ...any) { + consoleLogger.Debug().Msg(fmt.Sprintf(format, v...)) +} + +// InfoToConsole logs at info level to stdout +func InfoToConsole(format string, v ...any) { + consoleLogger.Info().Msg(fmt.Sprintf(format, v...)) +} + +// WarnToConsole logs at info level to stdout +func WarnToConsole(format string, v ...any) { + consoleLogger.Warn().Msg(fmt.Sprintf(format, v...)) +} + +// ErrorToConsole logs at error level to stdout +func ErrorToConsole(format string, v ...any) { + consoleLogger.Error().Msg(fmt.Sprintf(format, v...)) +} + +// TransferLog logs uploads or downloads +func TransferLog(operation, path string, elapsed int64, size int64, user, connectionID, protocol, localAddr, + remoteAddr, ftpMode string, err error, +) { + var ev *zerolog.Event + if err != nil { + ev = logger.Error() + } else { + ev = logger.Info() + } + ev. + Timestamp(). + Str("sender", operation). + Str("local_addr", localAddr). + Str("remote_addr", remoteAddr). + Int64("elapsed_ms", elapsed). + Int64("size_bytes", size). + Str("username", user). + Str("file_path", path). + Str("connection_id", connectionID). + Str("protocol", protocol) + if ftpMode != "" { + ev.Str("ftp_mode", ftpMode) + } + ev.AnErr("error", err).Send() +} + +// CommandLog logs an SFTP/SCP/SSH command +func CommandLog(command, path, target, user, fileMode, connectionID, protocol string, uid, gid int, atime, mtime, + sshCommand string, size int64, localAddr, remoteAddr string, elapsed int64) { + logger.Info(). + Timestamp(). + Str("sender", command). + Str("local_addr", localAddr). + Str("remote_addr", remoteAddr). + Str("username", user). + Str("file_path", path). + Str("target_path", target). + Str("filemode", fileMode). + Int("uid", uid). + Int("gid", gid). + Str("access_time", atime). + Str("modification_time", mtime). + Int64("size", size). + Int64("elapsed", elapsed). + Str("ssh_command", sshCommand). + Str("connection_id", connectionID). + Str("protocol", protocol). + Send() +} + +// ConnectionFailedLog logs failed attempts to initialize a connection. +// A connection can fail for an authentication error or other errors such as +// a client abort or a time out if the login does not happen in two minutes. +// These logs are useful for better integration with Fail2ban and similar tools. +func ConnectionFailedLog(user, ip, loginType, protocol, errorString string) { + logger.Debug(). + Timestamp(). + Str("sender", "connection_failed"). + Str("client_ip", ip). + Str("username", user). + Str("login_type", loginType). + Str("protocol", protocol). + Str("error", errorString). + Send() +} + +// LoginLog logs successful logins. +func LoginLog(user, ip, loginMethod, protocol, connectionID, clientVersion string, encrypted bool, info string) { + ev := logger.Info() + ev.Timestamp(). + Str("sender", "login"). + Str("ip", ip). + Str("username", user). + Str("method", loginMethod). + Str("protocol", protocol) + if connectionID != "" { + ev.Str("connection_id", connectionID) + } + ev.Str("client", clientVersion). + Bool("encrypted", encrypted) + if info != "" { + ev.Str("info", info) + } + ev.Send() +} + +func isLogFilePathValid(logFilePath string) bool { + cleanInput := filepath.Clean(logFilePath) + if cleanInput == "." || cleanInput == ".." { + return false + } + return true +} + +// StdLoggerWrapper is a wrapper for standard logger compatibility +type StdLoggerWrapper struct { + Sender string +} + +// Write implements the io.Writer interface. This is useful to set as a writer +// for the standard library log. +func (l *StdLoggerWrapper) Write(p []byte) (n int, err error) { + n = len(p) + if n > 0 && p[n-1] == '\n' { + // Trim CR added by stdlog. + p = p[0 : n-1] + } + + Log(LevelError, l.Sender, "", "%s", p) + return +} + +// LeveledLogger is a logger that accepts a message string and a variadic number of key-value pairs +type LeveledLogger struct { + Sender string + additionalKeyVals []any +} + +func addKeysAndValues(ev *zerolog.Event, keysAndValues ...any) { + kvLen := len(keysAndValues) + if kvLen%2 != 0 { + extra := keysAndValues[kvLen-1] + keysAndValues = append(keysAndValues[:kvLen-1], "EXTRA_VALUE_AT_END", extra) + } + for i := 0; i < len(keysAndValues); i += 2 { + key, val := keysAndValues[i], keysAndValues[i+1] + if keyStr, ok := key.(string); ok && keyStr != "timestamp" { + ev.Str(keyStr, fmt.Sprintf("%v", val)) + } + } +} + +// Error logs at error level for the specified sender +func (l *LeveledLogger) Error(msg string, keysAndValues ...any) { + ev := logger.Error() + ev.Timestamp().Str("sender", l.Sender) + if len(l.additionalKeyVals) > 0 { + addKeysAndValues(ev, l.additionalKeyVals...) + } + addKeysAndValues(ev, keysAndValues...) + ev.Msg(msg) +} + +// Info logs at info level for the specified sender +func (l *LeveledLogger) Info(msg string, keysAndValues ...any) { + ev := logger.Info() + ev.Timestamp().Str("sender", l.Sender) + if len(l.additionalKeyVals) > 0 { + addKeysAndValues(ev, l.additionalKeyVals...) + } + addKeysAndValues(ev, keysAndValues...) + ev.Msg(msg) +} + +// Debug logs at debug level for the specified sender +func (l *LeveledLogger) Debug(msg string, keysAndValues ...any) { + ev := logger.Debug() + ev.Timestamp().Str("sender", l.Sender) + if len(l.additionalKeyVals) > 0 { + addKeysAndValues(ev, l.additionalKeyVals...) + } + addKeysAndValues(ev, keysAndValues...) + ev.Msg(msg) +} + +// Warn logs at warn level for the specified sender +func (l *LeveledLogger) Warn(msg string, keysAndValues ...any) { + ev := logger.Warn() + ev.Timestamp().Str("sender", l.Sender) + if len(l.additionalKeyVals) > 0 { + addKeysAndValues(ev, l.additionalKeyVals...) + } + addKeysAndValues(ev, keysAndValues...) + ev.Msg(msg) +} + +// Panic logs the panic at error level for the specified sender +func (l *LeveledLogger) Panic(msg string, keysAndValues ...any) { + l.Error(msg, keysAndValues...) +} diff --git a/internal/logger/mail.go b/internal/logger/mail.go new file mode 100644 index 00000000..5f6ae713 --- /dev/null +++ b/internal/logger/mail.go @@ -0,0 +1,66 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package logger + +import ( + "fmt" + + "github.com/wneessen/go-mail/log" +) + +const ( + mailLogSender = "smtpclient" +) + +// MailAdapter is an adapter for mail.Logger +type MailAdapter struct { + ConnectionID string +} + +// Errorf emits a log at Error level +func (l *MailAdapter) Errorf(logMsg log.Log) { + format := l.getFormatString(&logMsg) + ErrorToConsole(format, logMsg.Messages...) + Log(LevelError, mailLogSender, l.ConnectionID, format, logMsg.Messages...) +} + +// Warnf emits a log at Warn level +func (l *MailAdapter) Warnf(logMsg log.Log) { + format := l.getFormatString(&logMsg) + WarnToConsole(format, logMsg.Messages...) + Log(LevelWarn, mailLogSender, l.ConnectionID, format, logMsg.Messages...) +} + +// Infof emits a log at Info level +func (l *MailAdapter) Infof(logMsg log.Log) { + format := l.getFormatString(&logMsg) + InfoToConsole(format, logMsg.Messages...) + Log(LevelInfo, mailLogSender, l.ConnectionID, format, logMsg.Messages...) +} + +// Debugf emits a log at Debug level +func (l *MailAdapter) Debugf(logMsg log.Log) { + format := l.getFormatString(&logMsg) + DebugToConsole(format, logMsg.Messages...) + Log(LevelDebug, mailLogSender, l.ConnectionID, format, logMsg.Messages...) +} + +func (*MailAdapter) getFormatString(logMsg *log.Log) string { + p := "C <-- S:" + if logMsg.Direction == log.DirClientToServer { + p = "C --> S:" + } + return fmt.Sprintf("%s %s", p, logMsg.Format) +} diff --git a/internal/logger/request_logger.go b/internal/logger/request_logger.go new file mode 100644 index 00000000..325f44b9 --- /dev/null +++ b/internal/logger/request_logger.go @@ -0,0 +1,119 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package logger + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "time" + + "github.com/go-chi/chi/v5/middleware" + "github.com/rs/zerolog" + + "github.com/drakkan/sftpgo/v2/internal/metric" +) + +// StructuredLogger defines a simple wrapper around zerolog logger. +// It implements chi.middleware.LogFormatter interface +type StructuredLogger struct { + Logger *zerolog.Logger +} + +// StructuredLoggerEntry defines a log entry. +// It implements chi.middleware.LogEntry interface +type StructuredLoggerEntry struct { + // The zerolog logger + Logger *zerolog.Logger + // fields to write in the log + fields map[string]any +} + +// NewStructuredLogger returns a chi.middleware.RequestLogger using our StructuredLogger. +// This structured logger is called by the chi.middleware.Logger handler to log each HTTP request +func NewStructuredLogger(logger *zerolog.Logger) func(next http.Handler) http.Handler { + return middleware.RequestLogger(&StructuredLogger{logger}) +} + +// NewLogEntry creates a new log entry for an HTTP request +func (l *StructuredLogger) NewLogEntry(r *http.Request) middleware.LogEntry { + scheme := "http" + cipherSuite := "" + if r.TLS != nil { + scheme = "https" + cipherSuite = tls.CipherSuiteName(r.TLS.CipherSuite) + } + + fields := map[string]any{ + "local_addr": getLocalAddress(r), + "remote_addr": r.RemoteAddr, + "proto": r.Proto, + "method": r.Method, + "user_agent": r.UserAgent(), + "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI), + "cipher_suite": cipherSuite, + } + + reqID := middleware.GetReqID(r.Context()) + if reqID != "" { + fields["request_id"] = reqID + } + + return &StructuredLoggerEntry{Logger: l.Logger, fields: fields} +} + +// Write logs a new entry at the end of the HTTP request +func (l *StructuredLoggerEntry) Write(status, bytes int, _ http.Header, elapsed time.Duration, _ any) { + metric.HTTPRequestServed(status) + var ev *zerolog.Event + if status >= http.StatusInternalServerError { + ev = l.Logger.Error() + } else if status >= http.StatusBadRequest { + ev = l.Logger.Warn() + } else { + ev = l.Logger.Debug() + } + ev. + Timestamp(). + Str("sender", "httpd"). + Fields(l.fields). + Int("resp_status", status). + Int("resp_size", bytes). + Int64("elapsed_ms", elapsed.Nanoseconds()/1000000). + Send() +} + +// Panic logs panics +func (l *StructuredLoggerEntry) Panic(v any, stack []byte) { + l.Logger.Error(). + Timestamp(). + Str("sender", "httpd"). + Fields(l.fields). + Str("stack", string(stack)). + Str("panic", fmt.Sprintf("%+v", v)). + Send() +} + +func getLocalAddress(r *http.Request) string { + if r == nil { + return "" + } + localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr) + if ok { + return localAddr.String() + } + return "" +} diff --git a/internal/logger/slog.go b/internal/logger/slog.go new file mode 100644 index 00000000..e9a544d0 --- /dev/null +++ b/internal/logger/slog.go @@ -0,0 +1,97 @@ +// Copyright (C) 2025 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package logger + +import ( + "context" + "log/slog" + "slices" + + "github.com/rs/zerolog" +) + +// slogAdapter is an adapter for slog.Handler +type slogAdapter struct { + sender string + attrs []slog.Attr +} + +// NewSlogAdapter creates a slog.Handler adapter +func NewSlogAdapter(sender string, attrs []slog.Attr) *slogAdapter { + return &slogAdapter{ + sender: sender, + attrs: attrs, + } +} + +func (l *slogAdapter) Enabled(ctx context.Context, level slog.Level) bool { + // Log level is handled by our implementation + return true +} + +func (l *slogAdapter) Handle(ctx context.Context, r slog.Record) error { + var ev *zerolog.Event + switch r.Level { + case slog.LevelDebug: + ev = logger.Debug() + case slog.LevelInfo: + ev = logger.Info() + case slog.LevelWarn: + ev = logger.Warn() + case slog.LevelError: + ev = logger.Error() + default: + ev = logger.Debug() + } + + ev.Timestamp() + if l.sender != "" { + ev.Str("sender", l.sender) + } + + addSlogAttr := func(a slog.Attr) { + if a.Key == "time" { + return + } + ev.Any(a.Key, a.Value.Any()) + } + + for _, a := range l.attrs { + addSlogAttr(a) + } + + r.Attrs(func(a slog.Attr) bool { + addSlogAttr(a) + return true + }) + + ev.Msg(r.Message) + + return nil +} + +func (l *slogAdapter) WithAttrs(attrs []slog.Attr) slog.Handler { + newHandler := *l + newHandler.attrs = slices.Concat(l.attrs, attrs) + return &newHandler +} + +func (l *slogAdapter) WithGroup(name string) slog.Handler { + newHandler := *l + if name != "" { + newHandler.sender = name + } + return &newHandler +} diff --git a/internal/logger/sync_wrapper.go b/internal/logger/sync_wrapper.go new file mode 100644 index 00000000..b87ac11c --- /dev/null +++ b/internal/logger/sync_wrapper.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package logger + +import ( + "os" + "sync" +) + +type logSyncWrapper struct { + sync.Mutex + output *os.File +} + +func (l *logSyncWrapper) Write(b []byte) (n int, err error) { + l.Lock() + defer l.Unlock() + return l.output.Write(b) +} diff --git a/internal/metric/metric.go b/internal/metric/metric.go new file mode 100644 index 00000000..4e3e95d8 --- /dev/null +++ b/internal/metric/metric.go @@ -0,0 +1,1007 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nometrics + +// Package metric provides Prometheus metrics support +package metric + +import ( + "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + loginMethodPublicKey = "publickey" + loginMethodKeyboardInteractive = "keyboard-interactive" + loginMethodKeyAndPassword = "publickey+password" + loginMethodKeyAndKeyboardInt = "publickey+keyboard-interactive" + loginMethodTLSCertificate = "TLSCertificate" + loginMethodTLSCertificateAndPwd = "TLSCertificate+password" + loginMethodIDP = "IDP" +) + +func init() { + version.AddFeature("+metrics") +} + +var ( + // dataproviderAvailability is the metric that reports the availability for the configured data provider + dataproviderAvailability = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "sftpgo_dataprovider_availability", + Help: "Availability for the configured data provider, 1 means OK, 0 KO", + }) + + // activeConnections is the metric that reports the total number of active connections + activeConnections = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "sftpgo_active_connections", + Help: "Total number of logged in users", + }) + + // totalUploads is the metric that reports the total number of successful uploads + totalUploads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_uploads_total", + Help: "The total number of successful uploads", + }) + + // totalDownloads is the metric that reports the total number of successful downloads + totalDownloads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_downloads_total", + Help: "The total number of successful downloads", + }) + + // totalUploadErrors is the metric that reports the total number of upload errors + totalUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_upload_errors_total", + Help: "The total number of upload errors", + }) + + // totalDownloadErrors is the metric that reports the total number of download errors + totalDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_download_errors_total", + Help: "The total number of download errors", + }) + + // totalUploadSize is the metric that reports the total uploads size as bytes + totalUploadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_upload_size", + Help: "The total upload size as bytes, partial uploads are included", + }) + + // totalDownloadSize is the metric that reports the total downloads size as bytes + totalDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_download_size", + Help: "The total download size as bytes, partial downloads are included", + }) + + // totalSSHCommands is the metric that reports the total number of executed SSH commands + totalSSHCommands = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_ssh_commands_total", + Help: "The total number of executed SSH commands", + }) + + // totalSSHCommandErrors is the metric that reports the total number of SSH command errors + totalSSHCommandErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_ssh_command_errors_total", + Help: "The total number of SSH command errors", + }) + + // totalLoginAttempts is the metric that reports the total number of login attempts + totalLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_login_attempts_total", + Help: "The total number of login attempts", + }) + + // totalNoAuthTried is te metric that reports the total number of clients disconnected + // for inactivity before trying to login + totalNoAuthTried = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_no_auth_total", + Help: "The total number of clients disconnected for inactivity before trying to login", + }) + + // totalLoginOK is the metric that reports the total number of successful logins + totalLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_login_ok_total", + Help: "The total number of successful logins", + }) + + // totalLoginFailed is the metric that reports the total number of failed logins + totalLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_login_ko_total", + Help: "The total number of failed logins", + }) + + // totalPasswordLoginAttempts is the metric that reports the total number of login attempts + // using a password + totalPasswordLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_password_login_attempts_total", + Help: "The total number of login attempts using a password", + }) + + // totalPasswordLoginOK is the metric that reports the total number of successful logins + // using a password + totalPasswordLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_password_login_ok_total", + Help: "The total number of successful logins using a password", + }) + + // totalPasswordLoginFailed is the metric that reports the total number of failed logins + // using a password + totalPasswordLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_password_login_ko_total", + Help: "The total number of failed logins using a password", + }) + + // totalKeyLoginAttempts is the metric that reports the total number of login attempts + // using a public key + totalKeyLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_public_key_login_attempts_total", + Help: "The total number of login attempts using a public key", + }) + + // totalKeyLoginOK is the metric that reports the total number of successful logins + // using a public key + totalKeyLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_public_key_login_ok_total", + Help: "The total number of successful logins using a public key", + }) + + // totalKeyLoginFailed is the metric that reports the total number of failed logins + // using a public key + totalKeyLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_public_key_login_ko_total", + Help: "The total number of failed logins using a public key", + }) + + // totalTLSCertLoginAttempts is the metric that reports the total number of login attempts + // using a TLS certificate + totalTLSCertLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_tls_cert_login_attempts_total", + Help: "The total number of login attempts using a TLS certificate", + }) + + // totalTLSCertLoginOK is the metric that reports the total number of successful logins + // using a TLS certificate + totalTLSCertLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_tls_cert_login_ok_total", + Help: "The total number of successful logins using a TLS certificate", + }) + + // totalTLSCertLoginFailed is the metric that reports the total number of failed logins + // using a TLS certificate + totalTLSCertLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_tls_cert_login_ko_total", + Help: "The total number of failed logins using a TLS certificate", + }) + + // totalTLSCertAndPwdLoginAttempts is the metric that reports the total number of login attempts + // using a TLS certificate+password + totalTLSCertAndPwdLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_tls_cert_and_pwd_login_attempts_total", + Help: "The total number of login attempts using a TLS certificate+password", + }) + + // totalTLSCertLoginOK is the metric that reports the total number of successful logins + // using a TLS certificate+password + totalTLSCertAndPwdLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_tls_cert_and_pwd_login_ok_total", + Help: "The total number of successful logins using a TLS certificate+password", + }) + + // totalTLSCertAndPwdLoginFailed is the metric that reports the total number of failed logins + // using a TLS certificate+password + totalTLSCertAndPwdLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_tls_cert_and_pwd_login_ko_total", + Help: "The total number of failed logins using a TLS certificate+password", + }) + + // totalInteractiveLoginAttempts is the metric that reports the total number of login attempts + // using keyboard interactive authentication + totalInteractiveLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_keyboard_interactive_login_attempts_total", + Help: "The total number of login attempts using keyboard interactive authentication", + }) + + // totalInteractiveLoginOK is the metric that reports the total number of successful logins + // using keyboard interactive authentication + totalInteractiveLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_keyboard_interactive_login_ok_total", + Help: "The total number of successful logins using keyboard interactive authentication", + }) + + // totalInteractiveLoginFailed is the metric that reports the total number of failed logins + // using keyboard interactive authentication + totalInteractiveLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_keyboard_interactive_login_ko_total", + Help: "The total number of failed logins using keyboard interactive authentication", + }) + + // totalKeyAndPasswordLoginAttempts is the metric that reports the total number of + // login attempts using public key + password multi steps auth + totalKeyAndPasswordLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_key_and_password_login_attempts_total", + Help: "The total number of login attempts using public key + password", + }) + + // totalKeyAndPasswordLoginOK is the metric that reports the total number of + // successful logins using public key + password multi steps auth + totalKeyAndPasswordLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_key_and_password_login_ok_total", + Help: "The total number of successful logins using public key + password", + }) + + // totalKeyAndPasswordLoginFailed is the metric that reports the total number of + // failed logins using public key + password multi steps auth + totalKeyAndPasswordLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_key_and_password_login_ko_total", + Help: "The total number of failed logins using public key + password", + }) + + // totalKeyAndKeyIntLoginAttempts is the metric that reports the total number of + // login attempts using public key + keyboard interactive multi steps auth + totalKeyAndKeyIntLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_key_and_keyboard_int_login_attempts_total", + Help: "The total number of login attempts using public key + keyboard interactive", + }) + + // totalKeyAndKeyIntLoginOK is the metric that reports the total number of + // successful logins using public key + keyboard interactive multi steps auth + totalKeyAndKeyIntLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_key_and_keyboard_int_login_ok_total", + Help: "The total number of successful logins using public key + keyboard interactive", + }) + + // totalKeyAndKeyIntLoginFailed is the metric that reports the total number of + // failed logins using public key + keyboard interactive multi steps auth + totalKeyAndKeyIntLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_key_and_keyboard_int_login_ko_total", + Help: "The total number of failed logins using public key + keyboard interactive", + }) + + // totalIDPLoginAttempts is the metric that reports the total number of + // login attempts using identity providers + totalIDPLoginAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_idp_login_attempts_total", + Help: "The total number of login attempts using Identity Providers", + }) + + // totalIDPLoginOK is the metric that reports the total number of + // successful logins using identity providers + totalIDPLoginOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_idp_login_ok_total", + Help: "The total number of successful logins using Identity Providers", + }) + + // totalIDPLoginFailed is the metric that reports the total number of + // failed logins using identity providers + totalIDPLoginFailed = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_idp_login_ko_total", + Help: "The total number of failed logins using Identity Providers", + }) + + totalHTTPRequests = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_http_req_total", + Help: "The total number of HTTP requests served", + }) + + totalHTTPOK = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_http_req_ok_total", + Help: "The total number of HTTP requests served with 2xx status code", + }) + + totalHTTPClientErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_http_client_errors_total", + Help: "The total number of HTTP requests served with 4xx status code", + }) + + totalHTTPServerErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_http_server_errors_total", + Help: "The total number of HTTP requests served with 5xx status code", + }) + + // totalS3Uploads is the metric that reports the total number of successful S3 uploads + totalS3Uploads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_uploads_total", + Help: "The total number of successful S3 uploads", + }) + + // totalS3Downloads is the metric that reports the total number of successful S3 downloads + totalS3Downloads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_downloads_total", + Help: "The total number of successful S3 downloads", + }) + + // totalS3UploadErrors is the metric that reports the total number of S3 upload errors + totalS3UploadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_upload_errors_total", + Help: "The total number of S3 upload errors", + }) + + // totalS3DownloadErrors is the metric that reports the total number of S3 download errors + totalS3DownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_download_errors_total", + Help: "The total number of S3 download errors", + }) + + // totalS3UploadSize is the metric that reports the total S3 uploads size as bytes + totalS3UploadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_upload_size", + Help: "The total S3 upload size as bytes, partial uploads are included", + }) + + // totalS3DownloadSize is the metric that reports the total S3 downloads size as bytes + totalS3DownloadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_download_size", + Help: "The total S3 download size as bytes, partial downloads are included", + }) + + // totalS3ListObjects is the metric that reports the total successful S3 list objects requests + totalS3ListObjects = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_list_objects", + Help: "The total number of successful S3 list objects requests", + }) + + // totalS3CopyObject is the metric that reports the total successful S3 copy object requests + totalS3CopyObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_copy_object", + Help: "The total number of successful S3 copy object requests", + }) + + // totalS3DeleteObject is the metric that reports the total successful S3 delete object requests + totalS3DeleteObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_delete_object", + Help: "The total number of successful S3 delete object requests", + }) + + // totalS3ListObjectsError is the metric that reports the total S3 list objects errors + totalS3ListObjectsErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_list_objects_errors", + Help: "The total number of S3 list objects errors", + }) + + // totalS3CopyObjectErrors is the metric that reports the total S3 copy object errors + totalS3CopyObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_copy_object_errors", + Help: "The total number of S3 copy object errors", + }) + + // totalS3DeleteObjectErrors is the metric that reports the total S3 delete object errors + totalS3DeleteObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_delete_object_errors", + Help: "The total number of S3 delete object errors", + }) + + // totalS3HeadObject is the metric that reports the total successful S3 head object requests + totalS3HeadObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_head_object", + Help: "The total number of successful S3 head object requests", + }) + + // totalS3HeadObjectErrors is the metric that reports the total S3 head object errors + totalS3HeadObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_s3_head_object_errors", + Help: "The total number of S3 head object errors", + }) + + // totalGCSUploads is the metric that reports the total number of successful GCS uploads + totalGCSUploads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_uploads_total", + Help: "The total number of successful GCS uploads", + }) + + // totalGCSDownloads is the metric that reports the total number of successful GCS downloads + totalGCSDownloads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_downloads_total", + Help: "The total number of successful GCS downloads", + }) + + // totalGCSUploadErrors is the metric that reports the total number of GCS upload errors + totalGCSUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_upload_errors_total", + Help: "The total number of GCS upload errors", + }) + + // totalGCSDownloadErrors is the metric that reports the total number of GCS download errors + totalGCSDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_download_errors_total", + Help: "The total number of GCS download errors", + }) + + // totalGCSUploadSize is the metric that reports the total GCS uploads size as bytes + totalGCSUploadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_upload_size", + Help: "The total GCS upload size as bytes, partial uploads are included", + }) + + // totalGCSDownloadSize is the metric that reports the total GCS downloads size as bytes + totalGCSDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_download_size", + Help: "The total GCS download size as bytes, partial downloads are included", + }) + + // totalGCSListObjects is the metric that reports the total successful GCS list objects requests + totalGCSListObjects = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_list_objects", + Help: "The total number of successful GCS list objects requests", + }) + + // totalGCSCopyObject is the metric that reports the total successful GCS copy object requests + totalGCSCopyObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_copy_object", + Help: "The total number of successful GCS copy object requests", + }) + + // totalGCSDeleteObject is the metric that reports the total successful GCS delete object requests + totalGCSDeleteObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_delete_object", + Help: "The total number of successful GCS delete object requests", + }) + + // totalGCSListObjectsError is the metric that reports the total GCS list objects errors + totalGCSListObjectsErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_list_objects_errors", + Help: "The total number of GCS list objects errors", + }) + + // totalGCSCopyObjectErrors is the metric that reports the total GCS copy object errors + totalGCSCopyObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_copy_object_errors", + Help: "The total number of GCS copy object errors", + }) + + // totalGCSDeleteObjectErrors is the metric that reports the total GCS delete object errors + totalGCSDeleteObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_delete_object_errors", + Help: "The total number of GCS delete object errors", + }) + + // totalGCSHeadObject is the metric that reports the total successful GCS head object requests + totalGCSHeadObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_head_object", + Help: "The total number of successful GCS head object requests", + }) + + // totalGCSHeadObjectErrors is the metric that reports the total GCS head object errors + totalGCSHeadObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_gcs_head_object_errors", + Help: "The total number of GCS head object errors", + }) + + // totalAZUploads is the metric that reports the total number of successful Azure uploads + totalAZUploads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_uploads_total", + Help: "The total number of successful Azure uploads", + }) + + // totalAZDownloads is the metric that reports the total number of successful Azure downloads + totalAZDownloads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_downloads_total", + Help: "The total number of successful Azure downloads", + }) + + // totalAZUploadErrors is the metric that reports the total number of Azure upload errors + totalAZUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_upload_errors_total", + Help: "The total number of Azure upload errors", + }) + + // totalAZDownloadErrors is the metric that reports the total number of Azure download errors + totalAZDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_download_errors_total", + Help: "The total number of Azure download errors", + }) + + // totalAZUploadSize is the metric that reports the total Azure uploads size as bytes + totalAZUploadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_upload_size", + Help: "The total Azure upload size as bytes, partial uploads are included", + }) + + // totalAZDownloadSize is the metric that reports the total Azure downloads size as bytes + totalAZDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_download_size", + Help: "The total Azure download size as bytes, partial downloads are included", + }) + + // totalAZListObjects is the metric that reports the total successful Azure list objects requests + totalAZListObjects = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_list_objects", + Help: "The total number of successful Azure list objects requests", + }) + + // totalAZCopyObject is the metric that reports the total successful Azure copy object requests + totalAZCopyObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_copy_object", + Help: "The total number of successful Azure copy object requests", + }) + + // totalAZDeleteObject is the metric that reports the total successful Azure delete object requests + totalAZDeleteObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_delete_object", + Help: "The total number of successful Azure delete object requests", + }) + + // totalAZListObjectsError is the metric that reports the total Azure list objects errors + totalAZListObjectsErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_list_objects_errors", + Help: "The total number of Azure list objects errors", + }) + + // totalAZCopyObjectErrors is the metric that reports the total Azure copy object errors + totalAZCopyObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_copy_object_errors", + Help: "The total number of Azure copy object errors", + }) + + // totalAZDeleteObjectErrors is the metric that reports the total Azure delete object errors + totalAZDeleteObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_delete_object_errors", + Help: "The total number of Azure delete object errors", + }) + + // totalAZHeadObject is the metric that reports the total successful Azure head object requests + totalAZHeadObject = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_head_object", + Help: "The total number of successful Azure head object requests", + }) + + // totalAZHeadObjectErrors is the metric that reports the total Azure head object errors + totalAZHeadObjectErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_az_head_object_errors", + Help: "The total number of Azure head object errors", + }) + + // totalSFTPFsUploads is the metric that reports the total number of successful SFTPFs uploads + totalSFTPFsUploads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_sftpfs_uploads_total", + Help: "The total number of successful SFTPFs uploads", + }) + + // totalSFTPFsDownloads is the metric that reports the total number of successful SFTPFs downloads + totalSFTPFsDownloads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_sftpfs_downloads_total", + Help: "The total number of successful SFTPFs downloads", + }) + + // totalSFTPFsUploadErrors is the metric that reports the total number of SFTPFs upload errors + totalSFTPFsUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_sftpfs_upload_errors_total", + Help: "The total number of SFTPFs upload errors", + }) + + // totalSFTPFsDownloadErrors is the metric that reports the total number of SFTPFs download errors + totalSFTPFsDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_sftpfs_download_errors_total", + Help: "The total number of SFTPFs download errors", + }) + + // totalSFTPFsUploadSize is the metric that reports the total SFTPFs uploads size as bytes + totalSFTPFsUploadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_sftpfs_upload_size", + Help: "The total SFTPFs upload size as bytes, partial uploads are included", + }) + + // totalSFTPFsDownloadSize is the metric that reports the total SFTPFs downloads size as bytes + totalSFTPFsDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_sftpfs_download_size", + Help: "The total SFTPFs download size as bytes, partial downloads are included", + }) + + // totalHTTPFsUploads is the metric that reports the total number of successful HTTPFs uploads + totalHTTPFsUploads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_httpfs_uploads_total", + Help: "The total number of successful HTTPFs uploads", + }) + + // totalHTTPFsDownloads is the metric that reports the total number of successful HTTPFs downloads + totalHTTPFsDownloads = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_httpfs_downloads_total", + Help: "The total number of successful HTTPFs downloads", + }) + + // totalHTTPFsUploadErrors is the metric that reports the total number of HTTPFs upload errors + totalHTTPFsUploadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_httpfs_upload_errors_total", + Help: "The total number of HTTPFs upload errors", + }) + + // totalHTTPFsDownloadErrors is the metric that reports the total number of HTTPFs download errors + totalHTTPFsDownloadErrors = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_httpfs_download_errors_total", + Help: "The total number of HTTPFs download errors", + }) + + // totalHTTPFsUploadSize is the metric that reports the total HTTPFs uploads size as bytes + totalHTTPFsUploadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_httpfs_upload_size", + Help: "The total HTTPFs upload size as bytes, partial uploads are included", + }) + + // totalHTTPFsDownloadSize is the metric that reports the total HTTPFs downloads size as bytes + totalHTTPFsDownloadSize = promauto.NewCounter(prometheus.CounterOpts{ + Name: "sftpgo_httpfs_download_size", + Help: "The total HTTPFs download size as bytes, partial downloads are included", + }) +) + +// AddMetricsEndpoint publishes metrics to the specified endpoint +func AddMetricsEndpoint(metricsPath string, handler chi.Router) { + handler.Handle(metricsPath, promhttp.Handler()) +} + +// TransferCompleted updates metrics after an upload or a download +func TransferCompleted(bytesSent, bytesReceived int64, transferKind int, err error, isSFTPFs bool) { + if transferKind == 0 { + // upload + if err == nil { + totalUploads.Inc() + } else { + totalUploadErrors.Inc() + } + } else { + // download + if err == nil { + totalDownloads.Inc() + } else { + totalDownloadErrors.Inc() + } + } + if bytesReceived > 0 { + totalUploadSize.Add(float64(bytesReceived)) + } + if bytesSent > 0 { + totalDownloadSize.Add(float64(bytesSent)) + } + if isSFTPFs { + sftpFsTransferCompleted(bytesSent, bytesReceived, transferKind, err) + } +} + +// S3TransferCompleted updates metrics after an S3 upload or a download +func S3TransferCompleted(bytes int64, transferKind int, err error) { + if transferKind == 0 { + // upload + if err == nil { + totalS3Uploads.Inc() + } else { + totalS3UploadErrors.Inc() + } + totalS3UploadSize.Add(float64(bytes)) + } else { + // download + if err == nil { + totalS3Downloads.Inc() + } else { + totalS3DownloadErrors.Inc() + } + totalS3DownloadSize.Add(float64(bytes)) + } +} + +// S3ListObjectsCompleted updates metrics after an S3 list objects request terminates +func S3ListObjectsCompleted(err error) { + if err == nil { + totalS3ListObjects.Inc() + } else { + totalS3ListObjectsErrors.Inc() + } +} + +// S3CopyObjectCompleted updates metrics after an S3 copy object request terminates +func S3CopyObjectCompleted(err error) { + if err == nil { + totalS3CopyObject.Inc() + } else { + totalS3CopyObjectErrors.Inc() + } +} + +// S3DeleteObjectCompleted updates metrics after an S3 delete object request terminates +func S3DeleteObjectCompleted(err error) { + if err == nil { + totalS3DeleteObject.Inc() + } else { + totalS3DeleteObjectErrors.Inc() + } +} + +// S3HeadObjectCompleted updates metrics after a S3 head object request terminates +func S3HeadObjectCompleted(err error) { + if err == nil { + totalS3HeadObject.Inc() + } else { + totalS3HeadObjectErrors.Inc() + } +} + +// GCSTransferCompleted updates metrics after a GCS upload or a download +func GCSTransferCompleted(bytes int64, transferKind int, err error) { + if transferKind == 0 { + // upload + if err == nil { + totalGCSUploads.Inc() + } else { + totalGCSUploadErrors.Inc() + } + totalGCSUploadSize.Add(float64(bytes)) + } else { + // download + if err == nil { + totalGCSDownloads.Inc() + } else { + totalGCSDownloadErrors.Inc() + } + totalGCSDownloadSize.Add(float64(bytes)) + } +} + +// GCSListObjectsCompleted updates metrics after a GCS list objects request terminates +func GCSListObjectsCompleted(err error) { + if err == nil { + totalGCSListObjects.Inc() + } else { + totalGCSListObjectsErrors.Inc() + } +} + +// GCSCopyObjectCompleted updates metrics after a GCS copy object request terminates +func GCSCopyObjectCompleted(err error) { + if err == nil { + totalGCSCopyObject.Inc() + } else { + totalGCSCopyObjectErrors.Inc() + } +} + +// GCSDeleteObjectCompleted updates metrics after a GCS delete object request terminates +func GCSDeleteObjectCompleted(err error) { + if err == nil { + totalGCSDeleteObject.Inc() + } else { + totalGCSDeleteObjectErrors.Inc() + } +} + +// GCSHeadObjectCompleted updates metrics after a GCS head object request terminates +func GCSHeadObjectCompleted(err error) { + if err == nil { + totalGCSHeadObject.Inc() + } else { + totalGCSHeadObjectErrors.Inc() + } +} + +// AZTransferCompleted updates metrics after a Azure upload or a download +func AZTransferCompleted(bytes int64, transferKind int, err error) { + if transferKind == 0 { + // upload + if err == nil { + totalAZUploads.Inc() + } else { + totalAZUploadErrors.Inc() + } + totalAZUploadSize.Add(float64(bytes)) + } else { + // download + if err == nil { + totalAZDownloads.Inc() + } else { + totalAZDownloadErrors.Inc() + } + totalAZDownloadSize.Add(float64(bytes)) + } +} + +// AZListObjectsCompleted updates metrics after a Azure list objects request terminates +func AZListObjectsCompleted(err error) { + if err == nil { + totalAZListObjects.Inc() + } else { + totalAZListObjectsErrors.Inc() + } +} + +// AZCopyObjectCompleted updates metrics after a Azure copy object request terminates +func AZCopyObjectCompleted(err error) { + if err == nil { + totalAZCopyObject.Inc() + } else { + totalAZCopyObjectErrors.Inc() + } +} + +// AZDeleteObjectCompleted updates metrics after a Azure delete object request terminates +func AZDeleteObjectCompleted(err error) { + if err == nil { + totalAZDeleteObject.Inc() + } else { + totalAZDeleteObjectErrors.Inc() + } +} + +// AZHeadObjectCompleted updates metrics after a Azure head object request terminates +func AZHeadObjectCompleted(err error) { + if err == nil { + totalAZHeadObject.Inc() + } else { + totalAZHeadObjectErrors.Inc() + } +} + +// sftpFsTransferCompleted updates metrics after an SFTPFs upload or a download +func sftpFsTransferCompleted(bytesSent, bytesReceived int64, transferKind int, err error) { + if transferKind == 0 { + // upload + if err == nil { + totalSFTPFsUploads.Inc() + } else { + totalSFTPFsUploadErrors.Inc() + } + } else { + // download + if err == nil { + totalSFTPFsDownloads.Inc() + } else { + totalSFTPFsDownloadErrors.Inc() + } + } + if bytesReceived > 0 { + totalSFTPFsUploadSize.Add(float64(bytesReceived)) + } + if bytesSent > 0 { + totalSFTPFsDownloadSize.Add(float64(bytesSent)) + } +} + +// HTTPFsTransferCompleted updates metrics after an HTTPFs upload or a download +func HTTPFsTransferCompleted(bytes int64, transferKind int, err error) { + if transferKind == 0 { + // upload + if err == nil { + totalHTTPFsUploads.Inc() + } else { + totalHTTPFsUploadErrors.Inc() + } + totalHTTPFsUploadSize.Add(float64(bytes)) + } else { + // download + if err == nil { + totalHTTPFsDownloads.Inc() + } else { + totalHTTPFsDownloadErrors.Inc() + } + totalHTTPFsDownloadSize.Add(float64(bytes)) + } +} + +// SSHCommandCompleted update metrics after an SSH command terminates +func SSHCommandCompleted(err error) { + if err == nil { + totalSSHCommands.Inc() + } else { + totalSSHCommandErrors.Inc() + } +} + +// UpdateDataProviderAvailability updates the metric for the data provider availability +func UpdateDataProviderAvailability(err error) { + if err == nil { + dataproviderAvailability.Set(1) + } else { + dataproviderAvailability.Set(0) + } +} + +// AddLoginAttempt increments the metrics for login attempts +func AddLoginAttempt(authMethod string) { + totalLoginAttempts.Inc() + switch authMethod { + case loginMethodPublicKey: + totalKeyLoginAttempts.Inc() + case loginMethodKeyboardInteractive: + totalInteractiveLoginAttempts.Inc() + case loginMethodKeyAndPassword: + totalKeyAndPasswordLoginAttempts.Inc() + case loginMethodKeyAndKeyboardInt: + totalKeyAndKeyIntLoginAttempts.Inc() + case loginMethodTLSCertificate: + totalTLSCertLoginAttempts.Inc() + case loginMethodTLSCertificateAndPwd: + totalTLSCertAndPwdLoginAttempts.Inc() + case loginMethodIDP: + totalIDPLoginAttempts.Inc() + default: + totalPasswordLoginAttempts.Inc() + } +} + +func incLoginOK(authMethod string) { + totalLoginOK.Inc() + switch authMethod { + case loginMethodPublicKey: + totalKeyLoginOK.Inc() + case loginMethodKeyboardInteractive: + totalInteractiveLoginOK.Inc() + case loginMethodKeyAndPassword: + totalKeyAndPasswordLoginOK.Inc() + case loginMethodKeyAndKeyboardInt: + totalKeyAndKeyIntLoginOK.Inc() + case loginMethodTLSCertificate: + totalTLSCertLoginOK.Inc() + case loginMethodTLSCertificateAndPwd: + totalTLSCertAndPwdLoginOK.Inc() + case loginMethodIDP: + totalIDPLoginOK.Inc() + default: + totalPasswordLoginOK.Inc() + } +} + +func incLoginFailed(authMethod string) { + totalLoginFailed.Inc() + switch authMethod { + case loginMethodPublicKey: + totalKeyLoginFailed.Inc() + case loginMethodKeyboardInteractive: + totalInteractiveLoginFailed.Inc() + case loginMethodKeyAndPassword: + totalKeyAndPasswordLoginFailed.Inc() + case loginMethodKeyAndKeyboardInt: + totalKeyAndKeyIntLoginFailed.Inc() + case loginMethodTLSCertificate: + totalTLSCertLoginFailed.Inc() + case loginMethodTLSCertificateAndPwd: + totalTLSCertAndPwdLoginFailed.Inc() + case loginMethodIDP: + totalIDPLoginFailed.Inc() + default: + totalPasswordLoginFailed.Inc() + } +} + +// AddLoginResult increments the metrics for login results +func AddLoginResult(authMethod string, err error) { + if err == nil { + incLoginOK(authMethod) + } else { + incLoginFailed(authMethod) + } +} + +// AddNoAuthTried increments the metric for clients disconnected +// for inactivity before trying to login +func AddNoAuthTried() { + totalNoAuthTried.Inc() +} + +// HTTPRequestServed increments the metrics for HTTP requests +func HTTPRequestServed(status int) { + totalHTTPRequests.Inc() + if status >= 200 && status < 300 { + totalHTTPOK.Inc() + } else if status >= 400 && status < 500 { + totalHTTPClientErrors.Inc() + } else if status >= 500 { + totalHTTPServerErrors.Inc() + } +} + +// UpdateActiveConnectionsSize sets the metric for active connections +func UpdateActiveConnectionsSize(size int) { + activeConnections.Set(float64(size)) +} diff --git a/internal/metric/metric_disabled.go b/internal/metric/metric_disabled.go new file mode 100644 index 00000000..78bacfdd --- /dev/null +++ b/internal/metric/metric_disabled.go @@ -0,0 +1,88 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nometrics + +package metric + +import ( + "github.com/go-chi/chi/v5" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-metrics") +} + +// AddMetricsEndpoint publishes metrics to the specified endpoint +func AddMetricsEndpoint(_ string, _ chi.Router) {} + +// TransferCompleted updates metrics after an upload or a download +func TransferCompleted(_, _ int64, _ int, _ error, _ bool) {} + +// S3TransferCompleted updates metrics after an S3 upload or a download +func S3TransferCompleted(_ int64, _ int, _ error) {} + +// S3ListObjectsCompleted updates metrics after an S3 list objects request terminates +func S3ListObjectsCompleted(_ error) {} + +// S3CopyObjectCompleted updates metrics after an S3 copy object request terminates +func S3CopyObjectCompleted(_ error) {} + +// S3DeleteObjectCompleted updates metrics after an S3 delete object request terminates +func S3DeleteObjectCompleted(_ error) {} + +// S3HeadBucketCompleted updates metrics after an S3 head bucket request terminates +func S3HeadBucketCompleted(_ error) {} + +// GCSTransferCompleted updates metrics after a GCS upload or a download +func GCSTransferCompleted(_ int64, _ int, _ error) {} + +// GCSListObjectsCompleted updates metrics after a GCS list objects request terminates +func GCSListObjectsCompleted(_ error) {} + +// GCSCopyObjectCompleted updates metrics after a GCS copy object request terminates +func GCSCopyObjectCompleted(_ error) {} + +// GCSDeleteObjectCompleted updates metrics after a GCS delete object request terminates +func GCSDeleteObjectCompleted(_ error) {} + +// GCSHeadBucketCompleted updates metrics after a GCS head bucket request terminates +func GCSHeadBucketCompleted(_ error) {} + +// HTTPFsTransferCompleted updates metrics after an HTTPFs upload or a download +func HTTPFsTransferCompleted(_ int64, _ int, _ error) {} + +// SSHCommandCompleted update metrics after an SSH command terminates +func SSHCommandCompleted(_ error) {} + +// UpdateDataProviderAvailability updates the metric for the data provider availability +func UpdateDataProviderAvailability(_ error) {} + +// AddLoginAttempt increments the metrics for login attempts +func AddLoginAttempt(_ string) {} + +// AddLoginResult increments the metrics for login results +func AddLoginResult(_ string, _ error) {} + +// AddNoAuthTried increments the metric for clients disconnected +// for inactivity before trying to login +func AddNoAuthTried() {} + +// HTTPRequestServed increments the metrics for HTTP requests +func HTTPRequestServed(_ int) {} + +// UpdateActiveConnectionsSize sets the metric for active connections +func UpdateActiveConnectionsSize(_ int) {} diff --git a/internal/mfa/mfa.go b/internal/mfa/mfa.go new file mode 100644 index 00000000..0b58bf62 --- /dev/null +++ b/internal/mfa/mfa.go @@ -0,0 +1,151 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package mfa provides supports for Multi-Factor authentication modules +package mfa + +import ( + "bytes" + "fmt" + "image/png" + "time" + + "github.com/pquerna/otp" +) + +var ( + totpConfigs []*TOTPConfig + serviceStatus ServiceStatus +) + +// ServiceStatus defines the service status +type ServiceStatus struct { + IsActive bool `json:"is_active"` + TOTPConfigs []TOTPConfig `json:"totp_configs"` +} + +// GetStatus returns the service status +func GetStatus() ServiceStatus { + return serviceStatus +} + +// Config defines configuration parameters for Multi-Factor authentication modules +type Config struct { + // Time-based one time passwords configurations + TOTP []TOTPConfig `json:"totp" mapstructure:"totp"` +} + +// Initialize configures the MFA support +func (c *Config) Initialize() error { + totpConfigs = nil + serviceStatus.IsActive = false + serviceStatus.TOTPConfigs = nil + totp := make(map[string]bool) + for _, totpConfig := range c.TOTP { + totpConfig := totpConfig //pin + if err := totpConfig.validate(); err != nil { + totpConfigs = nil + return fmt.Errorf("invalid TOTP config %+v: %v", totpConfig, err) + } + if _, ok := totp[totpConfig.Name]; ok { + totpConfigs = nil + return fmt.Errorf("totp: duplicate configuration name %q", totpConfig.Name) + } + totp[totpConfig.Name] = true + totpConfigs = append(totpConfigs, &totpConfig) + serviceStatus.IsActive = true + serviceStatus.TOTPConfigs = append(serviceStatus.TOTPConfigs, totpConfig) + } + startCleanupTicker(2 * time.Minute) + return nil +} + +// GetAvailableTOTPConfigs returns the available TOTP configs +func GetAvailableTOTPConfigs() []*TOTPConfig { + return totpConfigs +} + +// GetAvailableTOTPConfigNames returns the available TOTP config names +func GetAvailableTOTPConfigNames() []string { + var result []string + for _, c := range totpConfigs { + result = append(result, c.Name) + } + return result +} + +// ValidateTOTPPasscode validates a TOTP passcode using the given secret and configName +func ValidateTOTPPasscode(configName, passcode, secret string) (bool, error) { + for _, config := range totpConfigs { + if config.Name == configName { + return config.validatePasscode(passcode, secret) + } + } + + return false, fmt.Errorf("totp: no configuration %q", configName) +} + +// GenerateTOTPSecret generates a new TOTP secret and QR code for the given username +// using the configuration with configName +func GenerateTOTPSecret(configName, username string) (string, *otp.Key, []byte, error) { + for _, config := range totpConfigs { + if config.Name == configName { + key, qrCode, err := config.generate(username, 200, 200) + return configName, key, qrCode, err + } + } + + return "", nil, nil, fmt.Errorf("totp: no configuration %q", configName) +} + +// GenerateQRCodeFromURL generates a QR code from a TOTP URL +func GenerateQRCodeFromURL(url string, width, height int) ([]byte, error) { + key, err := otp.NewKeyFromURL(url) + if err != nil { + return nil, err + } + var buf bytes.Buffer + img, err := key.Image(width, height) + if err != nil { + return nil, err + } + err = png.Encode(&buf, img) + return buf.Bytes(), err +} + +// the ticker cannot be started/stopped from multiple goroutines +func startCleanupTicker(duration time.Duration) { + stopCleanupTicker() + cleanupTicker = time.NewTicker(duration) + cleanupDone = make(chan bool) + + go func() { + for { + select { + case <-cleanupDone: + return + case <-cleanupTicker.C: + cleanupUsedPasscodes() + } + } + }() +} + +func stopCleanupTicker() { + if cleanupTicker != nil { + cleanupTicker.Stop() + cleanupDone <- true + cleanupTicker = nil + } +} diff --git a/internal/mfa/mfa_test.go b/internal/mfa/mfa_test.go new file mode 100644 index 00000000..52a7bd8e --- /dev/null +++ b/internal/mfa/mfa_test.go @@ -0,0 +1,162 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mfa + +import ( + "testing" + "time" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMFAConfig(t *testing.T) { + config := Config{ + TOTP: []TOTPConfig{ + {}, + }, + } + configName1 := "config1" + configName2 := "config2" + configName3 := "config3" + err := config.Initialize() + assert.Error(t, err) + config.TOTP[0].Name = configName1 + err = config.Initialize() + assert.Error(t, err) + config.TOTP[0].Issuer = "issuer" + err = config.Initialize() + assert.Error(t, err) + config.TOTP[0].Algo = TOTPAlgoSHA1 + err = config.Initialize() + assert.NoError(t, err) + config.TOTP = append(config.TOTP, TOTPConfig{ + Name: configName1, + Issuer: "SFTPGo", + Algo: TOTPAlgoSHA512, + }) + err = config.Initialize() + assert.Error(t, err) + config.TOTP[1].Name = configName2 + err = config.Initialize() + assert.NoError(t, err) + assert.Len(t, GetAvailableTOTPConfigs(), 2) + assert.Len(t, GetAvailableTOTPConfigNames(), 2) + config.TOTP = append(config.TOTP, TOTPConfig{ + Name: configName3, + Issuer: "SFTPGo", + Algo: TOTPAlgoSHA256, + }) + err = config.Initialize() + assert.NoError(t, err) + assert.Len(t, GetAvailableTOTPConfigs(), 3) + if assert.Len(t, GetAvailableTOTPConfigNames(), 3) { + assert.Contains(t, GetAvailableTOTPConfigNames(), configName1) + assert.Contains(t, GetAvailableTOTPConfigNames(), configName2) + assert.Contains(t, GetAvailableTOTPConfigNames(), configName3) + } + status := GetStatus() + assert.True(t, status.IsActive) + if assert.Len(t, status.TOTPConfigs, 3) { + assert.Equal(t, configName1, status.TOTPConfigs[0].Name) + assert.Equal(t, configName2, status.TOTPConfigs[1].Name) + assert.Equal(t, configName3, status.TOTPConfigs[2].Name) + } + // now generate some secrets and validate some passcodes + _, _, _, err = GenerateTOTPSecret("", "") //nolint:dogsled + assert.Error(t, err) + match, err := ValidateTOTPPasscode("", "", "") + assert.Error(t, err) + assert.False(t, match) + cfgName, key, _, err := GenerateTOTPSecret(configName1, "user1") + assert.NoError(t, err) + assert.NotEmpty(t, key.Secret()) + assert.Equal(t, configName1, cfgName) + passcode, err := generatePasscode(key.Secret(), otp.AlgorithmSHA1) + assert.NoError(t, err) + match, err = ValidateTOTPPasscode(configName1, passcode, key.Secret()) + assert.NoError(t, err) + assert.True(t, match) + match, err = ValidateTOTPPasscode(configName1, passcode, key.Secret()) + assert.ErrorIs(t, err, errPasscodeUsed) + assert.False(t, match) + + passcode, err = generatePasscode(key.Secret(), otp.AlgorithmSHA256) + assert.NoError(t, err) + // config1 uses sha1 algo + match, err = ValidateTOTPPasscode(configName1, passcode, key.Secret()) + assert.NoError(t, err) + assert.False(t, match) + // config3 use the expected algo + match, err = ValidateTOTPPasscode(configName3, passcode, key.Secret()) + assert.NoError(t, err) + assert.True(t, match) + + stopCleanupTicker() +} + +func TestGenerateQRCodeFromURL(t *testing.T) { + _, err := GenerateQRCodeFromURL("http://foo\x7f.cloud", 200, 200) + assert.Error(t, err) + config := TOTPConfig{ + Name: "config name", + Issuer: "SFTPGo", + Algo: TOTPAlgoSHA256, + } + key, qrCode, err := config.generate("a", 150, 150) + require.NoError(t, err) + + qrCode1, err := GenerateQRCodeFromURL(key.URL(), 150, 150) + require.NoError(t, err) + assert.Equal(t, qrCode, qrCode1) + _, err = GenerateQRCodeFromURL(key.URL(), 10, 10) + assert.Error(t, err) +} + +func TestCleanupPasscodes(t *testing.T) { + usedPasscodes.Store("key", time.Now().Add(-24*time.Hour).UTC()) + startCleanupTicker(30 * time.Millisecond) + assert.Eventually(t, func() bool { + _, ok := usedPasscodes.Load("key") + return !ok + }, 1000*time.Millisecond, 100*time.Millisecond) + stopCleanupTicker() +} + +func TestTOTPGenerateErrors(t *testing.T) { + config := TOTPConfig{ + Name: "name", + Issuer: "", + algo: otp.AlgorithmSHA1, + } + // issuer cannot be empty + _, _, err := config.generate("username", 200, 200) //nolint:dogsled + assert.Error(t, err) + config.Issuer = "issuer" + // we cannot encode an image smaller than 45x45 + _, _, err = config.generate("username", 30, 30) //nolint:dogsled + assert.Error(t, err) +} + +func generatePasscode(secret string, algo otp.Algorithm) (string, error) { + return totp.GenerateCodeCustom(secret, time.Now(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: algo, + }) +} diff --git a/internal/mfa/totp.go b/internal/mfa/totp.go new file mode 100644 index 00000000..35e17b01 --- /dev/null +++ b/internal/mfa/totp.go @@ -0,0 +1,120 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package mfa + +import ( + "bytes" + "errors" + "fmt" + "image/png" + "sync" + "time" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" +) + +// TOTPHMacAlgo is the enumerable for the possible HMAC algorithms for Time-based one time passwords +type TOTPHMacAlgo = string + +// supported TOTP HMAC algorithms +const ( + TOTPAlgoSHA1 TOTPHMacAlgo = "sha1" + TOTPAlgoSHA256 TOTPHMacAlgo = "sha256" + TOTPAlgoSHA512 TOTPHMacAlgo = "sha512" +) + +var ( + cleanupTicker *time.Ticker + cleanupDone chan bool + usedPasscodes sync.Map + errPasscodeUsed = errors.New("this passcode was already used") +) + +// TOTPConfig defines the configuration for a Time-based one time password +type TOTPConfig struct { + Name string `json:"name" mapstructure:"name"` + Issuer string `json:"issuer" mapstructure:"issuer"` + Algo TOTPHMacAlgo `json:"algo" mapstructure:"algo"` + algo otp.Algorithm +} + +func (c *TOTPConfig) validate() error { + if c.Name == "" { + return errors.New("totp: name is mandatory") + } + if c.Issuer == "" { + return errors.New("totp: issuer is mandatory") + } + switch c.Algo { + case TOTPAlgoSHA1: + c.algo = otp.AlgorithmSHA1 + case TOTPAlgoSHA256: + c.algo = otp.AlgorithmSHA256 + case TOTPAlgoSHA512: + c.algo = otp.AlgorithmSHA512 + default: + return fmt.Errorf("unsupported totp algo %q", c.Algo) + } + return nil +} + +// validatePasscode validates a TOTP passcode +func (c *TOTPConfig) validatePasscode(passcode, secret string) (bool, error) { + key := fmt.Sprintf("%v_%v", secret, passcode) + if _, ok := usedPasscodes.Load(key); ok { + return false, errPasscodeUsed + } + match, err := totp.ValidateCustom(passcode, secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: c.algo, + }) + if match && err == nil { + usedPasscodes.Store(key, time.Now().Add(1*time.Minute).UTC()) + } + return match, err +} + +// generate generates a new TOTP secret and QR code for the given username +func (c *TOTPConfig) generate(username string, qrCodeWidth, qrCodeHeight int) (*otp.Key, []byte, error) { + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: c.Issuer, + AccountName: username, + Digits: otp.DigitsSix, + Algorithm: c.algo, + }) + if err != nil { + return nil, nil, err + } + var buf bytes.Buffer + img, err := key.Image(qrCodeWidth, qrCodeHeight) + if err != nil { + return nil, nil, err + } + err = png.Encode(&buf, img) + return key, buf.Bytes(), err +} + +func cleanupUsedPasscodes() { + usedPasscodes.Range(func(key, value any) bool { + exp, ok := value.(time.Time) + if !ok || exp.Before(time.Now().UTC()) { + usedPasscodes.Delete(key) + } + return true + }) +} diff --git a/internal/plugin/auth.go b/internal/plugin/auth.go new file mode 100644 index 00000000..57935cd8 --- /dev/null +++ b/internal/plugin/auth.go @@ -0,0 +1,192 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "errors" + "fmt" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/sftpgo/sdk/plugin/auth" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// Supported auth scopes +const ( + AuthScopePassword = 1 + AuthScopePublicKey = 2 + AuthScopeKeyboardInteractive = 4 + AuthScopeTLSCertificate = 8 +) + +// KeyboardAuthRequest defines the request for a keyboard interactive authentication step +type KeyboardAuthRequest struct { + RequestID string `json:"request_id"` + Step int `json:"step"` + Username string `json:"username,omitempty"` + IP string `json:"ip,omitempty"` + Password string `json:"password,omitempty"` + Answers []string `json:"answers,omitempty"` + Questions []string `json:"questions,omitempty"` +} + +// KeyboardAuthResponse defines the response for a keyboard interactive authentication step +type KeyboardAuthResponse struct { + Instruction string `json:"instruction"` + Questions []string `json:"questions"` + Echos []bool `json:"echos"` + AuthResult int `json:"auth_result"` + CheckPwd int `json:"check_password"` +} + +// Validate returns an error if the KeyboardAuthResponse is invalid +func (r *KeyboardAuthResponse) Validate() error { + if len(r.Questions) == 0 { + err := errors.New("interactive auth error: response does not contain questions") + return err + } + if len(r.Questions) != len(r.Echos) { + err := fmt.Errorf("interactive auth error: response questions don't match echos: %v %v", + len(r.Questions), len(r.Echos)) + return err + } + return nil +} + +// AuthConfig defines configuration parameters for auth plugins +type AuthConfig struct { + // Scope defines the scope for the authentication plugin. + // - 1 means passwords only + // - 2 means public keys only + // - 4 means keyboard interactive only + // - 8 means TLS certificates only + // you can combine the scopes, for example 3 means password and public key, 5 password and keyboard + // interactive and so on + Scope int `json:"scope" mapstructure:"scope"` +} + +func (c *AuthConfig) validate() error { + authScopeMax := AuthScopePassword + AuthScopePublicKey + AuthScopeKeyboardInteractive + AuthScopeTLSCertificate + if c.Scope == 0 || c.Scope > authScopeMax { + return fmt.Errorf("invalid auth scope: %v", c.Scope) + } + return nil +} + +type authPlugin struct { + config Config + service auth.Authenticator + client *plugin.Client +} + +func newAuthPlugin(config Config) (*authPlugin, error) { + p := &authPlugin{ + config: config, + } + if err := p.initialize(); err != nil { + logger.Warn(logSender, "", "unable to create auth plugin: %v, config %+v", err, config) + return nil, err + } + return p, nil +} + +func (p *authPlugin) initialize() error { + killProcess(p.config.Cmd) + logger.Debug(logSender, "", "create new auth plugin %q", p.config.Cmd) + if err := p.config.AuthOptions.validate(); err != nil { + return fmt.Errorf("invalid options for auth plugin %q: %v", p.config.Cmd, err) + } + + secureConfig, err := p.config.getSecureConfig() + if err != nil { + return err + } + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: auth.Handshake, + Plugins: auth.PluginMap, + Cmd: p.config.getCommand(), + SkipHostEnv: true, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolGRPC, + }, + AutoMTLS: p.config.AutoMTLS, + SecureConfig: secureConfig, + Managed: false, + Logger: &logger.HCLogAdapter{ + Logger: hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("%v.%v", logSender, auth.PluginName), + Level: pluginsLogLevel, + DisableTime: true, + }), + }, + }) + rpcClient, err := client.Client() + if err != nil { + logger.Debug(logSender, "", "unable to get rpc client for auth plugin %q: %v", p.config.Cmd, err) + return err + } + raw, err := rpcClient.Dispense(auth.PluginName) + if err != nil { + logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", + auth.PluginName, p.config.Cmd, err) + return err + } + + p.service = raw.(auth.Authenticator) + p.client = client + + return nil +} + +func (p *authPlugin) exited() bool { + return p.client.Exited() +} + +func (p *authPlugin) cleanup() { + p.client.Kill() +} + +func (p *authPlugin) checkUserAndPass(username, password, ip, protocol string, userAsJSON []byte) ([]byte, error) { + return p.service.CheckUserAndPass(username, password, ip, protocol, userAsJSON) +} + +func (p *authPlugin) checkUserAndTLSCertificate(username, tlsCert, ip, protocol string, userAsJSON []byte) ([]byte, error) { + return p.service.CheckUserAndTLSCert(username, tlsCert, ip, protocol, userAsJSON) +} + +func (p *authPlugin) checkUserAndPublicKey(username, pubKey, ip, protocol string, userAsJSON []byte) ([]byte, error) { + return p.service.CheckUserAndPublicKey(username, pubKey, ip, protocol, userAsJSON) +} + +func (p *authPlugin) checkUserAndKeyboardInteractive(username, ip, protocol string, userAsJSON []byte) ([]byte, error) { + return p.service.CheckUserAndKeyboardInteractive(username, ip, protocol, userAsJSON) +} + +func (p *authPlugin) sendKeyboardIteractiveRequest(req *KeyboardAuthRequest) (*KeyboardAuthResponse, error) { + instructions, questions, echos, authResult, checkPassword, err := p.service.SendKeyboardAuthRequest( + req.RequestID, req.Username, req.Password, req.IP, req.Answers, req.Questions, int32(req.Step)) + if err != nil { + return nil, err + } + return &KeyboardAuthResponse{ + Instruction: instructions, + Questions: questions, + Echos: echos, + AuthResult: authResult, + CheckPwd: checkPassword, + }, nil +} diff --git a/internal/plugin/ipfilter.go b/internal/plugin/ipfilter.go new file mode 100644 index 00000000..496fb570 --- /dev/null +++ b/internal/plugin/ipfilter.go @@ -0,0 +1,94 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "fmt" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/sftpgo/sdk/plugin/ipfilter" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +type ipFilterPlugin struct { + config Config + filter ipfilter.Filter + client *plugin.Client +} + +func newIPFilterPlugin(config Config) (*ipFilterPlugin, error) { + p := &ipFilterPlugin{ + config: config, + } + if err := p.initialize(); err != nil { + logger.Warn(logSender, "", "unable to create IP filter plugin: %v, config %+v", err, config) + return nil, err + } + return p, nil +} + +func (p *ipFilterPlugin) exited() bool { + return p.client.Exited() +} + +func (p *ipFilterPlugin) cleanup() { + p.client.Kill() +} + +func (p *ipFilterPlugin) initialize() error { + logger.Debug(logSender, "", "create new IP filter plugin %q", p.config.Cmd) + killProcess(p.config.Cmd) + secureConfig, err := p.config.getSecureConfig() + if err != nil { + return err + } + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: ipfilter.Handshake, + Plugins: ipfilter.PluginMap, + Cmd: p.config.getCommand(), + SkipHostEnv: true, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolGRPC, + }, + AutoMTLS: p.config.AutoMTLS, + SecureConfig: secureConfig, + Managed: false, + Logger: &logger.HCLogAdapter{ + Logger: hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("%v.%v", logSender, ipfilter.PluginName), + Level: pluginsLogLevel, + DisableTime: true, + }), + }, + }) + rpcClient, err := client.Client() + if err != nil { + logger.Debug(logSender, "", "unable to get rpc client for plugin %q: %v", p.config.Cmd, err) + return err + } + raw, err := rpcClient.Dispense(ipfilter.PluginName) + if err != nil { + logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", + ipfilter.PluginName, p.config.Cmd, err) + return err + } + + p.client = client + p.filter = raw.(ipfilter.Filter) + + return nil +} diff --git a/internal/plugin/kms.go b/internal/plugin/kms.go new file mode 100644 index 00000000..b4ed5be1 --- /dev/null +++ b/internal/plugin/kms.go @@ -0,0 +1,195 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "fmt" + "path/filepath" + "slices" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + sdkkms "github.com/sftpgo/sdk/kms" + kmsplugin "github.com/sftpgo/sdk/plugin/kms" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +var ( + validKMSSchemes = []string{sdkkms.SchemeAWS, sdkkms.SchemeGCP, sdkkms.SchemeVaultTransit, + sdkkms.SchemeAzureKeyVault, sdkkms.SchemeOracleKeyVault} + validKMSEncryptedStatuses = []string{sdkkms.SecretStatusVaultTransit, sdkkms.SecretStatusAWS, sdkkms.SecretStatusGCP, + sdkkms.SecretStatusAzureKeyVault, sdkkms.SecretStatusOracleKeyVault} +) + +// KMSConfig defines configuration parameters for kms plugins +type KMSConfig struct { + Scheme string `json:"scheme" mapstructure:"scheme"` + EncryptedStatus string `json:"encrypted_status" mapstructure:"encrypted_status"` +} + +func (c *KMSConfig) validate() error { + if !slices.Contains(validKMSSchemes, c.Scheme) { + return fmt.Errorf("invalid kms scheme: %v", c.Scheme) + } + if !slices.Contains(validKMSEncryptedStatuses, c.EncryptedStatus) { + return fmt.Errorf("invalid kms encrypted status: %v", c.EncryptedStatus) + } + return nil +} + +type kmsPlugin struct { + config Config + service kmsplugin.Service + client *plugin.Client +} + +func newKMSPlugin(config Config) (*kmsPlugin, error) { + p := &kmsPlugin{ + config: config, + } + if err := p.initialize(); err != nil { + logger.Warn(logSender, "", "unable to create kms plugin: %v, config %+v", err, config) + return nil, err + } + return p, nil +} + +func (p *kmsPlugin) initialize() error { + killProcess(p.config.Cmd) + logger.Debug(logSender, "", "create new kms plugin %q", p.config.Cmd) + if err := p.config.KMSOptions.validate(); err != nil { + return fmt.Errorf("invalid options for kms plugin %q: %v", p.config.Cmd, err) + } + secureConfig, err := p.config.getSecureConfig() + if err != nil { + return err + } + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: kmsplugin.Handshake, + Plugins: kmsplugin.PluginMap, + Cmd: p.config.getCommand(), + SkipHostEnv: true, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolGRPC, + }, + AutoMTLS: p.config.AutoMTLS, + SecureConfig: secureConfig, + Managed: false, + Logger: &logger.HCLogAdapter{ + Logger: hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("%v.%v", logSender, kmsplugin.PluginName), + Level: pluginsLogLevel, + DisableTime: true, + }), + }, + }) + rpcClient, err := client.Client() + if err != nil { + logger.Debug(logSender, "", "unable to get rpc client for kms plugin %q: %v", p.config.Cmd, err) + return err + } + raw, err := rpcClient.Dispense(kmsplugin.PluginName) + if err != nil { + logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", + kmsplugin.PluginName, p.config.Cmd, err) + return err + } + + p.client = client + p.service = raw.(kmsplugin.Service) + + return nil +} + +func (p *kmsPlugin) exited() bool { + return p.client.Exited() +} + +func (p *kmsPlugin) cleanup() { + p.client.Kill() +} + +func (p *kmsPlugin) Encrypt(secret kms.BaseSecret, url string, masterKey string) (string, string, int32, error) { + return p.service.Encrypt(secret.Payload, secret.AdditionalData, url, masterKey) +} + +func (p *kmsPlugin) Decrypt(secret kms.BaseSecret, url string, masterKey string) (string, error) { + return p.service.Decrypt(secret.Payload, secret.Key, secret.AdditionalData, secret.Mode, url, masterKey) +} + +type kmsPluginSecretProvider struct { + kms.BaseSecret + URL string + MasterKey string + config *Config +} + +func (s *kmsPluginSecretProvider) Name() string { + return fmt.Sprintf("KMSPlugin_%v_%v_%v", filepath.Base(s.config.Cmd), s.config.KMSOptions.Scheme, s.config.kmsID) +} + +func (s *kmsPluginSecretProvider) IsEncrypted() bool { + return s.Status == s.config.KMSOptions.EncryptedStatus +} + +func (s *kmsPluginSecretProvider) Encrypt() error { + if s.Status != sdkkms.SecretStatusPlain { + return kms.ErrWrongSecretStatus + } + if s.Payload == "" { + return kms.ErrInvalidSecret + } + + payload, key, mode, err := Handler.kmsEncrypt(s.BaseSecret, s.URL, s.MasterKey, s.config.kmsID) + if err != nil { + return err + } + s.Status = s.config.KMSOptions.EncryptedStatus + s.Payload = payload + s.Key = key + s.Mode = int(mode) + + return nil +} + +func (s *kmsPluginSecretProvider) Decrypt() error { + if !s.IsEncrypted() { + return kms.ErrWrongSecretStatus + } + payload, err := Handler.kmsDecrypt(s.BaseSecret, s.URL, s.MasterKey, s.config.kmsID) + if err != nil { + return err + } + s.Status = sdkkms.SecretStatusPlain + s.Payload = payload + s.Key = "" + s.AdditionalData = "" + s.Mode = 0 + + return nil +} + +func (s *kmsPluginSecretProvider) Clone() kms.SecretProvider { + baseSecret := kms.BaseSecret{ + Status: s.Status, + Payload: s.Payload, + Key: s.Key, + AdditionalData: s.AdditionalData, + Mode: s.Mode, + } + return s.config.newKMSPluginSecretProvider(baseSecret, s.URL, s.MasterKey) +} diff --git a/internal/plugin/notifier.go b/internal/plugin/notifier.go new file mode 100644 index 00000000..2a37c462 --- /dev/null +++ b/internal/plugin/notifier.go @@ -0,0 +1,268 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "fmt" + "slices" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// NotifierConfig defines configuration parameters for notifiers plugins +type NotifierConfig struct { + FsEvents []string `json:"fs_events" mapstructure:"fs_events"` + ProviderEvents []string `json:"provider_events" mapstructure:"provider_events"` + ProviderObjects []string `json:"provider_objects" mapstructure:"provider_objects"` + LogEvents []int `json:"log_events" mapstructure:"log_events"` + RetryMaxTime int `json:"retry_max_time" mapstructure:"retry_max_time"` + RetryQueueMaxSize int `json:"retry_queue_max_size" mapstructure:"retry_queue_max_size"` +} + +func (c *NotifierConfig) hasActions() bool { + if len(c.FsEvents) > 0 { + return true + } + if len(c.ProviderEvents) > 0 && len(c.ProviderObjects) > 0 { + return true + } + if len(c.LogEvents) > 0 { + return true + } + return false +} + +type notifierPlugin struct { + config Config + notifier notifier.Notifier + client *plugin.Client + mu sync.RWMutex + fsEvents []*notifier.FsEvent + providerEvents []*notifier.ProviderEvent + logEvents []*notifier.LogEvent +} + +func newNotifierPlugin(config Config) (*notifierPlugin, error) { + p := ¬ifierPlugin{ + config: config, + } + if err := p.initialize(); err != nil { + logger.Warn(logSender, "", "unable to create notifier plugin: %v, config %+v", err, config) + return nil, err + } + return p, nil +} + +func (p *notifierPlugin) exited() bool { + return p.client.Exited() +} + +func (p *notifierPlugin) cleanup() { + p.client.Kill() +} + +func (p *notifierPlugin) initialize() error { + killProcess(p.config.Cmd) + logger.Debug(logSender, "", "create new notifier plugin %q", p.config.Cmd) + if !p.config.NotifierOptions.hasActions() { + return fmt.Errorf("no actions defined for the notifier plugin %q", p.config.Cmd) + } + secureConfig, err := p.config.getSecureConfig() + if err != nil { + return err + } + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: notifier.Handshake, + Plugins: notifier.PluginMap, + Cmd: p.config.getCommand(), + SkipHostEnv: true, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolGRPC, + }, + AutoMTLS: p.config.AutoMTLS, + SecureConfig: secureConfig, + Managed: false, + Logger: &logger.HCLogAdapter{ + Logger: hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("%s.%s", logSender, notifier.PluginName), + Level: pluginsLogLevel, + DisableTime: true, + }), + }, + }) + rpcClient, err := client.Client() + if err != nil { + logger.Debug(logSender, "", "unable to get rpc client for plugin %q: %v", p.config.Cmd, err) + return err + } + raw, err := rpcClient.Dispense(notifier.PluginName) + if err != nil { + logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", + notifier.PluginName, p.config.Cmd, err) + return err + } + + p.client = client + p.notifier = raw.(notifier.Notifier) + + return nil +} + +func (p *notifierPlugin) queueSize() int { + p.mu.RLock() + defer p.mu.RUnlock() + + return len(p.providerEvents) + len(p.fsEvents) + len(p.logEvents) +} + +func (p *notifierPlugin) queueFsEvent(ev *notifier.FsEvent) { + p.mu.Lock() + defer p.mu.Unlock() + + p.fsEvents = append(p.fsEvents, ev) +} + +func (p *notifierPlugin) queueProviderEvent(ev *notifier.ProviderEvent) { + p.mu.Lock() + defer p.mu.Unlock() + + p.providerEvents = append(p.providerEvents, ev) +} + +func (p *notifierPlugin) queueLogEvent(ev *notifier.LogEvent) { + p.mu.Lock() + defer p.mu.Unlock() + + p.logEvents = append(p.logEvents, ev) +} + +func (p *notifierPlugin) canQueueEvent(timestamp int64) bool { + if p.config.NotifierOptions.RetryMaxTime == 0 { + return false + } + if time.Now().After(time.Unix(0, timestamp).Add(time.Duration(p.config.NotifierOptions.RetryMaxTime) * time.Second)) { + logger.Warn(logSender, "", "dropping too late event for plugin %v, event timestamp: %v", + p.config.Cmd, time.Unix(0, timestamp)) + return false + } + if p.config.NotifierOptions.RetryQueueMaxSize > 0 { + return p.queueSize() < p.config.NotifierOptions.RetryQueueMaxSize + } + return true +} + +func (p *notifierPlugin) notifyFsAction(event *notifier.FsEvent) { + if !slices.Contains(p.config.NotifierOptions.FsEvents, event.Action) { + return + } + p.sendFsEvent(event) +} + +func (p *notifierPlugin) notifyProviderAction(event *notifier.ProviderEvent, object Renderer) { + if !slices.Contains(p.config.NotifierOptions.ProviderEvents, event.Action) || + !slices.Contains(p.config.NotifierOptions.ProviderObjects, event.ObjectType) { + return + } + p.sendProviderEvent(event, object) +} + +func (p *notifierPlugin) notifyLogEvent(event *notifier.LogEvent) { + p.sendLogEvent(event) +} + +func (p *notifierPlugin) sendFsEvent(ev *notifier.FsEvent) { + go func(event *notifier.FsEvent) { + Handler.addTask() + defer Handler.removeTask() + + if err := p.notifier.NotifyFsEvent(event); err != nil { + logger.Warn(logSender, "", "unable to send fs action notification to plugin %v: %v", p.config.Cmd, err) + if p.canQueueEvent(event.Timestamp) { + p.queueFsEvent(event) + } + } + }(ev) +} + +func (p *notifierPlugin) sendProviderEvent(ev *notifier.ProviderEvent, object Renderer) { + go func(event *notifier.ProviderEvent) { + Handler.addTask() + defer Handler.removeTask() + + if object != nil { + objectAsJSON, err := object.RenderAsJSON(event.Action != "delete") + if err != nil { + logger.Error(logSender, "", "unable to render user as json for action %q: %v", event.Action, err) + } else { + event.ObjectData = objectAsJSON + } + } + + if err := p.notifier.NotifyProviderEvent(event); err != nil { + logger.Warn(logSender, "", "unable to send user action notification to plugin %v: %v", p.config.Cmd, err) + if p.canQueueEvent(event.Timestamp) { + p.queueProviderEvent(event) + } + } + }(ev) +} + +func (p *notifierPlugin) sendLogEvent(ev *notifier.LogEvent) { + go func(event *notifier.LogEvent) { + Handler.addTask() + defer Handler.removeTask() + + if err := p.notifier.NotifyLogEvent(event); err != nil { + logger.Warn(logSender, "", "unable to send log event to plugin %v: %v", p.config.Cmd, err) + if p.canQueueEvent(event.Timestamp) { + p.queueLogEvent(event) + } + } + }(ev) +} + +func (p *notifierPlugin) sendQueuedEvents() { + queueSize := p.queueSize() + if queueSize == 0 { + return + } + p.mu.Lock() + defer p.mu.Unlock() + + logger.Debug(logSender, "", "send queued events for notifier %q, events size: %v", p.config.Cmd, queueSize) + + for _, ev := range p.fsEvents { + p.sendFsEvent(ev) + } + p.fsEvents = nil + + for _, ev := range p.providerEvents { + p.sendProviderEvent(ev, nil) + } + p.providerEvents = nil + + for _, ev := range p.logEvents { + p.sendLogEvent(ev) + } + p.logEvents = nil + + logger.Debug(logSender, "", "%d queued events sent for notifier %q,", queueSize, p.config.Cmd) +} diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go new file mode 100644 index 00000000..94479c81 --- /dev/null +++ b/internal/plugin/plugin.go @@ -0,0 +1,798 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package plugin provides support for the SFTPGo plugin system +package plugin + +import ( + "crypto/sha256" + "crypto/x509" + "encoding/hex" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/sftpgo/sdk/plugin/auth" + "github.com/sftpgo/sdk/plugin/eventsearcher" + "github.com/sftpgo/sdk/plugin/ipfilter" + kmsplugin "github.com/sftpgo/sdk/plugin/kms" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + logSender = "plugins" +) + +var ( + // Handler defines the plugins manager + Handler Manager + pluginsLogLevel = hclog.Debug + // ErrNoSearcher defines the error to return for events searches if no plugin is configured + ErrNoSearcher = errors.New("no events searcher plugin defined") +) + +// Renderer defines the interface for generic objects rendering +type Renderer interface { + RenderAsJSON(reload bool) ([]byte, error) +} + +// Config defines a plugin configuration +type Config struct { + // Plugin type + Type string `json:"type" mapstructure:"type"` + // NotifierOptions defines options for notifiers plugins + NotifierOptions NotifierConfig `json:"notifier_options" mapstructure:"notifier_options"` + // KMSOptions defines options for a KMS plugin + KMSOptions KMSConfig `json:"kms_options" mapstructure:"kms_options"` + // AuthOptions defines options for authentication plugins + AuthOptions AuthConfig `json:"auth_options" mapstructure:"auth_options"` + // Path to the plugin executable + Cmd string `json:"cmd" mapstructure:"cmd"` + // Args to pass to the plugin executable + Args []string `json:"args" mapstructure:"args"` + // SHA256 checksum for the plugin executable. + // If not empty it will be used to verify the integrity of the executable + SHA256Sum string `json:"sha256sum" mapstructure:"sha256sum"` + // If enabled the client and the server automatically negotiate mTLS for + // transport authentication. This ensures that only the original client will + // be allowed to connect to the server, and all other connections will be + // rejected. The client will also refuse to connect to any server that isn't + // the original instance started by the client. + AutoMTLS bool `json:"auto_mtls" mapstructure:"auto_mtls"` + // EnvPrefix defines the prefix for env vars to pass from the SFTPGo process + // environment to the plugin. Set to "none" to not pass any environment + // variable, set to "*" to pass all environment variables. If empty, the + // prefix is returned as the plugin name in uppercase with "-" replaced with + // "_" and a trailing "_". For example if the plugin name is + // sftpgo-plugin-eventsearch the prefix will be SFTPGO_PLUGIN_EVENTSEARCH_ + EnvPrefix string `json:"env_prefix" mapstructure:"env_prefix"` + // Additional environment variable names to pass from the SFTPGo process + // environment to the plugin. + EnvVars []string `json:"env_vars" mapstructure:"env_vars"` + // unique identifier for kms plugins + kmsID int +} + +func (c *Config) getSecureConfig() (*plugin.SecureConfig, error) { + if c.SHA256Sum != "" { + checksum, err := hex.DecodeString(c.SHA256Sum) + if err != nil { + return nil, fmt.Errorf("invalid sha256 hash %q: %w", c.SHA256Sum, err) + } + return &plugin.SecureConfig{ + Checksum: checksum, + Hash: sha256.New(), + }, nil + } + return nil, nil +} + +func (c *Config) getEnvVarPrefix() string { + if c.EnvPrefix == "none" { + return "" + } + if c.EnvPrefix != "" { + return c.EnvPrefix + } + + baseName := filepath.Base(c.Cmd) + name := strings.TrimSuffix(baseName, filepath.Ext(baseName)) + prefix := strings.ToUpper(name) + "_" + return strings.ReplaceAll(prefix, "-", "_") +} + +func (c *Config) getCommand() *exec.Cmd { + cmd := exec.Command(c.Cmd, c.Args...) + cmd.Env = []string{} + + if envVarPrefix := c.getEnvVarPrefix(); envVarPrefix != "" { + if envVarPrefix == "*" { + logger.Debug(logSender, "", "sharing all the environment variables with plugin %q", c.Cmd) + cmd.Env = append(cmd.Env, os.Environ()...) + return cmd + } + logger.Debug(logSender, "", "adding env vars with prefix %q for plugin %q", envVarPrefix, c.Cmd) + for _, val := range os.Environ() { + if strings.HasPrefix(val, envVarPrefix) { + cmd.Env = append(cmd.Env, val) + } + } + } + logger.Debug(logSender, "", "additional env vars for plugin %q: %+v", c.Cmd, c.EnvVars) + for _, key := range c.EnvVars { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, os.Getenv(key))) + } + return cmd +} + +func (c *Config) newKMSPluginSecretProvider(base kms.BaseSecret, url, masterKey string) kms.SecretProvider { + return &kmsPluginSecretProvider{ + BaseSecret: base, + URL: url, + MasterKey: masterKey, + config: c, + } +} + +// Manager handles enabled plugins +type Manager struct { + closed atomic.Bool + done chan bool + // List of configured plugins + Configs []Config `json:"plugins" mapstructure:"plugins"` + notifLock sync.RWMutex + notifiers []*notifierPlugin + kmsLock sync.RWMutex + kms []*kmsPlugin + authLock sync.RWMutex + auths []*authPlugin + searcherLock sync.RWMutex + searcher *searcherPlugin + ipFilterLock sync.RWMutex + filter *ipFilterPlugin + authScopes int + hasSearcher bool + hasNotifiers bool + hasAuths bool + hasIPFilter bool + concurrencyGuard chan struct{} +} + +// Initialize initializes the configured plugins +func Initialize(configs []Config, logLevel string) error { + logger.Debug(logSender, "", "initialize") + Handler = Manager{ + Configs: configs, + done: make(chan bool), + authScopes: -1, + concurrencyGuard: make(chan struct{}, 250), + } + Handler.closed.Store(false) + setLogLevel(logLevel) + if len(configs) == 0 { + return nil + } + + if err := Handler.validateConfigs(); err != nil { + return err + } + if err := initializePlugins(); err != nil { + return err + } + + startCheckTicker() + return nil +} + +func initializePlugins() error { + kmsID := 0 + for idx, config := range Handler.Configs { + switch config.Type { + case notifier.PluginName: + plugin, err := newNotifierPlugin(config) + if err != nil { + return err + } + Handler.notifiers = append(Handler.notifiers, plugin) + case kmsplugin.PluginName: + plugin, err := newKMSPlugin(config) + if err != nil { + return err + } + Handler.kms = append(Handler.kms, plugin) + Handler.Configs[idx].kmsID = kmsID + kmsID++ + kms.RegisterSecretProvider(config.KMSOptions.Scheme, config.KMSOptions.EncryptedStatus, + Handler.Configs[idx].newKMSPluginSecretProvider) + logger.Info(logSender, "", "registered secret provider for scheme %q, encrypted status %q", + config.KMSOptions.Scheme, config.KMSOptions.EncryptedStatus) + case auth.PluginName: + plugin, err := newAuthPlugin(config) + if err != nil { + return err + } + Handler.auths = append(Handler.auths, plugin) + if Handler.authScopes == -1 { + Handler.authScopes = config.AuthOptions.Scope + } else { + Handler.authScopes |= config.AuthOptions.Scope + } + case eventsearcher.PluginName: + plugin, err := newSearcherPlugin(config) + if err != nil { + return err + } + Handler.searcher = plugin + case ipfilter.PluginName: + plugin, err := newIPFilterPlugin(config) + if err != nil { + return err + } + Handler.filter = plugin + default: + return fmt.Errorf("unsupported plugin type: %v", config.Type) + } + } + + return nil +} + +func (m *Manager) validateConfigs() error { + kmsSchemes := make(map[string]bool) + kmsEncryptions := make(map[string]bool) + m.hasSearcher = false + m.hasNotifiers = false + m.hasAuths = false + m.hasIPFilter = false + + for _, config := range m.Configs { + switch config.Type { + case kmsplugin.PluginName: + if _, ok := kmsSchemes[config.KMSOptions.Scheme]; ok { + return fmt.Errorf("invalid KMS configuration, duplicated scheme %q", config.KMSOptions.Scheme) + } + if _, ok := kmsEncryptions[config.KMSOptions.EncryptedStatus]; ok { + return fmt.Errorf("invalid KMS configuration, duplicated encrypted status %q", config.KMSOptions.EncryptedStatus) + } + kmsSchemes[config.KMSOptions.Scheme] = true + kmsEncryptions[config.KMSOptions.EncryptedStatus] = true + case eventsearcher.PluginName: + if m.hasSearcher { + return errors.New("only one eventsearcher plugin can be defined") + } + m.hasSearcher = true + case notifier.PluginName: + m.hasNotifiers = true + case auth.PluginName: + m.hasAuths = true + case ipfilter.PluginName: + m.hasIPFilter = true + } + } + return nil +} + +// HasAuthenticators returns true if there is at least an auth plugin +func (m *Manager) HasAuthenticators() bool { + return m.hasAuths +} + +// HasNotifiers returns true if there is at least a notifier plugin +func (m *Manager) HasNotifiers() bool { + return m.hasNotifiers +} + +// NotifyFsEvent sends the fs event notifications using any defined notifier plugins +func (m *Manager) NotifyFsEvent(event *notifier.FsEvent) { + m.notifLock.RLock() + defer m.notifLock.RUnlock() + + for _, n := range m.notifiers { + n.notifyFsAction(event) + } +} + +// NotifyProviderEvent sends the provider event notifications using any defined notifier plugins +func (m *Manager) NotifyProviderEvent(event *notifier.ProviderEvent, object Renderer) { + m.notifLock.RLock() + defer m.notifLock.RUnlock() + + for _, n := range m.notifiers { + n.notifyProviderAction(event, object) + } +} + +// NotifyLogEvent sends the log event notifications using any defined notifier plugins +func (m *Manager) NotifyLogEvent(event notifier.LogEventType, protocol, username, ip, role string, err error) { + if !m.hasNotifiers { + return + } + m.notifLock.RLock() + defer m.notifLock.RUnlock() + + var e *notifier.LogEvent + + for _, n := range m.notifiers { + if slices.Contains(n.config.NotifierOptions.LogEvents, int(event)) { + if e == nil { + message := "" + if err != nil { + message = strings.Trim(err.Error(), "\x00") + } + + e = ¬ifier.LogEvent{ + Timestamp: time.Now().UnixNano(), + Event: event, + Protocol: protocol, + Username: username, + IP: ip, + Message: message, + Role: role, + } + } + n.notifyLogEvent(e) + } + } +} + +// HasSearcher returns true if an event searcher plugin is defined +func (m *Manager) HasSearcher() bool { + return m.hasSearcher +} + +// SearchFsEvents returns the filesystem events matching the specified filters +func (m *Manager) SearchFsEvents(searchFilters *eventsearcher.FsEventSearch) ([]byte, error) { + if !m.hasSearcher { + return nil, ErrNoSearcher + } + m.searcherLock.RLock() + plugin := m.searcher + m.searcherLock.RUnlock() + + return plugin.searchear.SearchFsEvents(searchFilters) +} + +// SearchProviderEvents returns the provider events matching the specified filters +func (m *Manager) SearchProviderEvents(searchFilters *eventsearcher.ProviderEventSearch) ([]byte, error) { + if !m.hasSearcher { + return nil, ErrNoSearcher + } + m.searcherLock.RLock() + plugin := m.searcher + m.searcherLock.RUnlock() + + return plugin.searchear.SearchProviderEvents(searchFilters) +} + +// SearchLogEvents returns the log events matching the specified filters +func (m *Manager) SearchLogEvents(searchFilters *eventsearcher.LogEventSearch) ([]byte, error) { + if !m.hasSearcher { + return nil, ErrNoSearcher + } + m.searcherLock.RLock() + plugin := m.searcher + m.searcherLock.RUnlock() + + return plugin.searchear.SearchLogEvents(searchFilters) +} + +// IsIPBanned returns true if the IP filter plugin does not allow the specified ip. +// If no IP filter plugin is defined this method returns false +func (m *Manager) IsIPBanned(ip, protocol string) bool { + if !m.hasIPFilter { + return false + } + + m.ipFilterLock.RLock() + plugin := m.filter + m.ipFilterLock.RUnlock() + + if plugin.exited() { + logger.Warn(logSender, "", "ip filter plugin is not active, cannot check ip %q", ip) + return false + } + + return plugin.filter.CheckIP(ip, protocol) != nil +} + +// ReloadFilter sends a reload request to the IP filter plugin +func (m *Manager) ReloadFilter() { + if !m.hasIPFilter { + return + } + m.ipFilterLock.RLock() + plugin := m.filter + m.ipFilterLock.RUnlock() + + if err := plugin.filter.Reload(); err != nil { + logger.Error(logSender, "", "unable to reload IP filter plugin: %v", err) + } +} + +func (m *Manager) kmsEncrypt(secret kms.BaseSecret, url string, masterKey string, kmsID int) (string, string, int32, error) { + m.kmsLock.RLock() + plugin := m.kms[kmsID] + m.kmsLock.RUnlock() + + return plugin.Encrypt(secret, url, masterKey) +} + +func (m *Manager) kmsDecrypt(secret kms.BaseSecret, url string, masterKey string, kmsID int) (string, error) { + m.kmsLock.RLock() + plugin := m.kms[kmsID] + m.kmsLock.RUnlock() + + return plugin.Decrypt(secret, url, masterKey) +} + +// HasAuthScope returns true if there is an auth plugin that support the specified scope +func (m *Manager) HasAuthScope(scope int) bool { + if m.authScopes == -1 { + return false + } + return m.authScopes&scope != 0 +} + +// Authenticate tries to authenticate the specified user using an external plugin +func (m *Manager) Authenticate(username, password, ip, protocol string, pkey string, + tlsCert *x509.Certificate, authScope int, userAsJSON []byte, +) ([]byte, error) { + switch authScope { + case AuthScopePassword: + return m.checkUserAndPass(username, password, ip, protocol, userAsJSON) + case AuthScopePublicKey: + return m.checkUserAndPublicKey(username, pkey, ip, protocol, userAsJSON) + case AuthScopeKeyboardInteractive: + return m.checkUserAndKeyboardInteractive(username, ip, protocol, userAsJSON) + case AuthScopeTLSCertificate: + cert, err := util.EncodeTLSCertToPem(tlsCert) + if err != nil { + logger.Error(logSender, "", "unable to encode tls certificate to pem: %v", err) + return nil, fmt.Errorf("unable to encode tls cert to pem: %w", err) + } + return m.checkUserAndTLSCert(username, cert, ip, protocol, userAsJSON) + default: + return nil, fmt.Errorf("unsupported auth scope: %v", authScope) + } +} + +// ExecuteKeyboardInteractiveStep executes a keyboard interactive step +func (m *Manager) ExecuteKeyboardInteractiveStep(req *KeyboardAuthRequest) (*KeyboardAuthResponse, error) { + var plugin *authPlugin + + m.authLock.Lock() + for _, p := range m.auths { + if p.config.AuthOptions.Scope&AuthScopePassword != 0 { + plugin = p + break + } + } + m.authLock.Unlock() + + if plugin == nil { + return nil, errors.New("no auth plugin configured for keyaboard interactive authentication step") + } + + return plugin.sendKeyboardIteractiveRequest(req) +} + +func (m *Manager) checkUserAndPass(username, password, ip, protocol string, userAsJSON []byte) ([]byte, error) { + var plugin *authPlugin + + m.authLock.Lock() + for _, p := range m.auths { + if p.config.AuthOptions.Scope&AuthScopePassword != 0 { + plugin = p + break + } + } + m.authLock.Unlock() + + if plugin == nil { + return nil, errors.New("no auth plugin configured for password checking") + } + + return plugin.checkUserAndPass(username, password, ip, protocol, userAsJSON) +} + +func (m *Manager) checkUserAndPublicKey(username, pubKey, ip, protocol string, userAsJSON []byte) ([]byte, error) { + var plugin *authPlugin + + m.authLock.Lock() + for _, p := range m.auths { + if p.config.AuthOptions.Scope&AuthScopePublicKey != 0 { + plugin = p + break + } + } + m.authLock.Unlock() + + if plugin == nil { + return nil, errors.New("no auth plugin configured for public key checking") + } + + return plugin.checkUserAndPublicKey(username, pubKey, ip, protocol, userAsJSON) +} + +func (m *Manager) checkUserAndTLSCert(username, tlsCert, ip, protocol string, userAsJSON []byte) ([]byte, error) { + var plugin *authPlugin + + m.authLock.Lock() + for _, p := range m.auths { + if p.config.AuthOptions.Scope&AuthScopeTLSCertificate != 0 { + plugin = p + break + } + } + m.authLock.Unlock() + + if plugin == nil { + return nil, errors.New("no auth plugin configured for TLS certificate checking") + } + + return plugin.checkUserAndTLSCertificate(username, tlsCert, ip, protocol, userAsJSON) +} + +func (m *Manager) checkUserAndKeyboardInteractive(username, ip, protocol string, userAsJSON []byte) ([]byte, error) { + var plugin *authPlugin + + m.authLock.Lock() + for _, p := range m.auths { + if p.config.AuthOptions.Scope&AuthScopeKeyboardInteractive != 0 { + plugin = p + break + } + } + m.authLock.Unlock() + + if plugin == nil { + return nil, errors.New("no auth plugin configured for keyboard interactive checking") + } + + return plugin.checkUserAndKeyboardInteractive(username, ip, protocol, userAsJSON) +} + +func (m *Manager) checkCrashedPlugins() { + m.notifLock.RLock() + for idx, n := range m.notifiers { + if n.exited() { + defer func(cfg Config, index int) { + Handler.restartNotifierPlugin(cfg, index) + }(n.config, idx) + } else { + n.sendQueuedEvents() + } + } + m.notifLock.RUnlock() + + m.kmsLock.RLock() + for idx, k := range m.kms { + if k.exited() { + defer func(cfg Config, index int) { + Handler.restartKMSPlugin(cfg, index) + }(k.config, idx) + } + } + m.kmsLock.RUnlock() + + m.authLock.RLock() + for idx, a := range m.auths { + if a.exited() { + defer func(cfg Config, index int) { + Handler.restartAuthPlugin(cfg, index) + }(a.config, idx) + } + } + m.authLock.RUnlock() + + if m.hasSearcher { + m.searcherLock.RLock() + if m.searcher.exited() { + defer func(cfg Config) { + Handler.restartSearcherPlugin(cfg) + }(m.searcher.config) + } + m.searcherLock.RUnlock() + } + + if m.hasIPFilter { + m.ipFilterLock.RLock() + if m.filter.exited() { + defer func(cfg Config) { + Handler.restartIPFilterPlugin(cfg) + }(m.filter.config) + } + m.ipFilterLock.RUnlock() + } +} + +func (m *Manager) restartNotifierPlugin(config Config, idx int) { + if m.closed.Load() { + return + } + logger.Info(logSender, "", "try to restart crashed notifier plugin %q, idx: %v", config.Cmd, idx) + plugin, err := newNotifierPlugin(config) + if err != nil { + logger.Error(logSender, "", "unable to restart notifier plugin %q, err: %v", config.Cmd, err) + return + } + + m.notifLock.Lock() + plugin.fsEvents = m.notifiers[idx].fsEvents + plugin.providerEvents = m.notifiers[idx].providerEvents + plugin.logEvents = m.notifiers[idx].logEvents + m.notifiers[idx] = plugin + m.notifLock.Unlock() + plugin.sendQueuedEvents() +} + +func (m *Manager) restartKMSPlugin(config Config, idx int) { + if m.closed.Load() { + return + } + logger.Info(logSender, "", "try to restart crashed kms plugin %q, idx: %v", config.Cmd, idx) + plugin, err := newKMSPlugin(config) + if err != nil { + logger.Error(logSender, "", "unable to restart kms plugin %q, err: %v", config.Cmd, err) + return + } + + m.kmsLock.Lock() + m.kms[idx] = plugin + m.kmsLock.Unlock() +} + +func (m *Manager) restartAuthPlugin(config Config, idx int) { + if m.closed.Load() { + return + } + logger.Info(logSender, "", "try to restart crashed auth plugin %q, idx: %v", config.Cmd, idx) + plugin, err := newAuthPlugin(config) + if err != nil { + logger.Error(logSender, "", "unable to restart auth plugin %q, err: %v", config.Cmd, err) + return + } + + m.authLock.Lock() + m.auths[idx] = plugin + m.authLock.Unlock() +} + +func (m *Manager) restartSearcherPlugin(config Config) { + if m.closed.Load() { + return + } + logger.Info(logSender, "", "try to restart crashed searcher plugin %q", config.Cmd) + plugin, err := newSearcherPlugin(config) + if err != nil { + logger.Error(logSender, "", "unable to restart searcher plugin %q, err: %v", config.Cmd, err) + return + } + + m.searcherLock.Lock() + m.searcher = plugin + m.searcherLock.Unlock() +} + +func (m *Manager) restartIPFilterPlugin(config Config) { + if m.closed.Load() { + return + } + logger.Info(logSender, "", "try to restart crashed IP filter plugin %q", config.Cmd) + plugin, err := newIPFilterPlugin(config) + if err != nil { + logger.Error(logSender, "", "unable to restart IP filter plugin %q, err: %v", config.Cmd, err) + return + } + + m.ipFilterLock.Lock() + m.filter = plugin + m.ipFilterLock.Unlock() +} + +func (m *Manager) addTask() { + m.concurrencyGuard <- struct{}{} +} + +func (m *Manager) removeTask() { + <-m.concurrencyGuard +} + +// Cleanup releases all the active plugins +func (m *Manager) Cleanup() { + if m.closed.Swap(true) { + return + } + logger.Debug(logSender, "", "cleanup") + close(m.done) + m.notifLock.Lock() + for _, n := range m.notifiers { + logger.Debug(logSender, "", "cleanup notifier plugin %v", n.config.Cmd) + n.cleanup() + } + m.notifLock.Unlock() + + m.kmsLock.Lock() + for _, k := range m.kms { + logger.Debug(logSender, "", "cleanup kms plugin %v", k.config.Cmd) + k.cleanup() + } + m.kmsLock.Unlock() + + m.authLock.Lock() + for _, a := range m.auths { + logger.Debug(logSender, "", "cleanup auth plugin %v", a.config.Cmd) + a.cleanup() + } + m.authLock.Unlock() + + if m.hasSearcher { + m.searcherLock.Lock() + logger.Debug(logSender, "", "cleanup searcher plugin %v", m.searcher.config.Cmd) + m.searcher.cleanup() + m.searcherLock.Unlock() + } + + if m.hasIPFilter { + m.ipFilterLock.Lock() + logger.Debug(logSender, "", "cleanup IP filter plugin %v", m.filter.config.Cmd) + m.filter.cleanup() + m.ipFilterLock.Unlock() + } +} + +func setLogLevel(logLevel string) { + switch logLevel { + case "info": + pluginsLogLevel = hclog.Info + case "warn": + pluginsLogLevel = hclog.Warn + case "error": + pluginsLogLevel = hclog.Error + default: + pluginsLogLevel = hclog.Debug + } +} + +func startCheckTicker() { + logger.Debug(logSender, "", "start plugins checker") + + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-Handler.done: + logger.Debug(logSender, "", "handler done, stop plugins checker") + return + case <-ticker.C: + Handler.checkCrashedPlugins() + } + } + }() +} diff --git a/internal/plugin/searcher.go b/internal/plugin/searcher.go new file mode 100644 index 00000000..013c12bd --- /dev/null +++ b/internal/plugin/searcher.go @@ -0,0 +1,94 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "fmt" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-plugin" + "github.com/sftpgo/sdk/plugin/eventsearcher" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +type searcherPlugin struct { + config Config + searchear eventsearcher.Searcher + client *plugin.Client +} + +func newSearcherPlugin(config Config) (*searcherPlugin, error) { + p := &searcherPlugin{ + config: config, + } + if err := p.initialize(); err != nil { + logger.Warn(logSender, "", "unable to create events searcher plugin: %v, config %+v", err, config) + return nil, err + } + return p, nil +} + +func (p *searcherPlugin) exited() bool { + return p.client.Exited() +} + +func (p *searcherPlugin) cleanup() { + p.client.Kill() +} + +func (p *searcherPlugin) initialize() error { + killProcess(p.config.Cmd) + logger.Debug(logSender, "", "create new searcher plugin %q", p.config.Cmd) + secureConfig, err := p.config.getSecureConfig() + if err != nil { + return err + } + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: eventsearcher.Handshake, + Plugins: eventsearcher.PluginMap, + Cmd: p.config.getCommand(), + SkipHostEnv: true, + AllowedProtocols: []plugin.Protocol{ + plugin.ProtocolGRPC, + }, + AutoMTLS: p.config.AutoMTLS, + SecureConfig: secureConfig, + Managed: false, + Logger: &logger.HCLogAdapter{ + Logger: hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("%v.%v", logSender, eventsearcher.PluginName), + Level: pluginsLogLevel, + DisableTime: true, + }), + }, + }) + rpcClient, err := client.Client() + if err != nil { + logger.Debug(logSender, "", "unable to get rpc client for plugin %q: %v", p.config.Cmd, err) + return err + } + raw, err := rpcClient.Dispense(eventsearcher.PluginName) + if err != nil { + logger.Debug(logSender, "", "unable to get plugin %v from rpc client for command %q: %v", + eventsearcher.PluginName, p.config.Cmd, err) + return err + } + + p.client = client + p.searchear = raw.(eventsearcher.Searcher) + + return nil +} diff --git a/internal/plugin/util.go b/internal/plugin/util.go new file mode 100644 index 00000000..8063e88b --- /dev/null +++ b/internal/plugin/util.go @@ -0,0 +1,39 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package plugin + +import ( + "github.com/shirou/gopsutil/v3/process" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +func killProcess(processPath string) { + procs, err := process.Processes() + if err != nil { + return + } + for _, p := range procs { + cmdLine, err := p.Exe() + if err == nil { + if cmdLine == processPath { + err = p.Kill() + logger.Debug(logSender, "", "killed process %v, pid %v, err %v", cmdLine, p.Pid, err) + return + } + } + } + logger.Debug(logSender, "", "no match for plugin process %v", processPath) +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 00000000..6cb9db8b --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,400 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package service allows to start and stop the SFTPGo service +package service + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/rs/zerolog" + + "github.com/drakkan/sftpgo/v2/internal/acme" + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + logSender = "service" +) + +var ( + graceTime int +) + +// Service defines the SFTPGo service +type Service struct { + ConfigDir string + ConfigFile string + LogFilePath string + LogMaxSize int + LogMaxBackups int + LogMaxAge int + PortableMode int + PortableUser dataprovider.User + LogCompress bool + LogLevel string + LogUTCTime bool + LoadDataClean bool + LoadDataFrom string + LoadDataMode int + LoadDataQuotaScan int + Shutdown chan bool + Error error +} + +func (s *Service) initLogger() { + var logLevel zerolog.Level + switch s.LogLevel { + case "info": + logLevel = zerolog.InfoLevel + case "warn": + logLevel = zerolog.WarnLevel + case "error": + logLevel = zerolog.ErrorLevel + default: + logLevel = zerolog.DebugLevel + } + if !filepath.IsAbs(s.LogFilePath) && util.IsFileInputValid(s.LogFilePath) { + s.LogFilePath = filepath.Join(s.ConfigDir, s.LogFilePath) + } + logger.InitLogger(s.LogFilePath, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogCompress, s.LogUTCTime, logLevel) + if s.PortableMode == 1 { + logger.EnableConsoleLogger(logLevel) + if s.LogFilePath == "" { + logger.DisableLogger() + } + } +} + +// Start initializes and starts the service +func (s *Service) Start() error { + s.initLogger() + logger.Info(logSender, "", "starting SFTPGo %s, config dir: %s, config file: %s, log max size: %d log max backups: %d "+ + "log max age: %d log level: %s, log compress: %t, log utc time: %t, load data from: %q, grace time: %d secs", + version.GetAsString(), s.ConfigDir, s.ConfigFile, s.LogMaxSize, s.LogMaxBackups, s.LogMaxAge, s.LogLevel, + s.LogCompress, s.LogUTCTime, s.LoadDataFrom, graceTime) + // in portable mode we don't read configuration from file + if s.PortableMode != 1 { + err := config.LoadConfig(s.ConfigDir, s.ConfigFile) + if err != nil { + logger.Error(logSender, "", "error loading configuration: %v", err) + return err + } + } + if !config.HasServicesToStart() { + const infoString = "no service configured, nothing to do" + logger.Info(logSender, "", infoString) + logger.InfoToConsole(infoString) + return errors.New(infoString) + } + + if err := s.initializeServices(); err != nil { + return err + } + + s.startServices() + go common.Config.ExecuteStartupHook() //nolint:errcheck + + return nil +} + +func (s *Service) initializeServices() error { + providerConf := config.GetProviderConf() + kmsConfig := config.GetKMSConfig() + err := kmsConfig.Initialize() + if err != nil { + logger.Error(logSender, "", "unable to initialize KMS: %v", err) + logger.ErrorToConsole("unable to initialize KMS: %v", err) + return err + } + // We may have KMS plugins and their schema needs to be registered before + // initializing the data provider which may contain KMS secrets. + if err := plugin.Initialize(config.GetPluginsConfig(), s.LogLevel); err != nil { + logger.Error(logSender, "", "unable to initialize plugin system: %v", err) + logger.ErrorToConsole("unable to initialize plugin system: %v", err) + return err + } + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.Error(logSender, "", "unable to initialize MFA: %v", err) + logger.ErrorToConsole("unable to initialize MFA: %v", err) + return err + } + err = dataprovider.Initialize(providerConf, s.ConfigDir, s.PortableMode == 0) + if err != nil { + logger.Error(logSender, "", "error initializing data provider: %v", err) + logger.ErrorToConsole("error initializing data provider: %v", err) + return err + } + smtpConfig := config.GetSMTPConfig() + err = smtpConfig.Initialize(s.ConfigDir, s.PortableMode != 1) + if err != nil { + logger.Error(logSender, "", "unable to initialize SMTP configuration: %v", err) + logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err) + return err + } + err = common.Initialize(config.GetCommonConfig(), providerConf.GetShared()) + if err != nil { + logger.Error(logSender, "", "%v", err) + logger.ErrorToConsole("%v", err) + return err + } + + if s.PortableMode == 1 { + // create the user for portable mode + err = dataprovider.AddUser(&s.PortableUser, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + logger.ErrorToConsole("error adding portable user: %v", err) + return err + } + } else { + acmeConfig := config.GetACMEConfig() + err = acme.Initialize(acmeConfig, s.ConfigDir, true) + if err != nil { + logger.Error(logSender, "", "error initializing ACME configuration: %v", err) + logger.ErrorToConsole("error initializing ACME configuration: %v", err) + return err + } + } + + httpConfig := config.GetHTTPConfig() + err = httpConfig.Initialize(s.ConfigDir) + if err != nil { + logger.Error(logSender, "", "error initializing http client: %v", err) + logger.ErrorToConsole("error initializing http client: %v", err) + return err + } + commandConfig := config.GetCommandConfig() + if err := commandConfig.Initialize(); err != nil { + logger.Error(logSender, "", "error initializing commands configuration: %v", err) + logger.ErrorToConsole("error initializing commands configuration: %v", err) + return err + } + + return nil +} + +func (s *Service) startServices() { + err := s.LoadInitialData() + if err != nil { + logger.Error(logSender, "", "unable to load initial data: %v", err) + logger.ErrorToConsole("unable to load initial data: %v", err) + } + + sftpdConf := config.GetSFTPDConfig() + ftpdConf := config.GetFTPDConfig() + httpdConf := config.GetHTTPDConfig() + webDavDConf := config.GetWebDAVDConfig() + telemetryConf := config.GetTelemetryConfig() + + if sftpdConf.ShouldBind() { + go func() { + redactedConf := sftpdConf + redactedConf.KeyboardInteractiveHook = util.GetRedactedURL(sftpdConf.KeyboardInteractiveHook) + logger.Info(logSender, "", "initializing SFTP server with config %+v", redactedConf) + if err := sftpdConf.Initialize(s.ConfigDir); err != nil { + logger.Error(logSender, "", "could not start SFTP server: %v", err) + logger.ErrorToConsole("could not start SFTP server: %v", err) + s.Error = err + } + s.Shutdown <- true + }() + } else { + logger.Info(logSender, "", "SFTP server not started, disabled in config file") + } + + if httpdConf.ShouldBind() { + go func() { + providerConf := config.GetProviderConf() + if err := httpdConf.Initialize(s.ConfigDir, providerConf.GetShared()); err != nil { + logger.Error(logSender, "", "could not start HTTP server: %v", err) + logger.ErrorToConsole("could not start HTTP server: %v", err) + s.Error = err + } + s.Shutdown <- true + }() + } else { + logger.Info(logSender, "", "HTTP server not started, disabled in config file") + if s.PortableMode != 1 { + logger.InfoToConsole("HTTP server not started, disabled in config file") + } + } + if ftpdConf.ShouldBind() { + go func() { + if err := ftpdConf.Initialize(s.ConfigDir); err != nil { + logger.Error(logSender, "", "could not start FTP server: %v", err) + logger.ErrorToConsole("could not start FTP server: %v", err) + s.Error = err + } + s.Shutdown <- true + }() + } else { + logger.Info(logSender, "", "FTP server not started, disabled in config file") + } + if webDavDConf.ShouldBind() { + go func() { + if err := webDavDConf.Initialize(s.ConfigDir); err != nil { + logger.Error(logSender, "", "could not start WebDAV server: %v", err) + logger.ErrorToConsole("could not start WebDAV server: %v", err) + s.Error = err + } + s.Shutdown <- true + }() + } else { + logger.Info(logSender, "", "WebDAV server not started, disabled in config file") + } + if telemetryConf.ShouldBind() { + go func() { + if err := telemetryConf.Initialize(s.ConfigDir); err != nil { + logger.Error(logSender, "", "could not start telemetry server: %v", err) + logger.ErrorToConsole("could not start telemetry server: %v", err) + s.Error = err + } + s.Shutdown <- true + }() + } else { + logger.Info(logSender, "", "telemetry server not started, disabled in config file") + if s.PortableMode != 1 { + logger.InfoToConsole("telemetry server not started, disabled in config file") + } + } +} + +// Wait blocks until the service exits +func (s *Service) Wait() { + if s.PortableMode != 1 { + registerSignals() + } + <-s.Shutdown +} + +// Stop terminates the service unblocking the Wait method +func (s *Service) Stop() { + close(s.Shutdown) + logger.Debug(logSender, "", "Service stopped") +} + +// LoadInitialData if a data file is set +func (s *Service) LoadInitialData() error { + if s.LoadDataFrom == "" { + return nil + } + if !filepath.IsAbs(s.LoadDataFrom) { + return fmt.Errorf("invalid input_file %q, it must be an absolute path", s.LoadDataFrom) + } + if s.LoadDataMode < 0 || s.LoadDataMode > 1 { + return fmt.Errorf("invalid loaddata-mode %v", s.LoadDataMode) + } + if s.LoadDataQuotaScan < 0 || s.LoadDataQuotaScan > 2 { + return fmt.Errorf("invalid loaddata-scan %v", s.LoadDataQuotaScan) + } + info, err := os.Stat(s.LoadDataFrom) + if err != nil { + return fmt.Errorf("unable to stat file %q: %w", s.LoadDataFrom, err) + } + if info.Size() > httpd.MaxRestoreSize { + return fmt.Errorf("unable to restore input file %q size too big: %d/%d bytes", + s.LoadDataFrom, info.Size(), httpd.MaxRestoreSize) + } + content, err := os.ReadFile(s.LoadDataFrom) + if err != nil { + return fmt.Errorf("unable to read input file %q: %w", s.LoadDataFrom, err) + } + dump, err := dataprovider.ParseDumpData(content) + if err != nil { + return fmt.Errorf("unable to parse file to restore %q: %w", s.LoadDataFrom, err) + } + err = s.restoreDump(&dump) + if err != nil { + return err + } + logger.Info(logSender, "", "data loaded from file %q mode: %v", s.LoadDataFrom, s.LoadDataMode) + logger.InfoToConsole("data loaded from file %q mode: %v", s.LoadDataFrom, s.LoadDataMode) + if s.LoadDataClean { + err = os.Remove(s.LoadDataFrom) + if err == nil { + logger.Info(logSender, "", "file %q deleted after successful load", s.LoadDataFrom) + logger.InfoToConsole("file %q deleted after successful load", s.LoadDataFrom) + } else { + logger.Warn(logSender, "", "unable to delete file %q after successful load: %v", s.LoadDataFrom, err) + logger.WarnToConsole("unable to delete file %q after successful load: %v", s.LoadDataFrom, err) + } + } + return nil +} + +func (s *Service) restoreDump(dump *dataprovider.BackupData) error { + err := httpd.RestoreConfigs(dump.Configs, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore configs from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreIPListEntries(dump.IPLists, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore IP list entries from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreRoles(dump.Roles, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore roles from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreFolders(dump.Folders, s.LoadDataFrom, s.LoadDataMode, s.LoadDataQuotaScan, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore folders from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreGroups(dump.Groups, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore groups from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreUsers(dump.Users, s.LoadDataFrom, s.LoadDataMode, s.LoadDataQuotaScan, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore users from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreAdmins(dump.Admins, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore admins from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreAPIKeys(dump.APIKeys, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore API keys from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreShares(dump.Shares, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore API keys from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreEventActions(dump.EventActions, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, "", "") + if err != nil { + return fmt.Errorf("unable to restore event actions from file %q: %v", s.LoadDataFrom, err) + } + err = httpd.RestoreEventRules(dump.EventRules, s.LoadDataFrom, s.LoadDataMode, dataprovider.ActionExecutorSystem, + "", "", dump.Version) + if err != nil { + return fmt.Errorf("unable to restore event rules from file %q: %v", s.LoadDataFrom, err) + } + return nil +} + +// SetGraceTime sets the grace time +func SetGraceTime(val int) { + graceTime = val +} diff --git a/internal/service/service_portable.go b/internal/service/service_portable.go new file mode 100644 index 00000000..91d5fc79 --- /dev/null +++ b/internal/service/service_portable.go @@ -0,0 +1,280 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !noportable + +package service + +import ( + "fmt" + "math/rand" + "slices" + "strings" + + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +// StartPortableMode starts the service in portable mode +func (s *Service) StartPortableMode(sftpdPort, ftpPort, webdavPort, httpPort int, enabledSSHCommands []string, + ftpsCert, ftpsKey, webDavCert, webDavKey, httpsCert, httpsKey string) error { + if s.PortableMode != 1 { + return fmt.Errorf("service is not configured for portable mode") + } + err := config.LoadConfig(s.ConfigDir, s.ConfigFile) + if err != nil { + fmt.Printf("error loading configuration file: %v using defaults\n", err) + } + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + return err + } + printablePassword := s.configurePortableUser() + dataProviderConf := config.GetProviderConf() + dataProviderConf.Driver = dataprovider.MemoryDataProviderName + dataProviderConf.Name = "" + config.SetProviderConf(dataProviderConf) + httpdConf := config.GetHTTPDConfig() + for idx := range httpdConf.Bindings { + httpdConf.Bindings[idx].Port = 0 + } + config.SetHTTPDConfig(httpdConf) + telemetryConf := config.GetTelemetryConfig() + telemetryConf.BindPort = 0 + config.SetTelemetryConfig(telemetryConf) + + configurePortableSFTPService(sftpdPort, enabledSSHCommands) + configurePortableFTPService(ftpPort, ftpsCert, ftpsKey) + configurePortableWebDAVService(webdavPort, webDavCert, webDavKey) + configurePortableHTTPService(httpPort, httpsCert, httpsKey) + + err = s.Start() + if err != nil { + return err + } + if httpPort >= 0 { + admin := &dataprovider.Admin{ + Username: util.GenerateUniqueID(), + Password: util.GenerateUniqueID(), + Status: 0, + Permissions: []string{dataprovider.PermAdminAny}, + } + if err := dataprovider.AddAdmin(admin, dataprovider.ActionExecutorSystem, "", ""); err != nil { + return err + } + } + + logger.InfoToConsole("Portable mode ready, user: %q, password: %q, public keys: %v, directory: %q, "+ + "permissions: %+v, file patterns filters: %+v %v", s.PortableUser.Username, + printablePassword, s.PortableUser.PublicKeys, s.getPortableDirToServe(), s.PortableUser.Permissions, + s.PortableUser.Filters.FilePatterns, s.getServiceOptionalInfoString()) + return nil +} + +func (s *Service) getServiceOptionalInfoString() string { + var info strings.Builder + if config.GetSFTPDConfig().Bindings[0].IsValid() { + fmt.Fprintf(&info, "SFTP port: %d ", config.GetSFTPDConfig().Bindings[0].Port) + } + if config.GetFTPDConfig().Bindings[0].IsValid() { + fmt.Fprintf(&info, "FTP port: %d ", config.GetFTPDConfig().Bindings[0].Port) + } + if config.GetWebDAVDConfig().Bindings[0].IsValid() { + scheme := "http" + if config.GetWebDAVDConfig().CertificateFile != "" && config.GetWebDAVDConfig().CertificateKeyFile != "" { + scheme = "https" + } + fmt.Fprintf(&info, "WebDAV URL: %v://:%v/ ", scheme, config.GetWebDAVDConfig().Bindings[0].Port) + } + if config.GetHTTPDConfig().Bindings[0].IsValid() { + scheme := "http" + if config.GetHTTPDConfig().CertificateFile != "" && config.GetHTTPDConfig().CertificateKeyFile != "" { + scheme = "https" + } + fmt.Fprintf(&info, "WebClient URL: %s://:%d/ ", scheme, config.GetHTTPDConfig().Bindings[0].Port) + } + return info.String() +} + +func (s *Service) getPortableDirToServe() string { + switch s.PortableUser.FsConfig.Provider { + case sdk.S3FilesystemProvider: + return s.PortableUser.FsConfig.S3Config.KeyPrefix + case sdk.GCSFilesystemProvider: + return s.PortableUser.FsConfig.GCSConfig.KeyPrefix + case sdk.AzureBlobFilesystemProvider: + return s.PortableUser.FsConfig.AzBlobConfig.KeyPrefix + case sdk.SFTPFilesystemProvider: + return s.PortableUser.FsConfig.SFTPConfig.Prefix + case sdk.HTTPFilesystemProvider: + return "/" + default: + return s.PortableUser.HomeDir + } +} + +// configures the portable user and return the printable password if any +func (s *Service) configurePortableUser() string { + if s.PortableUser.Username == "" { + s.PortableUser.Username = "user" + } + printablePassword := "" + if s.PortableUser.Password != "" { + printablePassword = "[redacted]" + } + if len(s.PortableUser.PublicKeys) == 0 && s.PortableUser.Password == "" { + s.PortableUser.Password = util.GenerateUniqueID() + printablePassword = s.PortableUser.Password + } + s.PortableUser.Filters.WebClient = []string{sdk.WebClientSharesDisabled, sdk.WebClientInfoChangeDisabled, + sdk.WebClientPubKeyChangeDisabled, sdk.WebClientPasswordChangeDisabled, sdk.WebClientAPIKeyAuthChangeDisabled, + sdk.WebClientMFADisabled, sdk.WebClientPasswordResetDisabled, sdk.WebClientTLSCertChangeDisabled, + } + if !s.PortableUser.HasAnyPerm([]string{dataprovider.PermUpload, dataprovider.PermOverwrite}, "/") { + s.PortableUser.Filters.WebClient = append(s.PortableUser.Filters.WebClient, sdk.WebClientWriteDisabled) + } + s.configurePortableSecrets() + return printablePassword +} + +func (s *Service) configurePortableSecrets() { + // we created the user before to initialize the KMS so we need to create the secret here + switch s.PortableUser.FsConfig.Provider { + case sdk.S3FilesystemProvider: + payload := s.PortableUser.FsConfig.S3Config.AccessSecret.GetPayload() + s.PortableUser.FsConfig.S3Config.AccessSecret = getSecretFromString(payload) + case sdk.GCSFilesystemProvider: + payload := s.PortableUser.FsConfig.GCSConfig.Credentials.GetPayload() + s.PortableUser.FsConfig.GCSConfig.Credentials = getSecretFromString(payload) + case sdk.AzureBlobFilesystemProvider: + payload := s.PortableUser.FsConfig.AzBlobConfig.AccountKey.GetPayload() + s.PortableUser.FsConfig.AzBlobConfig.AccountKey = getSecretFromString(payload) + payload = s.PortableUser.FsConfig.AzBlobConfig.SASURL.GetPayload() + s.PortableUser.FsConfig.AzBlobConfig.SASURL = getSecretFromString(payload) + case sdk.CryptedFilesystemProvider: + payload := s.PortableUser.FsConfig.CryptConfig.Passphrase.GetPayload() + s.PortableUser.FsConfig.CryptConfig.Passphrase = getSecretFromString(payload) + case sdk.SFTPFilesystemProvider: + payload := s.PortableUser.FsConfig.SFTPConfig.Password.GetPayload() + s.PortableUser.FsConfig.SFTPConfig.Password = getSecretFromString(payload) + payload = s.PortableUser.FsConfig.SFTPConfig.PrivateKey.GetPayload() + s.PortableUser.FsConfig.SFTPConfig.PrivateKey = getSecretFromString(payload) + payload = s.PortableUser.FsConfig.SFTPConfig.KeyPassphrase.GetPayload() + s.PortableUser.FsConfig.SFTPConfig.KeyPassphrase = getSecretFromString(payload) + case sdk.HTTPFilesystemProvider: + payload := s.PortableUser.FsConfig.HTTPConfig.Password.GetPayload() + s.PortableUser.FsConfig.HTTPConfig.Password = getSecretFromString(payload) + payload = s.PortableUser.FsConfig.HTTPConfig.APIKey.GetPayload() + s.PortableUser.FsConfig.HTTPConfig.APIKey = getSecretFromString(payload) + } +} + +func getSecretFromString(payload string) *kms.Secret { + if payload != "" { + return kms.NewPlainSecret(payload) + } + return kms.NewEmptySecret() +} + +func configurePortableSFTPService(port int, enabledSSHCommands []string) { + sftpdConf := config.GetSFTPDConfig() + if len(sftpdConf.Bindings) == 0 { + sftpdConf.Bindings = append(sftpdConf.Bindings, sftpd.Binding{}) + } + if port > 0 { + sftpdConf.Bindings[0].Port = port + } else if port == 0 { + // dynamic ports starts from 49152 + sftpdConf.Bindings[0].Port = 49152 + rand.Intn(15000) + } else { + sftpdConf.Bindings[0].Port = 0 + } + if slices.Contains(enabledSSHCommands, "*") { + sftpdConf.EnabledSSHCommands = sftpd.GetSupportedSSHCommands() + } else { + sftpdConf.EnabledSSHCommands = enabledSSHCommands + } + config.SetSFTPDConfig(sftpdConf) +} + +func configurePortableFTPService(port int, cert, key string) { + ftpConf := config.GetFTPDConfig() + if len(ftpConf.Bindings) == 0 { + ftpConf.Bindings = append(ftpConf.Bindings, ftpd.Binding{}) + } + if port > 0 { + ftpConf.Bindings[0].Port = port + } else if port == 0 { + ftpConf.Bindings[0].Port = 49152 + rand.Intn(15000) + } else { + ftpConf.Bindings[0].Port = 0 + } + ftpConf.Bindings[0].CertificateFile = cert + ftpConf.Bindings[0].CertificateKeyFile = key + config.SetFTPDConfig(ftpConf) +} + +func configurePortableWebDAVService(port int, cert, key string) { + webDavConf := config.GetWebDAVDConfig() + if len(webDavConf.Bindings) == 0 { + webDavConf.Bindings = append(webDavConf.Bindings, webdavd.Binding{}) + } + if port > 0 { + webDavConf.Bindings[0].Port = port + } else if port == 0 { + webDavConf.Bindings[0].Port = 49152 + rand.Intn(15000) + } else { + webDavConf.Bindings[0].Port = 0 + } + webDavConf.Bindings[0].CertificateFile = cert + webDavConf.Bindings[0].CertificateKeyFile = key + if cert != "" && key != "" { + webDavConf.Bindings[0].EnableHTTPS = true + } + config.SetWebDAVDConfig(webDavConf) +} + +func configurePortableHTTPService(port int, cert, key string) { + httpdConf := config.GetHTTPDConfig() + if len(httpdConf.Bindings) == 0 { + httpdConf.Bindings = append(httpdConf.Bindings, httpd.Binding{}) + } + if port > 0 { + httpdConf.Bindings[0].Port = port + } else if port == 0 { + httpdConf.Bindings[0].Port = 49152 + rand.Intn(15000) + } else { + httpdConf.Bindings[0].Port = 0 + } + httpdConf.Bindings[0].CertificateFile = cert + httpdConf.Bindings[0].CertificateKeyFile = key + if cert != "" && key != "" { + httpdConf.Bindings[0].EnableHTTPS = true + } + httpdConf.Bindings[0].EnableWebAdmin = false + httpdConf.Bindings[0].EnableWebClient = true + httpdConf.Bindings[0].EnableRESTAPI = false + httpdConf.Bindings[0].RenderOpenAPI = false + config.SetHTTPDConfig(httpdConf) +} diff --git a/internal/service/service_windows.go b/internal/service/service_windows.go new file mode 100644 index 00000000..16be97ae --- /dev/null +++ b/internal/service/service_windows.go @@ -0,0 +1,403 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package service + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/eventlog" + "golang.org/x/sys/windows/svc/mgr" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/telemetry" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + serviceName = "SFTPGo" + serviceDesc = "Full-featured and highly configurable file transfer server" + rotateLogCmd = svc.Cmd(128) + acceptRotateLog = svc.Accepted(rotateLogCmd) +) + +// Status defines service status +type Status uint8 + +// Supported values for service status +const ( + StatusUnknown Status = iota + StatusRunning + StatusStopped + StatusPaused + StatusStartPending + StatusPausePending + StatusContinuePending + StatusStopPending +) + +type WindowsService struct { + Service Service + isInteractive bool +} + +func (s Status) String() string { + switch s { + case StatusRunning: + return "running" + case StatusStopped: + return "stopped" + case StatusStartPending: + return "start pending" + case StatusPausePending: + return "pause pending" + case StatusPaused: + return "paused" + case StatusContinuePending: + return "continue pending" + case StatusStopPending: + return "stop pending" + default: + return "unknown" + } +} + +func (s *WindowsService) handleExit(wasStopped chan bool) { + s.Service.Wait() + + select { + case <-wasStopped: + // the service was stopped nothing to do + logger.Info(logSender, "", "Windows Service was stopped") + return + default: + // the server failed while running, we must be sure to exit the process. + // The defined recovery action will be executed. + logger.Info(logSender, "", "Service wait ended, error: %v", s.Service.Error) + if s.Service.Error == nil { + os.Exit(0) + } else { + os.Exit(1) + } + } +} + +func (s *WindowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { + changes <- svc.Status{State: svc.StartPending} + + go func() { + if err := s.Service.Start(); err != nil { + logger.Error(logSender, "", "Windows service failed to start, error: %v", err) + s.Service.Error = err + s.Service.Shutdown <- true + return + } + logger.Info(logSender, "", "Windows service started") + cmdsAccepted := svc.AcceptStop | svc.AcceptShutdown | svc.AcceptParamChange | acceptRotateLog + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} + }() + + wasStopped := make(chan bool, 1) + + go s.handleExit(wasStopped) + + changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} +loop: + for { + c := <-r + switch c.Cmd { + case svc.Interrogate: + logger.Debug(logSender, "", "Received service interrogate request, current status: %v", c.CurrentStatus) + changes <- c.CurrentStatus + case svc.Stop, svc.Shutdown: + logger.Debug(logSender, "", "Received service stop request") + changes <- svc.Status{State: svc.StopPending} + wasStopped <- true + s.Service.Stop() + plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) + break loop + case svc.ParamChange: + logger.Debug(logSender, "", "Received reload request") + err := dataprovider.ReloadConfig() + if err != nil { + logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err) + } + err = httpd.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading cert manager: %v", err) + } + err = ftpd.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading FTPD cert manager: %v", err) + } + err = webdavd.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading WebDAV cert manager: %v", err) + } + err = telemetry.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err) + } + err = common.Reload() + if err != nil { + logger.Warn(logSender, "", "error reloading common configs: %v", err) + } + err = sftpd.Reload() + if err != nil { + logger.Warn(logSender, "", "error reloading sftpd revoked certificates: %v", err) + } + case rotateLogCmd: + logger.Debug(logSender, "", "Received log file rotation request") + err := logger.RotateLogFile() + if err != nil { + logger.Warn(logSender, "", "error rotating log file: %v", err) + } + default: + continue loop + } + } + + return false, 0 +} + +func (s *WindowsService) RunService() error { + exePath, err := s.getExePath() + if err != nil { + return err + } + + isService, err := svc.IsWindowsService() + if err != nil { + return err + } + + s.isInteractive = !isService + dir := filepath.Dir(exePath) + if err = os.Chdir(dir); err != nil { + return err + } + if s.isInteractive { + return s.Start() + } + return svc.Run(serviceName, s) +} + +func (s *WindowsService) Start() error { + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer service.Close() + err = service.Start() + if err != nil { + return fmt.Errorf("could not start service: %v", err) + } + return nil +} + +func (s *WindowsService) Reload() error { + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer service.Close() + _, err = service.Control(svc.ParamChange) + if err != nil { + return fmt.Errorf("could not send control=%d: %v", svc.ParamChange, err) + } + return nil +} + +func (s *WindowsService) RotateLogFile() error { + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer service.Close() + _, err = service.Control(rotateLogCmd) + if err != nil { + return fmt.Errorf("could not send control=%d: %v", rotateLogCmd, err) + } + return nil +} + +func (s *WindowsService) Install(args ...string) error { + exePath, err := s.getExePath() + if err != nil { + return err + } + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err == nil { + service.Close() + return fmt.Errorf("service %s already exists", serviceName) + } + config := mgr.Config{ + DisplayName: serviceName, + Description: serviceDesc, + StartType: mgr.StartAutomatic} + service, err = m.CreateService(serviceName, exePath, config, args...) + if err != nil { + return err + } + defer service.Close() + err = eventlog.InstallAsEventCreate(serviceName, eventlog.Error|eventlog.Warning|eventlog.Info) + if err != nil { + if !strings.Contains(err.Error(), "exists") { + service.Delete() + return fmt.Errorf("SetupEventLogSource() failed: %s", err) + } + } + recoveryActions := []mgr.RecoveryAction{ + { + Type: mgr.ServiceRestart, + Delay: 5 * time.Second, + }, + { + Type: mgr.ServiceRestart, + Delay: 60 * time.Second, + }, + { + Type: mgr.ServiceRestart, + Delay: 90 * time.Second, + }, + } + err = service.SetRecoveryActions(recoveryActions, 300) + if err != nil { + service.Delete() + return fmt.Errorf("unable to set recovery actions: %v", err) + } + return nil +} + +func (s *WindowsService) Uninstall() error { + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("service %s is not installed", serviceName) + } + defer service.Close() + err = service.Delete() + if err != nil { + return err + } + err = eventlog.Remove(serviceName) + if err != nil { + return fmt.Errorf("RemoveEventLogSource() failed: %s", err) + } + return nil +} + +func (s *WindowsService) Stop() error { + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer service.Close() + status, err := service.Control(svc.Stop) + if err != nil { + return fmt.Errorf("could not send control=%d: %v", svc.Stop, err) + } + timeout := time.Now().Add(10 * time.Second) + for status.State != svc.Stopped { + if timeout.Before(time.Now()) { + return fmt.Errorf("timeout waiting for service to go to state=%d", svc.Stopped) + } + time.Sleep(300 * time.Millisecond) + status, err = service.Query() + if err != nil { + return fmt.Errorf("could not retrieve service status: %v", err) + } + } + return nil +} + +func (s *WindowsService) Status() (Status, error) { + m, err := mgr.Connect() + if err != nil { + return StatusUnknown, err + } + defer m.Disconnect() + service, err := m.OpenService(serviceName) + if err != nil { + return StatusUnknown, fmt.Errorf("could not access service: %v", err) + } + defer service.Close() + status, err := service.Query() + if err != nil { + return StatusUnknown, fmt.Errorf("could not query service status: %v", err) + } + switch status.State { + case svc.StartPending: + return StatusStartPending, nil + case svc.Running: + return StatusRunning, nil + case svc.PausePending: + return StatusPausePending, nil + case svc.Paused: + return StatusPaused, nil + case svc.ContinuePending: + return StatusContinuePending, nil + case svc.StopPending: + return StatusStopPending, nil + case svc.Stopped: + return StatusStopped, nil + default: + return StatusUnknown, fmt.Errorf("unknown status %v", status) + } +} + +func (s *WindowsService) getExePath() (string, error) { + return os.Executable() +} diff --git a/internal/service/signals_unix.go b/internal/service/signals_unix.go new file mode 100644 index 00000000..cecbea9f --- /dev/null +++ b/internal/service/signals_unix.go @@ -0,0 +1,97 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !windows + +package service + +import ( + "os" + "os/signal" + "syscall" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/ftpd" + "github.com/drakkan/sftpgo/v2/internal/httpd" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/telemetry" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +func registerSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGUSR1) + go func() { + for sig := range c { + switch sig { + case syscall.SIGHUP: + handleSIGHUP() + case syscall.SIGUSR1: + handleSIGUSR1() + case syscall.SIGINT, syscall.SIGTERM: + handleInterrupt() + } + } + }() +} + +func handleSIGHUP() { + logger.Debug(logSender, "", "Received reload request") + err := dataprovider.ReloadConfig() + if err != nil { + logger.Warn(logSender, "", "error reloading dataprovider configuration: %v", err) + } + err = httpd.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading cert manager: %v", err) + } + err = ftpd.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading FTPD cert manager: %v", err) + } + err = webdavd.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading WebDAV cert manager: %v", err) + } + err = telemetry.ReloadCertificateMgr() + if err != nil { + logger.Warn(logSender, "", "error reloading telemetry cert manager: %v", err) + } + err = common.Reload() + if err != nil { + logger.Warn(logSender, "", "error reloading common configs: %v", err) + } + err = sftpd.Reload() + if err != nil { + logger.Warn(logSender, "", "error reloading sftpd revoked certificates: %v", err) + } +} + +func handleSIGUSR1() { + logger.Debug(logSender, "", "Received log file rotation request") + err := logger.RotateLogFile() + if err != nil { + logger.Warn(logSender, "", "error rotating log file: %v", err) + } +} + +func handleInterrupt() { + logger.Debug(logSender, "", "Received interrupt request") + plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) + os.Exit(0) +} diff --git a/internal/service/signals_windows.go b/internal/service/signals_windows.go new file mode 100644 index 00000000..3437f039 --- /dev/null +++ b/internal/service/signals_windows.go @@ -0,0 +1,37 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package service + +import ( + "os" + "os/signal" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/plugin" +) + +func registerSignals() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + for range c { + logger.Debug(logSender, "", "Received interrupt request") + plugin.Handler.Cleanup() + common.WaitForTransfers(graceTime) + os.Exit(0) + } + }() +} diff --git a/internal/sftpd/cryptfs_test.go b/internal/sftpd/cryptfs_test.go new file mode 100644 index 00000000..0b946a75 --- /dev/null +++ b/internal/sftpd/cryptfs_test.go @@ -0,0 +1,508 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd_test + +import ( + "crypto/sha256" + "fmt" + "net/http" + "os" + "path" + "path/filepath" + "testing" + "time" + + "github.com/minio/sio" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + testPassphrase = "test passphrase" +) + +func TestBasicSFTPCryptoHandling(t *testing.T) { + usePubKey := false + u := getTestUserWithCryptFs(usePubKey) + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + encryptedFileSize, err := getEncryptedFileSize(testFileSize) + assert.NoError(t, err) + expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) + assert.Error(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + initialHash, err := computeHashForFile(sha256.New(), testFilePath) + assert.NoError(t, err) + downloadedFileHash, err := computeHashForFile(sha256.New(), localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, initialHash, downloadedFileHash) + info, err := os.Stat(filepath.Join(user.HomeDir, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, encryptedFileSize, info.Size()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + result, err := client.ReadDir(".") + assert.NoError(t, err) + if assert.Len(t, result, 1) { + assert.Equal(t, testFileSize, result[0].Size()) + } + info, err = client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + err = client.Remove(testFileName) + assert.NoError(t, err) + _, err = client.Lstat(testFileName) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-encryptedFileSize, user.UsedQuotaSize) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestOpenReadWriteCryptoFs(t *testing.T) { + // read and write is not supported on crypto fs + usePubKey := false + u := getTestUserWithCryptFs(usePubKey) + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("sample test data") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + buffer := make([]byte, 128) + _, err = sftpFile.ReadAt(buffer, 1) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = sftpFile.Close() + assert.NoError(t, err) + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestEmptyFile(t *testing.T) { + usePubKey := true + u := getTestUserWithCryptFs(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + err = sftpFile.Close() + assert.NoError(t, err) + } + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(0), info.Size()) + } + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, 0, client) + assert.NoError(t, err) + encryptedFileSize, err := getEncryptedFileSize(0) + assert.NoError(t, err) + info, err = os.Stat(filepath.Join(user.HomeDir, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, encryptedFileSize, info.Size()) + } + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadResumeCryptFs(t *testing.T) { + // resuming uploads is not supported + usePubKey := true + u := getTestUserWithCryptFs(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + appendDataSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = appendToTestFile(testFilePath, appendDataSize) + assert.NoError(t, err) + err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize, false, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaFileReplaceCryptFs(t *testing.T) { + usePubKey := false + u := getTestUserWithCryptFs(usePubKey) + u.QuotaFiles = 1000 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + testFileSize := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + encryptedFileSize, err := getEncryptedFileSize(testFileSize) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { //nolint:dupl + defer conn.Close() + defer client.Close() + expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // now replace the same file, the quota must not change + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + // now create a symlink, replace it with a file and check the quota + // replacing a symlink is like uploading a new file + err = client.Symlink(testFileName, testFileName+".link") //nolint:goconst + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + expectedQuotaFiles = expectedQuotaFiles + 1 + expectedQuotaSize = expectedQuotaSize + encryptedFileSize + err = sftpUploadFile(testFilePath, testFileName+".link", testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now set a quota size restriction and upload the same file, upload should fail for space limit exceeded + user.QuotaSize = encryptedFileSize*2 - 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.Error(t, err, "quota size exceeded, file upload must fail") + err = client.Remove(testFileName) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaScanCryptFs(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(65535) + encryptedFileSize, err := getEncryptedFileSize(testFileSize) + assert.NoError(t, err) + expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + // create user with the same home dir, so there is at least an untracked file + user, _, err = httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestGetMimeTypeCryptFs(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("some UTF-8 text so we should get a text/plain mime type") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + err = sftpFile.Close() + assert.NoError(t, err) + } + } + + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(testPassphrase) + fs, err := user.GetFilesystem("connID") + if assert.NoError(t, err) { + assert.True(t, vfs.IsCryptOsFs(fs)) + mime, err := fs.GetMimeType(filepath.Join(user.GetHomeDir(), testFileName)) + assert.NoError(t, err) + assert.Equal(t, "text/plain; charset=utf-8", mime) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestTruncate(t *testing.T) { + // truncate is not supported + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUserWithCryptFs(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + err = f.Truncate(0) + assert.NoError(t, err) + err = f.Truncate(1) + assert.Error(t, err) + } + err = f.Close() + assert.NoError(t, err) + err = client.Truncate(testFileName, 0) + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSCPBasicHandlingCryptoFs(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUserWithCryptFs(usePubKey) + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131074) + encryptedFileSize, err := getEncryptedFileSize(testFileSize) + assert.NoError(t, err) + expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + // test to download a missing file + err = scpDownload(localPath, remoteDownPath, false, false) + assert.Error(t, err, "downloading a missing file via scp must fail") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.NoError(t, err) + fi, err := os.Stat(localPath) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + fi, err = os.Stat(filepath.Join(user.GetHomeDir(), testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, encryptedFileSize, fi.Size()) + } + err = os.Remove(localPath) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + // now overwrite the existing file + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) +} + +func TestSCPRecursiveCryptFs(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUserWithCryptFs(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testBaseDirName := "atestdir" + testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) + testBaseDirDownName := "test_dir_down" //nolint:goconst + testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) + testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) + testFileSize := int64(131074) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize) + assert.NoError(t, err) + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName)) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.NoError(t, err) + // overwrite existing dir + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.NoError(t, err) + err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) + assert.NoError(t, err) + // test download without passing -r + err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false) + assert.Error(t, err, "recursive download without -r must fail") + + fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + // upload to a non existent dir + remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir") + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.Error(t, err, "uploading via scp to a non existent dir must fail") + + err = os.RemoveAll(testBaseDirPath) + assert.NoError(t, err) + err = os.RemoveAll(testBaseDirDownPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func getEncryptedFileSize(size int64) (int64, error) { + encSize, err := sio.EncryptedSize(uint64(size)) + return int64(encSize) + 33, err +} + +func getTestUserWithCryptFs(usePubKey bool) dataprovider.User { + u := getTestUser(usePubKey) + u.FsConfig.Provider = sdk.CryptedFilesystemProvider + u.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(testPassphrase) + return u +} diff --git a/internal/sftpd/handler.go b/internal/sftpd/handler.go new file mode 100644 index 00000000..ca8575f5 --- /dev/null +++ b/internal/sftpd/handler.go @@ -0,0 +1,617 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "io" + "net" + "os" + "path" + "strings" + "time" + + "github.com/pkg/sftp" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// Connection details for an authenticated user +type Connection struct { + *common.BaseConnection + // client's version string + ClientVersion string + // Remote address for this connection + RemoteAddr net.Addr + LocalAddr net.Addr + channel io.ReadWriteCloser + command string +} + +// GetClientVersion returns the connected client's version +func (c *Connection) GetClientVersion() string { + return c.ClientVersion +} + +// GetLocalAddress returns local connection address +func (c *Connection) GetLocalAddress() string { + if c.LocalAddr == nil { + return "" + } + return c.LocalAddr.String() +} + +// GetRemoteAddress returns the connected client's address +func (c *Connection) GetRemoteAddress() string { + if c.RemoteAddr == nil { + return "" + } + return c.RemoteAddr.String() +} + +// GetCommand returns the SSH command, if any +func (c *Connection) GetCommand() string { + return c.command +} + +// Fileread creates a reader for a file on the system and returns the reader back. +func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { + c.UpdateLastActivity() + updateRequestPaths(request) + + if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) { + return nil, sftp.ErrSSHFxPermissionDenied + } + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file read due to transfer count limits") + return nil, c.GetPermissionDeniedError() + } + transferQuota := c.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.Log(logger.LevelInfo, "denying file read due to quota limits") + return nil, c.GetReadQuotaExceededError() + } + + if ok, policy := c.User.IsFileAllowed(request.Filepath); !ok { + c.Log(logger.LevelWarn, "reading file %q is not allowed", request.Filepath) + return nil, c.GetErrorForDeniedFile(policy) + } + + fs, p, err := c.GetFsAndResolvedPath(request.Filepath) + if err != nil { + return nil, err + } + + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreDownload, p, request.Filepath, 0, 0); err != nil { + c.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", request.Filepath, err) + return nil, c.GetPermissionDeniedError() + } + + file, r, cancelFn, err := fs.Open(p, 0) + if err != nil { + c.Log(logger.LevelError, "could not open file %q for reading: %+v", p, err) + return nil, c.GetFsError(fs, err) + } + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, p, request.Filepath, common.TransferDownload, + 0, 0, 0, 0, false, fs, transferQuota) + t := newTransfer(baseTransfer, nil, r, nil) + + return t, nil +} + +// OpenFile implements OpenFileWriter interface +func (c *Connection) OpenFile(request *sftp.Request) (sftp.WriterAtReaderAt, error) { + return c.handleFilewrite(request) +} + +// Filewrite handles the write actions for a file on the system. +func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) { + return c.handleFilewrite(request) +} + +func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) { //nolint:gocyclo + c.UpdateLastActivity() + updateRequestPaths(request) + + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying file write due to transfer count limits") + return nil, c.GetPermissionDeniedError() + } + + if ok, _ := c.User.IsFileAllowed(request.Filepath); !ok { + c.Log(logger.LevelWarn, "writing file %q is not allowed", request.Filepath) + return nil, c.GetPermissionDeniedError() + } + + fs, p, err := c.GetFsAndResolvedPath(request.Filepath) + if err != nil { + return nil, err + } + + filePath := p + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(p) + } + + var errForRead error + if !vfs.HasOpenRWSupport(fs) && request.Pflags().Read { + // read and write mode is only supported for local filesystem + errForRead = sftp.ErrSSHFxOpUnsupported + } + if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) { + // we can try to read only for local fs here, see above. + // os.ErrPermission will become sftp.ErrSSHFxPermissionDenied when sent to + // the client + errForRead = os.ErrPermission + } + + stat, statErr := fs.Lstat(p) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { + if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(request.Filepath)) { + return nil, sftp.ErrSSHFxPermissionDenied + } + return c.handleSFTPUploadToNewFile(fs, request.Pflags(), p, filePath, request.Filepath, errForRead) + } + + if statErr != nil { + c.Log(logger.LevelError, "error performing file stat %q: %+v", p, statErr) + return nil, c.GetFsError(fs, statErr) + } + + // This happen if we upload a file that has the same name of an existing directory + if stat.IsDir() { + c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p) + return nil, sftp.ErrSSHFxOpUnsupported + } + + if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(request.Filepath)) { + return nil, sftp.ErrSSHFxPermissionDenied + } + + return c.handleSFTPUploadToExistingFile(fs, request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead) +} + +// Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading +// or writing to those files. +func (c *Connection) Filecmd(request *sftp.Request) error { + c.UpdateLastActivity() + updateRequestPaths(request) + + switch request.Method { + case "Setstat": + return c.handleSFTPSetstat(request) + case "Rename": + if err := c.Rename(request.Filepath, request.Target); err != nil { + return err + } + case "Rmdir": + return c.RemoveDir(request.Filepath) + case "Mkdir": + err := c.CreateDir(request.Filepath, true) + if err != nil { + return err + } + case "Symlink": + if err := c.CreateSymlink(request.Filepath, request.Target); err != nil { + return err + } + case "Remove": + return c.handleSFTPRemove(request) + default: + return sftp.ErrSSHFxOpUnsupported + } + + return sftp.ErrSSHFxOk +} + +// Filelist is the handler for SFTP filesystem list calls. This will handle calls to list the contents of +// a directory as well as perform file/folder stat calls. +func (c *Connection) Filelist(request *sftp.Request) (sftp.ListerAt, error) { + c.UpdateLastActivity() + updateRequestPaths(request) + + switch request.Method { + case "List": + lister, err := c.ListDir(request.Filepath) + if err != nil { + return nil, err + } + modTime := time.Unix(0, 0) + if request.Filepath != "/" { + lister.Prepend(vfs.NewFileInfo("..", true, 0, modTime, false)) + } + lister.Prepend(vfs.NewFileInfo(".", true, 0, modTime, false)) + return lister, nil + case "Stat": + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { + return nil, sftp.ErrSSHFxPermissionDenied + } + + s, err := c.DoStat(request.Filepath, 0, true) + if err != nil { + return nil, err + } + + return listerAt([]os.FileInfo{s}), nil + default: + return nil, sftp.ErrSSHFxOpUnsupported + } +} + +// Readlink implements the ReadlinkFileLister interface +func (c *Connection) Readlink(filePath string) (string, error) { + filePath = util.CleanPath(filePath) + if err := c.canReadLink(filePath); err != nil { + return "", err + } + + fs, p, err := c.GetFsAndResolvedPath(filePath) + if err != nil { + return "", err + } + + s, err := fs.Readlink(p) + if err != nil { + c.Log(logger.LevelDebug, "error running readlink on path %q: %+v", p, err) + return "", c.GetFsError(fs, err) + } + + if err := c.canReadLink(s); err != nil { + return "", err + } + return s, nil +} + +// Lstat implements LstatFileLister interface +func (c *Connection) Lstat(request *sftp.Request) (sftp.ListerAt, error) { + c.UpdateLastActivity() + updateRequestPaths(request) + + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(request.Filepath)) { + return nil, sftp.ErrSSHFxPermissionDenied + } + + s, err := c.DoStat(request.Filepath, 1, true) + if err != nil { + return nil, err + } + + return listerAt([]os.FileInfo{s}), nil +} + +// RealPath implements the RealPathFileLister interface +func (c *Connection) RealPath(p string) (string, error) { + if c.User.Filters.StartDirectory == "" { + p = util.CleanPath(p) + } else { + p = util.CleanPathWithBase(c.User.Filters.StartDirectory, p) + } + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(p)) { + return "", sftp.ErrSSHFxPermissionDenied + } + fs, fsPath, err := c.GetFsAndResolvedPath(p) + if err != nil { + return "", err + } + if realPather, ok := fs.(vfs.FsRealPather); ok { + realPath, err := realPather.RealPath(fsPath) + if err != nil { + return "", c.GetFsError(fs, err) + } + return realPath, nil + } + return p, nil +} + +// StatVFS implements StatVFSFileCmder interface +func (c *Connection) StatVFS(r *sftp.Request) (*sftp.StatVFS, error) { + c.UpdateLastActivity() + updateRequestPaths(r) + + // we are assuming that r.Filepath is a dir, this could be wrong but should + // not produce any side effect here. + // we don't consider c.User.Filters.MaxUploadFileSize, we return disk stats here + // not the limit for a single file upload + quotaResult, _ := c.HasSpace(true, true, path.Join(r.Filepath, "fakefile.txt")) + + fs, p, err := c.GetFsAndResolvedPath(r.Filepath) + if err != nil { + return nil, err + } + + if !quotaResult.HasSpace { + return c.getStatVFSFromQuotaResult(fs, p, quotaResult) + } + + if quotaResult.QuotaSize == 0 && quotaResult.QuotaFiles == 0 { + // no quota restrictions + statvfs, err := fs.GetAvailableDiskSize(p) + if err == vfs.ErrStorageSizeUnavailable { + return c.getStatVFSFromQuotaResult(fs, p, quotaResult) + } + return statvfs, err + } + + // there is free space but some limits are configured + return c.getStatVFSFromQuotaResult(fs, p, quotaResult) +} + +func (c *Connection) canReadLink(name string) error { + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { + return sftp.ErrSSHFxPermissionDenied + } + ok, policy := c.User.IsFileAllowed(name) + if !ok && policy == sdk.DenyPolicyHide { + return sftp.ErrSSHFxNoSuchFile + } + return nil +} + +func (c *Connection) handleSFTPSetstat(request *sftp.Request) error { + attrs := common.StatAttributes{ + Flags: 0, + } + if request.Attributes() != nil { + if request.AttrFlags().Permissions { + attrs.Flags |= common.StatAttrPerms + attrs.Mode = request.Attributes().FileMode() + } + if request.AttrFlags().UidGid { + attrs.Flags |= common.StatAttrUIDGID + attrs.UID = int(request.Attributes().UID) + attrs.GID = int(request.Attributes().GID) + } + if request.AttrFlags().Acmodtime { + attrs.Flags |= common.StatAttrTimes + attrs.Atime = time.Unix(int64(request.Attributes().Atime), 0) + attrs.Mtime = time.Unix(int64(request.Attributes().Mtime), 0) + } + if request.AttrFlags().Size { + attrs.Flags |= common.StatAttrSize + attrs.Size = int64(request.Attributes().Size) + } + } + + return c.SetStat(request.Filepath, &attrs) +} + +func (c *Connection) handleSFTPRemove(request *sftp.Request) error { + fs, fsPath, err := c.GetFsAndResolvedPath(request.Filepath) + if err != nil { + return err + } + + var fi os.FileInfo + if fi, err = fs.Lstat(fsPath); err != nil { + c.Log(logger.LevelDebug, "failed to remove file %q: stat error: %+v", fsPath, err) + return c.GetFsError(fs, err) + } + if fi.IsDir() && fi.Mode()&os.ModeSymlink == 0 { + c.Log(logger.LevelDebug, "cannot remove %q is not a file/symlink", fsPath) + return sftp.ErrSSHFxFailure + } + + return c.RemoveFile(fs, fsPath, request.Filepath, fi) +} + +func (c *Connection) handleSFTPUploadToNewFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { + diskQuota, transferQuota := c.HasSpace(true, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, c.GetQuotaExceededError() + } + + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, 0, 0); err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, c.GetPermissionDeniedError() + } + + osFlags := getOSOpenFlags(pflags) + file, w, cancelFn, err := fs.Create(filePath, osFlags, c.GetCreateChecks(requestPath, true, false)) + if err != nil { + c.Log(logger.LevelError, "error creating file %q, os flags %d, pflags %+v: %+v", resolvedPath, osFlags, pflags, err) + return nil, c.GetFsError(fs, err) + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + // we can get an error only for resume + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) + t := newTransfer(baseTransfer, w, nil, errForRead) + + return t, nil +} + +func (c *Connection) handleSFTPUploadToExistingFile(fs vfs.Fs, pflags sftp.FileOpenFlags, resolvedPath, filePath string, + fileSize int64, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) { + var err error + diskQuota, transferQuota := c.HasSpace(false, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, c.GetQuotaExceededError() + } + + osFlags := getOSOpenFlags(pflags) + minWriteOffset := int64(0) + isTruncate := osFlags&os.O_TRUNC != 0 + // for upload resumes OpenSSH sets the APPEND flag while WinSCP does not set it, + // so we suppose this is an upload resume if the TRUNCATE flag is not set + isResume := !isTruncate + // if there is a size limit the remaining size cannot be 0 here, since quotaResult.HasSpace + // will return false in this case and we deny the upload before. + // For Cloud FS GetMaxWriteSize will return unsupported operation + maxWriteSize, err := c.GetMaxWriteSize(diskQuota, isResume, fileSize, vfs.IsUploadResumeSupported(fs, fileSize)) + if err != nil { + c.Log(logger.LevelDebug, "unable to get max write size for file %q is resume? %t: %v", + requestPath, isResume, err) + return nil, err + } + + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, fileSize, osFlags); err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, c.GetPermissionDeniedError() + } + + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + _, _, err = fs.Rename(resolvedPath, filePath, 0) + if err != nil { + c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", + resolvedPath, filePath, err) + return nil, c.GetFsError(fs, err) + } + } + + file, w, cancelFn, err := fs.Create(filePath, osFlags, c.GetCreateChecks(requestPath, false, isResume)) + if err != nil { + c.Log(logger.LevelError, "error opening existing file, os flags %v, pflags: %+v, source: %q, err: %+v", + osFlags, pflags, filePath, err) + return nil, c.GetFsError(fs, err) + } + + initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota + if isResume { + c.Log(logger.LevelDebug, "resuming upload requested, file path %q initial size: %d, has append flag %t", + filePath, fileSize, pflags.Append) + // enforce min write offset only if the client passed the APPEND flag or the filesystem + // supports emulated resume + if pflags.Append || !fs.IsUploadResumeSupported() { + minWriteOffset = fileSize + } + initialSize = fileSize + } else { + if isTruncate && vfs.HasTruncateSupport(fs) { + c.updateQuotaAfterTruncate(requestPath, fileSize) + } else { + initialSize = fileSize + truncatedSize = fileSize + } + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) + t := newTransfer(baseTransfer, w, nil, errForRead) + + return t, nil +} + +// Disconnect disconnects the client by closing the channel +func (c *Connection) Disconnect() error { + if c.channel == nil { + c.Log(logger.LevelWarn, "cannot disconnect a nil channel") + return nil + } + return c.channel.Close() +} + +func (c *Connection) getStatVFSFromQuotaResult(fs vfs.Fs, name string, quotaResult vfs.QuotaCheckResult) (*sftp.StatVFS, error) { + s, err := fs.GetAvailableDiskSize(name) + if err == nil { + if quotaResult.QuotaSize == 0 || quotaResult.QuotaSize > int64(s.TotalSpace()) { + quotaResult.QuotaSize = int64(s.TotalSpace()) + } + if quotaResult.QuotaFiles == 0 || quotaResult.QuotaFiles > int(s.Files) { + quotaResult.QuotaFiles = int(s.Files) + } + } else if err != vfs.ErrStorageSizeUnavailable { + return nil, err + } + // if we are unable to get quota size or quota files we add some arbitrary values + if quotaResult.QuotaSize == 0 { + quotaResult.QuotaSize = quotaResult.UsedSize + 8*1024*1024*1024*1024 // 8TB + } + if quotaResult.QuotaFiles == 0 { + quotaResult.QuotaFiles = quotaResult.UsedFiles + 1000000 // 1 million + } + + bsize := uint64(4096) + for bsize > uint64(quotaResult.QuotaSize) { + bsize /= 4 + } + blocks := uint64(quotaResult.QuotaSize) / bsize + bfree := uint64(quotaResult.QuotaSize-quotaResult.UsedSize) / bsize + files := uint64(quotaResult.QuotaFiles) + ffree := uint64(quotaResult.QuotaFiles - quotaResult.UsedFiles) + if !quotaResult.HasSpace { + bfree = 0 + ffree = 0 + } + + return &sftp.StatVFS{ + Bsize: bsize, + Frsize: bsize, + Blocks: blocks, + Bfree: bfree, + Bavail: bfree, + Files: files, + Ffree: ffree, + Favail: ffree, + Namemax: 255, + }, nil +} + +func (c *Connection) updateQuotaAfterTruncate(requestPath string, fileSize int64) { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) + return + } + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck +} + +func getOSOpenFlags(requestFlags sftp.FileOpenFlags) (flags int) { + var osFlags int + if requestFlags.Read && requestFlags.Write { + osFlags |= os.O_RDWR + } else if requestFlags.Write { + osFlags |= os.O_WRONLY + } + // we ignore Append flag since pkg/sftp use WriteAt that cannot work with os.O_APPEND + /*if requestFlags.Append { + osFlags |= os.O_APPEND + }*/ + if requestFlags.Creat { + osFlags |= os.O_CREATE + } + if requestFlags.Trunc { + osFlags |= os.O_TRUNC + } + if requestFlags.Excl { + osFlags |= os.O_EXCL + } + return osFlags +} + +func updateRequestPaths(request *sftp.Request) { + if request.Method == "Symlink" { + request.Filepath = path.Clean(strings.ReplaceAll(request.Filepath, "\\", "/")) + } else { + request.Filepath = util.CleanPath(request.Filepath) + } + + if request.Target != "" { + request.Target = util.CleanPath(request.Target) + } +} diff --git a/internal/sftpd/httpfs_test.go b/internal/sftpd/httpfs_test.go new file mode 100644 index 00000000..5954bf3e --- /dev/null +++ b/internal/sftpd/httpfs_test.go @@ -0,0 +1,372 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd_test + +import ( + "fmt" + "io/fs" + "math" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + httpFsPort = 12345 + defaultHTTPFsUsername = "httpfs_user" +) + +var ( + httpFsSocketPath = filepath.Join(os.TempDir(), "httpfs.sock") +) + +func TestBasicHTTPFsHandling(t *testing.T) { + usePubKey := true + u := getTestUserWithHTTPFs(usePubKey) + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := user.UsedQuotaSize + testFileSize*2 + expectedQuotaFiles := user.UsedQuotaFiles + 2 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) + assert.Error(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, info.Size()) + } + contents, err := client.ReadDir("/") + assert.NoError(t, err) + if assert.Len(t, contents, 1) { + assert.Equal(t, testFileName, contents[0].Name()) + } + dirName := "test dirname" + err = client.Mkdir(dirName) + assert.NoError(t, err) + contents, err = client.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, contents, 2) + contents, err = client.ReadDir(dirName) + assert.NoError(t, err) + assert.Len(t, contents, 0) + err = sftpUploadFile(testFilePath, path.Join(dirName, testFileName), testFileSize, client) + assert.NoError(t, err) + contents, err = client.ReadDir(dirName) + assert.NoError(t, err) + assert.Len(t, contents, 1) + dirRenamed := dirName + "_renamed" + err = client.Rename(dirName, dirRenamed) + assert.NoError(t, err) + info, err = client.Stat(dirRenamed) + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + // mode 0666 and 0444 works on Windows too + newPerm := os.FileMode(0444) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + info, err = client.Stat(testFileName) + assert.NoError(t, err) + assert.Equal(t, newPerm, info.Mode().Perm()) + newPerm = os.FileMode(0666) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + info, err = client.Stat(testFileName) + assert.NoError(t, err) + assert.Equal(t, newPerm, info.Mode().Perm()) + // chtimes + acmodTime := time.Now().Add(-36 * time.Hour) + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.NoError(t, err) + info, err = client.Stat(testFileName) + if assert.NoError(t, err) { + diff := math.Abs(info.ModTime().Sub(acmodTime).Seconds()) + assert.LessOrEqual(t, diff, float64(1)) + } + _, err = client.StatVFS("/") + assert.NoError(t, err) + + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + // execute a quota scan + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + + err = client.Remove(testFileName) + assert.NoError(t, err) + _, err = client.Lstat(testFileName) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) + // truncate + err = client.Truncate(path.Join(dirRenamed, testFileName), 100) + assert.NoError(t, err) + info, err = client.Stat(path.Join(dirRenamed, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, int64(100), info.Size()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, int64(100), user.UsedQuotaSize) + // update quota + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, int64(100), user.UsedQuotaSize) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHTTPFsVirtualFolder(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + folderName := "httpfsfolder" + vdirPath := "/vdir/http fs" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + FsConfig: vfs.Filesystem{ + Provider: sdk.HTTPFilesystemProvider, + HTTPConfig: vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), + Username: defaultHTTPFsUsername, + EqualityCheckMode: 1, + }, + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client) + assert.NoError(t, err) + _, err = client.Stat(path.Join(vdirPath, testFileName)) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) +} + +func TestHTTPFsWalk(t *testing.T) { + user := getTestUserWithHTTPFs(false) + user.FsConfig.HTTPConfig.EqualityCheckMode = 1 + httpFs, err := user.GetFilesystem("") + require.NoError(t, err) + basePath := filepath.Join(os.TempDir(), "httpfs", user.FsConfig.HTTPConfig.Username) + err = os.RemoveAll(basePath) + assert.NoError(t, err) + + var walkedPaths []string + err = httpFs.Walk("/", func(walkedPath string, _ fs.FileInfo, err error) error { + if err != nil { + return err + } + walkedPaths = append(walkedPaths, httpFs.GetRelativePath(walkedPath)) + return nil + }) + require.NoError(t, err) + require.Len(t, walkedPaths, 1) + require.Contains(t, walkedPaths, "/") + // now add some files/folders + for i := 0; i < 10; i++ { + err = os.WriteFile(filepath.Join(basePath, fmt.Sprintf("file%d", i)), nil, os.ModePerm) + assert.NoError(t, err) + err = os.Mkdir(filepath.Join(basePath, fmt.Sprintf("dir%d", i)), os.ModePerm) + assert.NoError(t, err) + for j := 0; j < 5; j++ { + err = os.WriteFile(filepath.Join(basePath, fmt.Sprintf("dir%d", i), fmt.Sprintf("subfile%d", j)), nil, os.ModePerm) + assert.NoError(t, err) + } + } + walkedPaths = nil + err = httpFs.Walk("/", func(walkedPath string, _ fs.FileInfo, err error) error { + if err != nil { + return err + } + walkedPaths = append(walkedPaths, httpFs.GetRelativePath(walkedPath)) + return nil + }) + require.NoError(t, err) + require.Len(t, walkedPaths, 71) + require.Contains(t, walkedPaths, "/") + for i := 0; i < 10; i++ { + require.Contains(t, walkedPaths, path.Join("/", fmt.Sprintf("file%d", i))) + require.Contains(t, walkedPaths, path.Join("/", fmt.Sprintf("dir%d", i))) + for j := 0; j < 5; j++ { + require.Contains(t, walkedPaths, path.Join("/", fmt.Sprintf("dir%d", i), fmt.Sprintf("subfile%d", j))) + } + } + + err = os.RemoveAll(basePath) + assert.NoError(t, err) +} + +func TestHTTPFsOverUNIXSocket(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("UNIX domain sockets are not supported on Windows") + } + assert.Eventually(t, func() bool { + _, err := os.Stat(httpFsSocketPath) + return err == nil + }, 1*time.Second, 50*time.Millisecond) + usePubKey := true + u := getTestUserWithHTTPFs(usePubKey) + u.FsConfig.HTTPConfig.Endpoint = fmt.Sprintf("http://unix?socket_path=%s&api_prefix=%s", + url.QueryEscape(httpFsSocketPath), url.QueryEscape("/api/v1")) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = checkBasicSFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Mkdir(testFileName) + assert.NoError(t, err) + err = client.RemoveDirectory(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func getTestUserWithHTTPFs(usePubKey bool) dataprovider.User { + u := getTestUser(usePubKey) + u.FsConfig.Provider = sdk.HTTPFilesystemProvider + u.FsConfig.HTTPConfig = vfs.HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: fmt.Sprintf("http://127.0.0.1:%d/api/v1", httpFsPort), + Username: defaultHTTPFsUsername, + }, + } + return u +} + +func startHTTPFs() { + if runtime.GOOS != osWindows { + go func() { + if err := httpdtest.StartTestHTTPFsOverUnixSocket(httpFsSocketPath); err != nil { + logger.ErrorToConsole("could not start HTTPfs test server over UNIX socket: %v", err) + os.Exit(1) + } + }() + } + go func() { + if err := httpdtest.StartTestHTTPFs(httpFsPort, nil); err != nil { + logger.ErrorToConsole("could not start HTTPfs test server: %v", err) + os.Exit(1) + } + }() + waitTCPListening(fmt.Sprintf(":%d", httpFsPort)) +} diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go new file mode 100644 index 00000000..44e93faa --- /dev/null +++ b/internal/sftpd/internal_test.go @@ -0,0 +1,1886 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "io/fs" + "net" + "os" + "path/filepath" + "runtime" + "slices" + "testing" + "time" + + "github.com/eikenb/pipeat" + "github.com/pkg/sftp" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + osWindows = "windows" +) + +var ( + configDir = filepath.Join(".", "..", "..") +) + +type MockChannel struct { + Buffer *bytes.Buffer + StdErrBuffer *bytes.Buffer + ReadError error + WriteError error + ShortWriteErr bool +} + +func (c *MockChannel) Read(data []byte) (int, error) { + if c.ReadError != nil { + return 0, c.ReadError + } + return c.Buffer.Read(data) +} + +func (c *MockChannel) Write(data []byte) (int, error) { + if c.WriteError != nil { + return 0, c.WriteError + } + if c.ShortWriteErr { + return 0, nil + } + return c.Buffer.Write(data) +} + +func (c *MockChannel) Close() error { + return nil +} + +func (c *MockChannel) CloseWrite() error { + return nil +} + +func (c *MockChannel) SendRequest(_ string, _ bool, _ []byte) (bool, error) { + return true, nil +} + +func (c *MockChannel) Stderr() io.ReadWriter { + return c.StdErrBuffer +} + +// MockOsFs mockable OsFs +type MockOsFs struct { + vfs.Fs + err error + statErr error + isAtomicUploadSupported bool +} + +// Name returns the name for the Fs implementation +func (fs MockOsFs) Name() string { + return "mockOsFs" +} + +// IsUploadResumeSupported returns true if resuming uploads is supported +func (MockOsFs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (MockOsFs) IsConditionalUploadResumeSupported(_ int64) bool { + return false +} + +// IsAtomicUploadSupported returns true if atomic upload is supported +func (fs MockOsFs) IsAtomicUploadSupported() bool { + return fs.isAtomicUploadSupported +} + +// Stat returns a FileInfo describing the named file +func (fs MockOsFs) Stat(name string) (os.FileInfo, error) { + if fs.statErr != nil { + return nil, fs.statErr + } + return os.Stat(name) +} + +// Lstat returns a FileInfo describing the named file +func (fs MockOsFs) Lstat(name string) (os.FileInfo, error) { + if fs.statErr != nil { + return nil, fs.statErr + } + return os.Lstat(name) +} + +// Remove removes the named file or (empty) directory. +func (fs MockOsFs) Remove(name string, _ bool) error { + if fs.err != nil { + return fs.err + } + return os.Remove(name) +} + +// Rename renames (moves) source to target +func (fs MockOsFs) Rename(source, target string, _ int) (int, int64, error) { + if fs.err != nil { + return -1, -1, fs.err + } + err := os.Rename(source, target) + return -1, -1, err +} + +func newMockOsFs(err, statErr error, atomicUpload bool, connectionID, rootDir string) vfs.Fs { + return &MockOsFs{ + Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), + err: err, + statErr: statErr, + isAtomicUploadSupported: atomicUpload, + } +} + +func TestRemoveNonexistentQuotaScan(t *testing.T) { + assert.False(t, common.QuotaScans.RemoveUserQuotaScan("username")) +} + +func TestGetOSOpenFlags(t *testing.T) { + var flags sftp.FileOpenFlags + flags.Write = true + flags.Excl = true + osFlags := getOSOpenFlags(flags) + assert.NotEqual(t, 0, osFlags&os.O_WRONLY) + assert.NotEqual(t, 0, osFlags&os.O_EXCL) + + flags.Append = true + // append flag should be ignored to allow resume + assert.NotEqual(t, 0, osFlags&os.O_WRONLY) + assert.NotEqual(t, 0, osFlags&os.O_EXCL) +} + +func TestUploadResumeInvalidOffset(t *testing.T) { + testfile := "testfile" //nolint:goconst + file, err := os.Create(testfile) + assert.NoError(t, err) + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testuser", + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) + baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, + common.TransferUpload, 10, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + transfer := newTransfer(baseTransfer, nil, nil, nil) + _, err = transfer.WriteAt([]byte("test"), 0) + assert.Error(t, err, "upload with invalid offset must fail") + if assert.Error(t, transfer.ErrTransfer) { + assert.EqualError(t, err, transfer.ErrTransfer.Error()) + assert.Contains(t, transfer.ErrTransfer.Error(), "invalid write offset") + } + + err = transfer.Close() + if assert.Error(t, err) { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + } + + err = os.Remove(testfile) + assert.NoError(t, err) +} + +func TestReadWriteErrors(t *testing.T) { + testfile := "testfile" + file, err := os.Create(testfile) + assert.NoError(t, err) + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testuser", + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) + baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + transfer := newTransfer(baseTransfer, nil, nil, nil) + err = file.Close() + assert.NoError(t, err) + _, err = transfer.WriteAt([]byte("test"), 0) + assert.Error(t, err, "writing to closed file must fail") + buf := make([]byte, 32768) + _, err = transfer.ReadAt(buf, 0) + assert.Error(t, err, "reading from a closed file must fail") + err = transfer.Close() + assert.Error(t, err) + + r, _, err := pipeat.Pipe() + assert.NoError(t, err) + baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + transfer = newTransfer(baseTransfer, nil, vfs.NewPipeReader(r), nil) + err = transfer.Close() + assert.NoError(t, err) + _, err = transfer.ReadAt(buf, 0) + assert.Error(t, err, "reading from a closed pipe must fail") + + r, w, err := pipeat.Pipe() + assert.NoError(t, err) + pipeWriter := vfs.NewPipeWriter(w) + baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + transfer = newTransfer(baseTransfer, pipeWriter, nil, nil) + + err = r.Close() + assert.NoError(t, err) + errFake := fmt.Errorf("fake upload error") + go func() { + time.Sleep(100 * time.Millisecond) + pipeWriter.Done(errFake) + }() + err = transfer.closeIO() + assert.EqualError(t, err, errFake.Error()) + _, err = transfer.WriteAt([]byte("test"), 0) + assert.Error(t, err, "writing to closed pipe must fail") + err = transfer.BaseTransfer.Close() + assert.EqualError(t, err, errFake.Error()) + + err = os.Remove(testfile) + assert.NoError(t, err) + assert.Len(t, conn.GetTransfers(), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestUnsupportedListOP(t *testing.T) { + conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", dataprovider.User{}) + sftpConn := Connection{ + BaseConnection: conn, + } + request := sftp.NewRequest("Unsupported", "") + _, err := sftpConn.Filelist(request) + assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error()) +} + +func TestTransferCancelFn(t *testing.T) { + testfile := "testfile" + file, err := os.Create(testfile) + assert.NoError(t, err) + isCancelled := false + cancelFn := func() { + isCancelled = true + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testuser", + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + conn := common.NewBaseConnection("", common.ProtocolSFTP, "", "", user) + baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), file.Name(), testfile, common.TransferDownload, + 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + transfer := newTransfer(baseTransfer, nil, nil, nil) + + errFake := errors.New("fake error, this will trigger cancelFn") + transfer.TransferError(errFake) + err = transfer.Close() + if assert.Error(t, err) { + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + } + if assert.Error(t, transfer.ErrTransfer) { + assert.EqualError(t, transfer.ErrTransfer, errFake.Error()) + } + assert.True(t, isCancelled, "cancelFn not called!") + + err = os.Remove(testfile) + assert.NoError(t, err) +} + +func TestUploadFiles(t *testing.T) { + common.Config.UploadMode = common.UploadModeAtomic + fs := vfs.NewOsFs("123", os.TempDir(), "", nil) + u := dataprovider.User{} + c := Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), + } + var flags sftp.FileOpenFlags + flags.Write = true + flags.Trunc = true + _, err := c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) + assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid") + + common.Config.UploadMode = common.UploadModeStandard + _, err = c.handleSFTPUploadToExistingFile(fs, flags, "missing_path", "other_missing_path", 0, "/missing_path", nil) + assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid") + + missingFile := "missing/relative/file.txt" + if runtime.GOOS == osWindows { + missingFile = "missing\\relative\\file.txt" + } + _, err = c.handleSFTPUploadToNewFile(fs, flags, ".", missingFile, "/missing", nil) + assert.Error(t, err, "upload new file in missing path must fail") + + fs = newMockOsFs(nil, nil, false, "123", os.TempDir()) + f, err := os.CreateTemp("", "temp") + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + + tr, err := c.handleSFTPUploadToExistingFile(fs, flags, f.Name(), f.Name(), 123, f.Name(), nil) + if assert.NoError(t, err) { + transfer := tr.(*transfer) + transfers := c.GetTransfers() + if assert.Equal(t, 1, len(transfers)) { + assert.Equal(t, transfers[0].ID, transfer.GetID()) + assert.Equal(t, int64(123), transfer.InitialSize) + err = transfer.Close() + assert.NoError(t, err) + assert.Equal(t, 0, len(c.GetTransfers())) + } + } + err = os.Remove(f.Name()) + assert.NoError(t, err) + common.Config.UploadMode = common.UploadModeAtomicWithResume +} + +func TestWithInvalidHome(t *testing.T) { + u := dataprovider.User{} + u.HomeDir = "home_rel_path" //nolint:goconst + _, err := loginUser(&u, dataprovider.LoginMethodPassword, "", nil) + assert.Error(t, err, "login a user with an invalid home_dir must fail") + + u.HomeDir = os.TempDir() + fs, err := u.GetFilesystem("123") + assert.NoError(t, err) + c := Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), + } + resolved, err := fs.ResolvePath("../upper_path") + assert.NoError(t, err) + assert.Equal(t, filepath.Join(u.HomeDir, "upper_path"), resolved) + _, err = c.StatVFS(&sftp.Request{ + Method: "StatVFS", + Filepath: "../unresolvable-path", + }) + assert.Error(t, err) +} + +func TestResolveWithRootDir(t *testing.T) { + u := dataprovider.User{} + if runtime.GOOS == osWindows { + u.HomeDir = "C:\\" + } else { + u.HomeDir = "/" + } + fs, err := u.GetFilesystem("") + assert.NoError(t, err) + rel, err := filepath.Rel(u.HomeDir, os.TempDir()) + assert.NoError(t, err) + p, err := fs.ResolvePath(rel) + assert.NoError(t, err, "path %v", p) +} + +func TestSFTPGetUsedQuota(t *testing.T) { + u := dataprovider.User{} + u.HomeDir = "home_rel_path" + u.Username = "test_invalid_user" + u.QuotaSize = 4096 + u.QuotaFiles = 1 + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + connection := Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", u), + } + quotaResult, _ := connection.HasSpace(false, false, "/") + assert.False(t, quotaResult.HasSpace) +} + +func TestSupportedSSHCommands(t *testing.T) { + cmds := GetSupportedSSHCommands() + assert.Equal(t, len(supportedSSHCommands), len(cmds)) + + for _, c := range cmds { + assert.True(t, slices.Contains(supportedSSHCommands, c)) + } +} + +func TestSSHCommandPath(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + } + connection := &Connection{ + channel: &mockSSHChannel, + BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, "", "", dataprovider.User{}), + } + sshCommand := sshCommand{ + command: "test", + connection: connection, + args: []string{}, + } + assert.Equal(t, "", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "/tmp/../path"} + assert.Equal(t, "/path", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "/tmp/"} + assert.Equal(t, "/tmp/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "tmp/"} + assert.Equal(t, "/tmp/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "/tmp/../../../path"} + assert.Equal(t, "/path", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", ".."} + assert.Equal(t, "/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "."} + assert.Equal(t, "/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "//"} + assert.Equal(t, "/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "../.."} + assert.Equal(t, "/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-t", "/.."} + assert.Equal(t, "/", sshCommand.getDestPath()) + + sshCommand.args = []string{"-f", "/a space.txt"} + assert.Equal(t, "/a space.txt", sshCommand.getDestPath()) +} + +func TestSSHParseCommandPayload(t *testing.T) { + cmd := "command -a -f /ab\\ à/some\\ spaces\\ \\ \\(\\).txt" + name, args, _ := parseCommandPayload(cmd) + assert.Equal(t, "command", name) + assert.Equal(t, 3, len(args)) + assert.Equal(t, "/ab à/some spaces ().txt", args[2]) + + _, _, err := parseCommandPayload("") + assert.Error(t, err, "parsing invalid command must fail") +} + +func TestSSHCommandErrors(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + } + server, client := net.Pipe() + defer func() { + err := server.Close() + assert.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + user := dataprovider.User{} + user.Permissions = map[string][]string{ + "/": {dataprovider.PermAny}, + } + connection := Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, "", "", user), + channel: &mockSSHChannel, + } + cmd := sshCommand{ + command: "md5sum", + connection: &connection, + args: []string{}, + } + err := cmd.handle() + assert.Error(t, err, "ssh command must fail, we are sending a fake error") + + cmd = sshCommand{ + command: "md5sum", + connection: &connection, + args: []string{"/../../test_file_ftp.dat"}, + } + err = cmd.handle() + assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") + + user = dataprovider.User{} + user.Permissions = map[string][]string{ + "/": {dataprovider.PermAny}, + } + user.HomeDir = filepath.Clean(os.TempDir()) + user.QuotaFiles = 1 + user.UsedQuotaFiles = 2 + cmd.connection.User = user + _, err = cmd.connection.User.GetFilesystem("123") + assert.NoError(t, err) + + cmd = sshCommand{ + command: "sftpgo-remove", + connection: &connection, + args: []string{"/../../src"}, + } + err = cmd.handle() + assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") + + cmd = sshCommand{ + command: "sftpgo-copy", + connection: &connection, + args: []string{"/../../test_src", "."}, + } + err = cmd.handle() + assert.Error(t, err, "ssh command must fail, we are requesting an invalid path") + + err = common.Initialize(common.Config, 0) + assert.NoError(t, err) +} + +func TestCommandsWithExtensionsFilter(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + server, client := net.Pipe() + defer server.Close() + defer client.Close() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "test", + HomeDir: os.TempDir(), + Status: 1, + }, + } + user.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/subdir", + AllowedPatterns: []string{".jpg"}, + DeniedPatterns: []string{}, + }, + } + + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSSH, "", "", user), + channel: &mockSSHChannel, + } + cmd := sshCommand{ + command: "md5sum", + connection: connection, + args: []string{"subdir/test.png"}, + } + err := cmd.handleHashCommands() + assert.EqualError(t, err, common.ErrPermissionDenied.Error()) +} + +func TestSSHCommandsRemoteFs(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + user := dataprovider.User{} + user.FsConfig = vfs.Filesystem{ + Provider: sdk.S3FilesystemProvider, + S3Config: vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: "s3bucket", + Endpoint: "endpoint", + Region: "eu-west-1", + }, + }, + } + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), + channel: &mockSSHChannel, + } + cmd := sshCommand{ + command: "md5sum", + connection: connection, + args: []string{}, + } + + err := cmd.handleSFTPGoCopy() + assert.Error(t, err) + cmd = sshCommand{ + command: "sftpgo-remove", + connection: connection, + args: []string{}, + } + err = cmd.handleSFTPGoRemove() + assert.Error(t, err) +} + +func TestSSHCmdGetFsErrors(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: "relative path", + }, + } + user.Permissions = map[string][]string{} + user.Permissions["/"] = []string{dataprovider.PermAny} + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), + channel: &mockSSHChannel, + } + cmd := sshCommand{ + command: "sftpgo-remove", + connection: connection, + args: []string{"path"}, + } + err := cmd.handleSFTPGoRemove() + assert.Error(t, err) + + cmd = sshCommand{ + command: "sftpgo-copy", + connection: connection, + args: []string{"path1", "path2"}, + } + err = cmd.handleSFTPGoCopy() + assert.Error(t, err) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCommandGetFsError(t *testing.T) { + user := dataprovider.User{ + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + }, + } + + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + } + conn := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", user), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: conn, + args: []string{"-t", "/tmp"}, + }, + } + + err := scpCommand.handleRecursiveUpload() + assert.Error(t, err) + err = scpCommand.handleDownload("") + assert.Error(t, err) +} + +func TestSCPFileMode(t *testing.T) { + mode := getFileModeAsString(0, true) + assert.Equal(t, "0755", mode) + + mode = getFileModeAsString(0700, true) + assert.Equal(t, "0700", mode) + + mode = getFileModeAsString(0750, true) + assert.Equal(t, "0750", mode) + + mode = getFileModeAsString(0777, true) + assert.Equal(t, "0777", mode) + + mode = getFileModeAsString(0640, false) + assert.Equal(t, "0640", mode) + + mode = getFileModeAsString(0600, false) + assert.Equal(t, "0600", mode) + + mode = getFileModeAsString(0, false) + assert.Equal(t, "0644", mode) + + fileMode := uint32(0777) + fileMode = fileMode | uint32(os.ModeSetgid) + fileMode = fileMode | uint32(os.ModeSetuid) + fileMode = fileMode | uint32(os.ModeSticky) + mode = getFileModeAsString(os.FileMode(fileMode), false) + assert.Equal(t, "7777", mode) + + fileMode = uint32(0644) + fileMode = fileMode | uint32(os.ModeSetgid) + mode = getFileModeAsString(os.FileMode(fileMode), false) + assert.Equal(t, "4644", mode) + + fileMode = uint32(0600) + fileMode = fileMode | uint32(os.ModeSetuid) + mode = getFileModeAsString(os.FileMode(fileMode), false) + assert.Equal(t, "2600", mode) + + fileMode = uint32(0044) + fileMode = fileMode | uint32(os.ModeSticky) + mode = getFileModeAsString(os.FileMode(fileMode), false) + assert.Equal(t, "1044", mode) +} + +func TestSCPUploadError(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Join(os.TempDir()), + Permissions: make(map[string][]string), + }, + } + user.Permissions["/"] = []string{dataprovider.PermAny} + + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-t", "/"}, + }, + } + err := scpCommand.handle() + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer([]byte("D0755 0 testdir\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + err = scpCommand.handleRecursiveUpload() + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer([]byte("D0755 a testdir\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + err = scpCommand.handleRecursiveUpload() + assert.Error(t, err) +} + +func TestSCPInvalidEndDir(t *testing.T) { + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer([]byte("E\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: os.TempDir(), + }, + }), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-t", "/tmp"}, + }, + } + err := scpCommand.handleRecursiveUpload() + assert.EqualError(t, err, "unacceptable end dir command") +} + +func TestSCPParseUploadMessage(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: os.TempDir(), + }, + }), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-t", "/tmp"}, + }, + } + _, _, err := scpCommand.parseUploadMessage(fs, "invalid") + assert.Error(t, err, "parsing invalid upload message must fail") + + _, _, err = scpCommand.parseUploadMessage(fs, "D0755 0") + assert.Error(t, err, "parsing incomplete upload message must fail") + + _, _, err = scpCommand.parseUploadMessage(fs, "D0755 invalidsize testdir") + assert.Error(t, err, "parsing upload message with invalid size must fail") + + _, _, err = scpCommand.parseUploadMessage(fs, "D0755 0 ") + assert.Error(t, err, "parsing upload message with invalid name must fail") +} + +func TestSCPProtocolMessages(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-t", "/tmp"}, + }, + } + _, err := scpCommand.readProtocolMessage() + assert.EqualError(t, err, readErr.Error()) + + err = scpCommand.sendConfirmationMessage() + assert.EqualError(t, err, writeErr.Error()) + + err = scpCommand.sendProtocolMessage("E\n") + assert.EqualError(t, err, writeErr.Error()) + + _, err = scpCommand.getNextUploadProtocolMessage() + assert.EqualError(t, err, readErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer([]byte("T1183832947 0 1183833773 0\n")), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + scpCommand.connection.channel = &mockSSHChannel + _, err = scpCommand.getNextUploadProtocolMessage() + assert.EqualError(t, err, writeErr.Error()) + + respBuffer := []byte{0x02} + protocolErrorMsg := "protocol error msg" + respBuffer = append(respBuffer, protocolErrorMsg...) + respBuffer = append(respBuffer, 0x0A) + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(respBuffer), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + scpCommand.connection.channel = &mockSSHChannel + err = scpCommand.readConfirmationMessage() + if assert.Error(t, err) { + assert.Equal(t, protocolErrorMsg, err.Error()) + } + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(respBuffer), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + scpCommand.connection.channel = &mockSSHChannel + + err = scpCommand.downloadDirs(nil, nil) + assert.ErrorIs(t, err, writeErr) +} + +func TestSCPTestDownloadProtocolMessages(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-f", "-p", "/tmp"}, + }, + } + path := "testDir" + err := os.Mkdir(path, os.ModePerm) + assert.NoError(t, err) + stat, err := os.Stat(path) + assert.NoError(t, err) + err = scpCommand.sendDownloadProtocolMessages(path, stat) + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: nil, + } + + err = scpCommand.sendDownloadProtocolMessages(path, stat) + assert.EqualError(t, err, readErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + scpCommand.args = []string{"-f", "/tmp"} + scpCommand.connection.channel = &mockSSHChannel + err = scpCommand.sendDownloadProtocolMessages(path, stat) + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: nil, + } + scpCommand.connection.channel = &mockSSHChannel + err = scpCommand.sendDownloadProtocolMessages(path, stat) + assert.EqualError(t, err, readErr.Error()) + + err = os.Remove(path) + assert.NoError(t, err) +} + +func TestSCPCommandHandleErrors(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + server, client := net.Pipe() + defer func() { + err := server.Close() + assert.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-f", "/tmp"}, + }, + } + err := scpCommand.handle() + assert.EqualError(t, err, readErr.Error()) + scpCommand.args = []string{"-i", "/tmp"} + err = scpCommand.handle() + assert.Error(t, err, "invalid scp command must fail") +} + +func TestSCPErrorsMockFs(t *testing.T) { + errFake := errors.New("fake error") + u := dataprovider.User{} + u.Username = "test" + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.HomeDir = os.TempDir() + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + server, client := net.Pipe() + defer func() { + err := server.Close() + assert.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + connection := &Connection{ + channel: &mockSSHChannel, + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", u), + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-r", "-t", "/tmp"}, + }, + } + testfile := filepath.Join(u.HomeDir, "testfile") + err := os.WriteFile(testfile, []byte("test"), os.ModePerm) + assert.NoError(t, err) + + fs := newMockOsFs(errFake, nil, true, "123", os.TempDir()) + err = scpCommand.handleUploadFile(fs, testfile, testfile, 0, false, 4, "/testfile") + assert.NoError(t, err) + err = os.Remove(testfile) + assert.NoError(t, err) +} + +func TestSCPRecursiveDownloadErrors(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + server, client := net.Pipe() + defer func() { + err := server.Close() + assert.NoError(t, err) + }() + defer func() { + err := client.Close() + assert.NoError(t, err) + }() + fs := vfs.NewOsFs("123", os.TempDir(), "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: os.TempDir(), + }, + }), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-r", "-f", "/tmp"}, + }, + } + path := "testDir" + err := os.Mkdir(path, os.ModePerm) + assert.NoError(t, err) + stat, err := os.Stat(path) + assert.NoError(t, err) + err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat) + assert.EqualError(t, err, writeErr.Error()) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + scpCommand.connection.channel = &mockSSHChannel + err = scpCommand.handleRecursiveDownload(fs, "invalid_dir", "invalid_dir", stat) + assert.Error(t, err, "recursive upload download must fail for a non existing dir") + + err = os.Remove(path) + assert.NoError(t, err) +} + +func TestSCPRecursiveUploadErrors(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{}), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-r", "-t", "/tmp"}, + }, + } + err := scpCommand.handleRecursiveUpload() + assert.Error(t, err, "recursive upload must fail, we send a fake error message") + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: nil, + } + scpCommand.connection.channel = &mockSSHChannel + err = scpCommand.handleRecursiveUpload() + assert.Error(t, err, "recursive upload must fail, we send a fake error message") +} + +func TestSCPCreateDirs(t *testing.T) { + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + u := dataprovider.User{} + u.HomeDir = "home_rel_path" + u.Username = "test" + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + fs, err := u.GetFilesystem("123") + assert.NoError(t, err) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", u), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-r", "-t", "/tmp"}, + }, + } + err = scpCommand.handleCreateDir(fs, "invalid_dir") + assert.Error(t, err, "create invalid dir must fail") +} + +func TestSCPDownloadFileData(t *testing.T) { + testfile := "testfile" + buf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + stdErrBuf := make([]byte, 65535) + mockSSHChannelReadErr := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: nil, + } + mockSSHChannelWriteErr := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: writeErr, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}), + channel: &mockSSHChannelReadErr, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-r", "-f", "/tmp"}, + }, + } + err := os.WriteFile(testfile, []byte("test"), os.ModePerm) + assert.NoError(t, err) + stat, err := os.Stat(testfile) + assert.NoError(t, err) + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) + assert.EqualError(t, err, readErr.Error()) + + scpCommand.connection.channel = &mockSSHChannelWriteErr + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) + assert.EqualError(t, err, writeErr.Error()) + + scpCommand.args = []string{"-r", "-p", "-f", "/tmp"} + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) + assert.EqualError(t, err, writeErr.Error()) + + scpCommand.connection.channel = &mockSSHChannelReadErr + err = scpCommand.sendDownloadFileData(fs, testfile, stat, nil) + assert.EqualError(t, err, readErr.Error()) + + err = os.Remove(testfile) + assert.NoError(t, err) +} + +func TestSCPUploadFiledata(t *testing.T) { + testfile := "testfile" + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + readErr := fmt.Errorf("test read error") + writeErr := fmt.Errorf("test write error") + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: writeErr, + } + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testuser", + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", user), + channel: &mockSSHChannel, + } + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + args: []string{"-r", "-t", "/tmp"}, + }, + } + file, err := os.Create(testfile) + assert.NoError(t, err) + + baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), file.Name(), + "/"+testfile, common.TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + transfer := newTransfer(baseTransfer, nil, nil, nil) + + err = scpCommand.getUploadFileData(2, transfer) + assert.Error(t, err, "upload must fail, we send a fake write error message") + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: readErr, + WriteError: nil, + } + scpCommand.connection.channel = &mockSSHChannel + file, err = os.Create(testfile) + assert.NoError(t, err) + transfer.File = file + transfer.isFinished = false + transfer.Connection.AddTransfer(transfer) + err = scpCommand.getUploadFileData(2, transfer) + assert.Error(t, err, "upload must fail, we send a fake read error message") + + respBuffer := []byte("12") + respBuffer = append(respBuffer, 0x02) + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(respBuffer), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + scpCommand.connection.channel = &mockSSHChannel + file, err = os.Create(testfile) + assert.NoError(t, err) + baseTransfer.File = file + transfer = newTransfer(baseTransfer, nil, nil, nil) + transfer.Connection.AddTransfer(transfer) + err = scpCommand.getUploadFileData(2, transfer) + assert.Error(t, err, "upload must fail, we have not enough data to read") + + // the file is already closed so we have an error on trasfer closing + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + + transfer.Connection.AddTransfer(transfer) + err = scpCommand.getUploadFileData(0, transfer) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrTransferClosed.Error()) + } + transfer.Connection.RemoveTransfer(transfer) + + mockSSHChannel = MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + ReadError: nil, + WriteError: nil, + } + + transfer.Connection.AddTransfer(transfer) + err = scpCommand.getUploadFileData(2, transfer) + assert.ErrorContains(t, err, os.ErrClosed.Error()) + transfer.Connection.RemoveTransfer(transfer) + + err = os.Remove(testfile) + assert.NoError(t, err) + + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestUploadError(t *testing.T) { + common.Config.UploadMode = common.UploadModeAtomic + + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testuser", + }, + } + fs := vfs.NewOsFs("", os.TempDir(), "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSCP, "", "", user), + } + + testfile := "testfile" + fileTempName := "temptestfile" + file, err := os.Create(fileTempName) + assert.NoError(t, err) + baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, file.Name(), + testfile, common.TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) + transfer := newTransfer(baseTransfer, nil, nil, nil) + + errFake := errors.New("fake error") + transfer.TransferError(errFake) + err = transfer.Close() + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + if assert.Error(t, transfer.ErrTransfer) { + assert.EqualError(t, transfer.ErrTransfer, errFake.Error()) + } + assert.Equal(t, int64(0), transfer.BytesReceived.Load()) + + assert.NoFileExists(t, testfile) + assert.NoFileExists(t, fileTempName) + + common.Config.UploadMode = common.UploadModeAtomicWithResume +} + +func TestTransferFailingReader(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "testuser", + HomeDir: os.TempDir(), + }, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret("crypt secret"), + }, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + + fs := newMockOsFs(nil, nil, true, "", os.TempDir()) + connection := &Connection{ + BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, "", "", user), + } + + request := sftp.NewRequest("Open", "afile.txt") + request.Flags = 27 // read,write,create,truncate + + transfer, err := connection.handleFilewrite(request) + require.NoError(t, err) + buf := make([]byte, 32) + _, err = transfer.ReadAt(buf, 0) + assert.ErrorIs(t, err, sftp.ErrSSHFxOpUnsupported) + if c, ok := transfer.(io.Closer); ok { + err = c.Close() + assert.NoError(t, err) + } + + fsPath := filepath.Join(os.TempDir(), "afile.txt") + + r, _, err := pipeat.Pipe() + assert.NoError(t, err) + baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, fsPath, filepath.Base(fsPath), + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + errRead := errors.New("read is not allowed") + tr := newTransfer(baseTransfer, nil, vfs.NewPipeReader(r), errRead) + _, err = tr.ReadAt(buf, 0) + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + + err = tr.Close() + assert.NoError(t, err) + + tr = newTransfer(baseTransfer, nil, nil, errRead) + _, err = tr.ReadAt(buf, 0) + assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) + + err = tr.Close() + assert.NoError(t, err) + + err = os.Remove(fsPath) + assert.NoError(t, err) + assert.Len(t, connection.GetTransfers(), 0) +} + +func TestConfigsFromProvider(t *testing.T) { + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + c := Configuration{} + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Len(t, c.HostKeyAlgorithms, 0) + assert.Len(t, c.KexAlgorithms, 0) + assert.Len(t, c.Ciphers, 0) + assert.Len(t, c.MACs, 0) + assert.Len(t, c.PublicKeyAlgorithms, 0) + configs := dataprovider.Configs{ + SFTPD: &dataprovider.SFTPDConfigs{ + HostKeyAlgos: []string{ssh.KeyAlgoRSA}, + KexAlgorithms: []string{ssh.InsecureKeyExchangeDHGEXSHA1}, + Ciphers: []string{ssh.InsecureCipherAES128CBC}, + MACs: []string{ssh.HMACSHA512ETM}, + PublicKeyAlgos: []string{ssh.InsecureKeyAlgoDSA}, //nolint:staticcheck + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + err = c.loadFromProvider() + assert.NoError(t, err) + expectedHostKeyAlgos := append(preferredHostKeyAlgos, configs.SFTPD.HostKeyAlgos...) + expectedKEXs := append(preferredKexAlgos, configs.SFTPD.KexAlgorithms...) + expectedCiphers := append(preferredCiphers, configs.SFTPD.Ciphers...) + expectedMACs := append(preferredMACs, configs.SFTPD.MACs...) + expectedPublicKeyAlgos := append(preferredPublicKeyAlgos, configs.SFTPD.PublicKeyAlgos...) + assert.Equal(t, expectedHostKeyAlgos, c.HostKeyAlgorithms) + assert.Equal(t, expectedKEXs, c.KexAlgorithms) + assert.Equal(t, expectedCiphers, c.Ciphers) + assert.Equal(t, expectedMACs, c.MACs) + assert.Equal(t, expectedPublicKeyAlgos, c.PublicKeyAlgorithms) + + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestSupportedSecurityOptions(t *testing.T) { + c := Configuration{ + KexAlgorithms: supportedKexAlgos, + MACs: supportedMACs, + Ciphers: supportedCiphers, + } + var defaultKexs []string + for _, k := range supportedKexAlgos { + defaultKexs = append(defaultKexs, k) + if k == ssh.KeyExchangeCurve25519 { + defaultKexs = append(defaultKexs, keyExchangeCurve25519SHA256LibSSH) + } + } + serverConfig := &ssh.ServerConfig{} + err := c.configureSecurityOptions(serverConfig) + assert.NoError(t, err) + assert.Equal(t, supportedCiphers, serverConfig.Ciphers) + assert.Equal(t, supportedMACs, serverConfig.MACs) + assert.Equal(t, defaultKexs, serverConfig.KeyExchanges) + c.KexAlgorithms = append(c.KexAlgorithms, "not a kex") + err = c.configureSecurityOptions(serverConfig) + assert.Error(t, err) + c.KexAlgorithms = append(supportedKexAlgos, "diffie-hellman-group18-sha512") + c.MACs = []string{ + " hmac-sha2-256-etm@openssh.com ", " hmac-sha2-512-etm@openssh.com", + "hmac-sha2-256", "hmac-sha2-512 ", + "hmac-sha1 ", " hmac-sha1-96", + } + err = c.configureSecurityOptions(serverConfig) + assert.NoError(t, err) + assert.Equal(t, supportedCiphers, serverConfig.Ciphers) + assert.Equal(t, supportedMACs, serverConfig.MACs) + assert.Equal(t, defaultKexs, serverConfig.KeyExchanges) +} + +func TestLoadHostKeys(t *testing.T) { + serverConfig := &ssh.ServerConfig{} + c := Configuration{} + c.HostKeys = []string{".", "missing file"} + err := c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + testfile := filepath.Join(os.TempDir(), "invalidkey") + err = os.WriteFile(testfile, []byte("some bytes"), os.ModePerm) + assert.NoError(t, err) + c.HostKeys = []string{testfile} + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + err = os.Remove(testfile) + assert.NoError(t, err) + keysDir := filepath.Join(os.TempDir(), "keys") + err = os.MkdirAll(keysDir, os.ModePerm) + assert.NoError(t, err) + rsaKeyName := filepath.Join(keysDir, defaultPrivateRSAKeyName) + ecdsaKeyName := filepath.Join(keysDir, defaultPrivateECDSAKeyName) + ed25519KeyName := filepath.Join(keysDir, defaultPrivateEd25519KeyName) + nonDefaultKeyName := filepath.Join(keysDir, "akey") + c.HostKeys = []string{nonDefaultKeyName, rsaKeyName, ecdsaKeyName, ed25519KeyName} + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + c.HostKeyAlgorithms = []string{ssh.KeyAlgoRSASHA256} + c.HostKeys = []string{ecdsaKeyName} + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + c.HostKeyAlgorithms = preferredHostKeyAlgos + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.NoError(t, err) + assert.FileExists(t, rsaKeyName) + assert.FileExists(t, ecdsaKeyName) + assert.FileExists(t, ed25519KeyName) + assert.NoFileExists(t, nonDefaultKeyName) + err = os.Remove(rsaKeyName) + assert.NoError(t, err) + err = os.Remove(ecdsaKeyName) + assert.NoError(t, err) + err = os.Remove(ed25519KeyName) + assert.NoError(t, err) + if runtime.GOOS != osWindows { + err = os.Chmod(keysDir, 0551) + assert.NoError(t, err) + c.HostKeys = nil + err = c.checkAndLoadHostKeys(keysDir, serverConfig) + assert.Error(t, err) + c.HostKeys = []string{rsaKeyName, ecdsaKeyName} + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + c.HostKeys = []string{ecdsaKeyName, rsaKeyName} + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + c.HostKeys = []string{ed25519KeyName} + err = c.checkAndLoadHostKeys(configDir, serverConfig) + assert.Error(t, err) + err = os.Chmod(keysDir, 0755) + assert.NoError(t, err) + } + err = os.RemoveAll(keysDir) + assert.NoError(t, err) +} + +func TestCertCheckerInitErrors(t *testing.T) { + c := Configuration{} + c.TrustedUserCAKeys = []string{".", "missing file"} + err := c.initializeCertChecker("") + assert.Error(t, err) + testfile := filepath.Join(os.TempDir(), "invalidkey") + err = os.WriteFile(testfile, []byte("some bytes"), os.ModePerm) + assert.NoError(t, err) + c.TrustedUserCAKeys = []string{testfile} + err = c.initializeCertChecker("") + assert.Error(t, err) + err = os.Remove(testfile) + assert.NoError(t, err) +} + +func TestRecoverer(t *testing.T) { + c := Configuration{} + c.AcceptInboundConnection(nil, nil) + connID := "connectionID" + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, "", "", dataprovider.User{}), + } + c.handleSftpConnection(nil, connection) + sshCmd := sshCommand{ + command: "cd", + connection: connection, + } + err := sshCmd.handle() + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + scpCmd := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: connection, + }, + } + err = scpCmd.handle() + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestListernerAcceptErrors(t *testing.T) { + errFake := errors.New("a fake error") + listener := newFakeListener(errFake) + c := Configuration{} + err := c.serve(listener, nil) + require.EqualError(t, err, errFake.Error()) + err = listener.Close() + require.NoError(t, err) + + errNetFake := &fakeNetError{error: errFake} + listener = newFakeListener(errNetFake) + err = c.serve(listener, nil) + require.EqualError(t, err, errFake.Error()) + err = listener.Close() + require.NoError(t, err) +} + +type fakeNetError struct { + error + count int +} + +func (e *fakeNetError) Timeout() bool { + return false +} + +func (e *fakeNetError) Temporary() bool { + e.count++ + return e.count < 10 +} + +func (e *fakeNetError) Error() string { + return e.error.Error() +} + +type fakeListener struct { + server net.Conn + client net.Conn + err error +} + +func (l *fakeListener) Accept() (net.Conn, error) { + return l.client, l.err +} + +func (l *fakeListener) Close() error { + errClient := l.client.Close() + errServer := l.server.Close() + if errServer != nil { + return errServer + } + return errClient +} + +func (l *fakeListener) Addr() net.Addr { + return l.server.LocalAddr() +} + +func newFakeListener(err error) net.Listener { + server, client := net.Pipe() + + return &fakeListener{ + server: server, + client: client, + err: err, + } +} + +func TestLoadRevokedUserCertsFile(t *testing.T) { + r := revokedCertificates{ + certs: map[string]bool{}, + } + err := r.load() + assert.NoError(t, err) + r.filePath = filepath.Join(os.TempDir(), "sub", "testrevoked") + err = os.MkdirAll(filepath.Dir(r.filePath), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(r.filePath, []byte(`no json`), 0644) + assert.NoError(t, err) + err = r.load() + assert.Error(t, err) + r.filePath = filepath.Dir(r.filePath) + err = r.load() + assert.Error(t, err) + err = os.RemoveAll(r.filePath) + assert.NoError(t, err) +} + +func TestMaxUserSessions(t *testing.T) { + connection := &Connection{ + BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user_max_sessions", + HomeDir: filepath.Clean(os.TempDir()), + MaxSessions: 1, + }, + }), + } + err := common.Connections.Add(connection) + assert.NoError(t, err) + + c := Configuration{} + c.handleSftpConnection(nil, connection) + + buf := make([]byte, 65535) + stdErrBuf := make([]byte, 65535) + mockSSHChannel := MockChannel{ + Buffer: bytes.NewBuffer(buf), + StdErrBuffer: bytes.NewBuffer(stdErrBuf), + } + + conn := &Connection{ + BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user_max_sessions", + HomeDir: filepath.Clean(os.TempDir()), + MaxSessions: 1, + }, + }), + channel: &mockSSHChannel, + } + + sshCmd := sshCommand{ + command: "cd", + connection: conn, + } + err = sshCmd.handle() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "too many open sessions") + } + scpCmd := scpCommand{ + sshCommand: sshCommand{ + command: "scp", + connection: conn, + }, + } + err = scpCmd.handle() + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "too many open sessions") + } + + common.Connections.Remove(connection.GetID()) + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestCanReadSymlink(t *testing.T) { + connection := &Connection{ + BaseConnection: common.NewBaseConnection(xid.New().String(), common.ProtocolSFTP, "", "", dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "user_can_read_symlink", + HomeDir: filepath.Clean(os.TempDir()), + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + "/sub": {dataprovider.PermUpload}, + }, + }, + Filters: dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + FilePatterns: []sdk.PatternsFilter{ + { + Path: "/denied", + DeniedPatterns: []string{"*.txt"}, + DenyPolicy: sdk.DenyPolicyHide, + }, + }, + }, + }, + }), + } + err := connection.canReadLink("/sub/link") + assert.ErrorIs(t, err, sftp.ErrSSHFxPermissionDenied) + + err = connection.canReadLink("/denied/file.txt") + assert.ErrorIs(t, err, sftp.ErrSSHFxNoSuchFile) +} + +func TestAuthenticationErrors(t *testing.T) { + sftpAuthError := newAuthenticationError(nil, "", "") + loginMethod := dataprovider.SSHLoginMethodPassword + username := "test user" + err := newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", util.NewRecordNotFoundError("not found")), + loginMethod, username) + assert.ErrorIs(t, err, sftpAuthError) + assert.ErrorIs(t, err, util.ErrNotFound) + var sftpAuthErr *authenticationError + if assert.ErrorAs(t, err, &sftpAuthErr) { + assert.Equal(t, loginMethod, sftpAuthErr.getLoginMethod()) + assert.Equal(t, username, sftpAuthErr.getUsername()) + } + err = newAuthenticationError(fmt.Errorf("cannot validate credentials: %w", fs.ErrPermission), loginMethod, username) + assert.ErrorIs(t, err, sftpAuthError) + assert.NotErrorIs(t, err, util.ErrNotFound) + err = newAuthenticationError(fmt.Errorf("cert has wrong type %d", ssh.HostCert), loginMethod, username) + assert.ErrorIs(t, err, sftpAuthError) + assert.NotErrorIs(t, err, util.ErrNotFound) + err = newAuthenticationError(errors.New("ssh: certificate signed by unrecognized authority"), loginMethod, username) + assert.ErrorIs(t, err, sftpAuthError) + assert.NotErrorIs(t, err, util.ErrNotFound) + err = newAuthenticationError(nil, loginMethod, username) + assert.ErrorIs(t, err, sftpAuthError) + assert.NotErrorIs(t, err, util.ErrNotFound) +} + +type mockCommandExecutor struct { + Output []byte + Err error +} + +func (f mockCommandExecutor) CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) { + return f.Output, f.Err +} + +func TestVerifyWithOPKSSH(t *testing.T) { + sshCert := []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg4+hKHVPKv183MU/Q7XD/mzDBFSc2YY3eraltxLMGJo0AAAADAQABAAABAQCe6jMoy1xCQgiZkZJ7gi6NLj4uRqz2OaUGK/OJYZTfBqK+SlS9iymAluHu9K+cc4+0qxx0gn7dRTJWINSgzvca6ayYe995EKgD1hE5krh9BH0bRrXB+hGqyslcZOgLNO+v8jYojClQbRtET2tS+xb4k33GCuL5wgla2790ZgOQgs7huQUjG0S8c1W+EYt6fI4cWE/DeEBnv9sqryS8rOb0PbM6WUd7XBadwySFWYQUX0ei56GNt12Z4gADEGlFQV/OnV0PvnTcAMGUl0rfToPgJ4jgogWKoTVWuZ9wyA/x+2LRLRvgm2a969ig937/AH0i0Wq+FzqfK7EXQ99Yf5K/AAAAAAAAAAAAAAACAAAAFGhvc3QuZXhhbXBsZS5jb20ta2V5AAAAFAAAABBob3N0LmV4YW1wbGUuY29tAAAAAGXEzYAAAAAAd8sP4wAAAAAAAAAAAAAAAAAAARcAAAAHc3NoLXJzYQAAAAMBAAEAAAEBAL4PXUPSERufZWCW/hhEnylk3IeMgaa+2HcNY5Cur77a8fYy6OYZAPF+vhJUT0akwGUpTeXAZumAgHECDrJlw1J+jo9ZVT0AKDo0wU77IzNzYxob7+dpB02NJ7DLAXmPauQ07Zc5pWJFVKtmuh7YH9pjYtNXSMOXye7k06PBGzX+ztIt7nPWvD9fR2mZeTSoljeBCGZHwdlnV2ESQlQbBoEI93RPxqxJh/UCDatQPhpDbyverr2ZvB9Y45rqsx6ZVmu5RXl3MfBU1U21W/4ia2di3PybyD4rSmVoam0efcqxo6cBKSHe26OFoTuS9zgdH0iCWL37vqOFmJ7eH91M3nMAAAEUAAAADHJzYS1zaGEyLTI1NgAAAQA/ByIegNZYJRRl413S/8LxGvTZnbxsPwaluoJ/54niGZV9P28THz7d9jXfSHPjalhH93jNPfTYXvI4opnDC37ua1Nu8KKfk40IWXnnDdZLWraUxEidIzhmfVtz8kGdGoFQ8H0EzubL7zKNOTlfSfOoDlmQVOuxT/+eh2mEp4ri0/+8J1mLfLBr8tREX0/iaNjK+RKdcyTMicKursAYMCDdu8vlaphxea+ocyHM9izSX/l33t44V13ueTqIOh2Zbl2UE2k+jk+0dc1CmV0SEoiWiIyt8TRM4yQry1vPlQLsrf28sYM/QMwnhCVhyZO3vs5F25aQWrB9d51VEzBW9/fd host.example.com`) + key, _, _, _, err := ssh.ParseAuthorizedKey(sshCert) //nolint:dogsled + require.NoError(t, err) + cert, ok := key.(*ssh.Certificate) + require.True(t, ok) + c := Configuration{} + c.executor = mockCommandExecutor{ + Err: errors.New("test error"), + } + err = c.verifyWithOPKSSH("user", cert) + assert.Error(t, err) + + c.executor = mockCommandExecutor{} + err = c.verifyWithOPKSSH("", cert) + assert.Error(t, err) + + c.executor = mockCommandExecutor{ + Output: ssh.MarshalAuthorizedKey(cert), + } + err = c.verifyWithOPKSSH("", cert) + assert.Error(t, err) + + c.executor = mockCommandExecutor{ + Output: ssh.MarshalAuthorizedKey(cert.SignatureKey), + } + err = c.verifyWithOPKSSH("", cert) + assert.NoError(t, err) +} diff --git a/internal/sftpd/lister.go b/internal/sftpd/lister.go new file mode 100644 index 00000000..fd0d55cc --- /dev/null +++ b/internal/sftpd/lister.go @@ -0,0 +1,36 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "io" + "os" +) + +type listerAt []os.FileInfo + +// ListAt returns the number of entries copied and an io.EOF error if we made it to the end of the file list. +// Take a look at the pkg/sftp godoc for more information about how this function should work. +func (l listerAt) ListAt(f []os.FileInfo, offset int64) (int, error) { + if offset >= int64(len(l)) { + return 0, io.EOF + } + + n := copy(f, l[offset:]) + if n < len(f) { + return n, io.EOF + } + return n, nil +} diff --git a/internal/sftpd/scp.go b/internal/sftpd/scp.go new file mode 100644 index 00000000..2d91afe3 --- /dev/null +++ b/internal/sftpd/scp.go @@ -0,0 +1,849 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "errors" + "fmt" + "io" + "math" + "os" + "path" + "path/filepath" + "runtime/debug" + "strconv" + "strings" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + okMsg = []byte{0x00} + warnMsg = []byte{0x01} // must be followed by an optional message and a newline + errMsg = []byte{0x02} // must be followed by an optional message and a newline + newLine = []byte{0x0A} +) + +type scpCommand struct { + sshCommand +} + +func (c *scpCommand) handle() (err error) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in handle scp command: %q stack trace: %v", r, string(debug.Stack())) + err = common.ErrGenericFailure + } + }() + if err := common.Connections.Add(c.connection); err != nil { + defer c.connection.CloseFS() //nolint:errcheck + logger.Info(logSender, "", "unable to add SCP connection: %v", err) + return c.sendErrorResponse(err) + } + defer common.Connections.Remove(c.connection.GetID()) + + destPath := c.getDestPath() + c.connection.Log(logger.LevelDebug, "handle scp command, args: %v user: %s, dest path: %q", + c.args, c.connection.User.Username, destPath) + if c.hasFlag("t") { + // -t means "to", so upload + err = c.sendConfirmationMessage() + if err != nil { + return err + } + err = c.handleRecursiveUpload() + if err != nil { + return err + } + } else if c.hasFlag("f") { + // -f means "from" so download + err = c.readConfirmationMessage() + if err != nil { + return err + } + err = c.handleDownload(destPath) + if err != nil { + return err + } + } else { + err = fmt.Errorf("scp command not supported, args: %v", c.args) + c.connection.Log(logger.LevelDebug, "unsupported scp command, args: %v", c.args) + } + c.sendExitStatus(err) + return err +} + +func (c *scpCommand) handleRecursiveUpload() error { + numDirs := 0 + destPath := c.getDestPath() + for { + fs, err := c.connection.User.GetFilesystemForPath(destPath, c.connection.ID) + if err != nil { + c.connection.Log(logger.LevelError, "error uploading file %q: %+v", destPath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %q", destPath)) + return err + } + command, err := c.getNextUploadProtocolMessage() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + c.sendErrorMessage(fs, err) + return err + } + if strings.HasPrefix(command, "E") { + numDirs-- + c.connection.Log(logger.LevelDebug, "received end dir command, num dirs: %v", numDirs) + if numDirs < 0 { + err = errors.New("unacceptable end dir command") + c.sendErrorMessage(nil, err) + return err + } + // the destination dir is now the parent directory + destPath = path.Join(destPath, "..") + } else { + sizeToRead, name, err := c.parseUploadMessage(fs, command) + if err != nil { + return err + } + if strings.HasPrefix(command, "D") { + numDirs++ + destPath = path.Join(destPath, name) + fs, err = c.connection.User.GetFilesystemForPath(destPath, c.connection.ID) + if err != nil { + c.connection.Log(logger.LevelError, "error uploading file %q: %+v", destPath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to get fs for path %q", destPath)) + return err + } + err = c.handleCreateDir(fs, destPath) + if err != nil { + return err + } + c.connection.Log(logger.LevelDebug, "received start dir command, num dirs: %v destPath: %q", numDirs, destPath) + } else if strings.HasPrefix(command, "C") { + err = c.handleUpload(c.getFileUploadDestPath(fs, destPath, name), sizeToRead) + if err != nil { + return err + } + } + } + err = c.sendConfirmationMessage() + if err != nil { + return err + } + } +} + +func (c *scpCommand) handleCreateDir(fs vfs.Fs, dirPath string) error { + c.connection.UpdateLastActivity() + + p, err := fs.ResolvePath(dirPath) + if err != nil { + c.connection.Log(logger.LevelError, "error creating dir: %q, invalid file path, err: %v", dirPath, err) + c.sendErrorMessage(fs, err) + return err + } + if !c.connection.User.HasPerm(dataprovider.PermCreateDirs, path.Dir(dirPath)) { + c.connection.Log(logger.LevelError, "error creating dir: %q, permission denied", dirPath) + c.sendErrorMessage(fs, common.ErrPermissionDenied) + return common.ErrPermissionDenied + } + + info, err := c.connection.DoStat(dirPath, 1, true) + if err == nil && info.IsDir() { + return nil + } + + err = c.createDir(fs, p) + if err != nil { + return err + } + c.connection.Log(logger.LevelDebug, "created dir %q", dirPath) + return nil +} + +// we need to close the transfer if we have an error +func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *transfer) error { + err := c.sendConfirmationMessage() + if err != nil { + transfer.TransferError(err) + transfer.Close() + return err + } + + if sizeToRead > 0 { + // we could replace this method with io.CopyN implementing "Write" method in transfer struct + remaining := sizeToRead + buf := make([]byte, int64(math.Min(32768, float64(sizeToRead)))) + for { + n, err := c.connection.channel.Read(buf) + if err != nil { + transfer.TransferError(err) + transfer.Close() + c.sendErrorMessage(transfer.Fs, err) + return err + } + _, err = transfer.WriteAt(buf[:n], sizeToRead-remaining) + if err != nil { + transfer.Close() + c.sendErrorMessage(transfer.Fs, err) + return err + } + remaining -= int64(n) + if remaining <= 0 { + break + } + if remaining < int64(len(buf)) { + buf = make([]byte, remaining) + } + } + } + err = c.readConfirmationMessage() + if err != nil { + transfer.TransferError(err) + transfer.Close() + return err + } + err = transfer.Close() + if err != nil { + c.sendErrorMessage(transfer.Fs, err) + return err + } + return nil +} + +func (c *scpCommand) handleUploadFile(fs vfs.Fs, resolvedPath, filePath string, sizeToRead int64, isNewFile bool, fileSize int64, requestPath string) error { + if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { + err := fmt.Errorf("denying file write due to transfer count limits") + c.connection.Log(logger.LevelInfo, "denying file write due to transfer count limits") + c.sendErrorMessage(nil, err) + return err + } + diskQuota, transferQuota := c.connection.HasSpace(isNewFile, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + err := fmt.Errorf("denying file write due to quota limits") + c.connection.Log(logger.LevelError, "error uploading file: %q, err: %v", filePath, err) + c.sendErrorMessage(nil, err) + return err + } + _, err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, + fileSize, os.O_TRUNC) + if err != nil { + c.connection.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + err = c.connection.GetPermissionDeniedError() + c.sendErrorMessage(fs, err) + return err + } + + maxWriteSize, _ := c.connection.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) + + file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.connection.GetCreateChecks(requestPath, isNewFile, false)) + if err != nil { + c.connection.Log(logger.LevelError, "error creating file %q: %v", resolvedPath, err) + c.sendErrorMessage(fs, err) + return err + } + + initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota + if !isNewFile { + if vfs.HasTruncateSupport(fs) { + vfolder, err := c.connection.User.GetVirtualFolderForPath(path.Dir(requestPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.connection.User, 0, -fileSize, false) + } else { + dataprovider.UpdateUserQuota(&c.connection.User, 0, -fileSize, false) //nolint:errcheck + } + } else { + initialSize = fileSize + truncatedSize = initialSize + } + if maxWriteSize > 0 { + maxWriteSize += fileSize + } + } + + vfs.SetPathPermissions(fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) + + baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, isNewFile, fs, transferQuota) + t := newTransfer(baseTransfer, w, nil, nil) + + return c.getUploadFileData(sizeToRead, t) +} + +func (c *scpCommand) handleUpload(uploadFilePath string, sizeToRead int64) error { + c.connection.UpdateLastActivity() + + fs, p, err := c.connection.GetFsAndResolvedPath(uploadFilePath) + if err != nil { + c.connection.Log(logger.LevelError, "error uploading file: %q, err: %v", uploadFilePath, err) + c.sendErrorMessage(nil, err) + return err + } + + if ok, _ := c.connection.User.IsFileAllowed(uploadFilePath); !ok { + c.connection.Log(logger.LevelWarn, "writing file %q is not allowed", uploadFilePath) + c.sendErrorMessage(fs, c.connection.GetPermissionDeniedError()) + return common.ErrPermissionDenied + } + + filePath := p + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(p) + } + stat, statErr := fs.Lstat(p) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { + if !c.connection.User.HasPerm(dataprovider.PermUpload, path.Dir(uploadFilePath)) { + c.connection.Log(logger.LevelWarn, "cannot upload file: %q, permission denied", uploadFilePath) + c.sendErrorMessage(fs, common.ErrPermissionDenied) + return common.ErrPermissionDenied + } + return c.handleUploadFile(fs, p, filePath, sizeToRead, true, 0, uploadFilePath) + } + + if statErr != nil { + c.connection.Log(logger.LevelError, "error performing file stat %q: %v", p, statErr) + c.sendErrorMessage(fs, statErr) + return statErr + } + + if stat.IsDir() { + c.connection.Log(logger.LevelError, "attempted to open a directory for writing to: %q", p) + err = fmt.Errorf("attempted to open a directory for writing: %q", p) + c.sendErrorMessage(fs, err) + return err + } + + if !c.connection.User.HasPerm(dataprovider.PermOverwrite, uploadFilePath) { + c.connection.Log(logger.LevelWarn, "cannot overwrite file: %q, permission denied", uploadFilePath) + c.sendErrorMessage(fs, common.ErrPermissionDenied) + return common.ErrPermissionDenied + } + + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + _, _, err = fs.Rename(p, filePath, 0) + if err != nil { + c.connection.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %v", + p, filePath, err) + c.sendErrorMessage(fs, err) + return err + } + } + + return c.handleUploadFile(fs, p, filePath, sizeToRead, false, stat.Size(), uploadFilePath) +} + +func (c *scpCommand) sendDownloadProtocolMessages(virtualDirPath string, stat os.FileInfo) error { + var err error + if c.sendFileTime() { + modTime := stat.ModTime().UnixNano() / 1000000000 + tCommand := fmt.Sprintf("T%d 0 %d 0\n", modTime, modTime) + err = c.sendProtocolMessage(tCommand) + if err != nil { + return err + } + err = c.readConfirmationMessage() + if err != nil { + return err + } + } + + dirName := path.Base(virtualDirPath) + if dirName == "/" || dirName == "." { + dirName = c.connection.User.Username + } + + fileMode := fmt.Sprintf("D%v 0 %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), dirName) + err = c.sendProtocolMessage(fileMode) + if err != nil { + return err + } + err = c.readConfirmationMessage() + return err +} + +// We send first all the files in the root directory and then the directories. +// For each directory we recursively call this method again +func (c *scpCommand) handleRecursiveDownload(fs vfs.Fs, dirPath, virtualPath string, stat os.FileInfo) error { + var err error + if c.isRecursive() { + c.connection.Log(logger.LevelDebug, "recursive download, dir path %q virtual path %q", dirPath, virtualPath) + err = c.sendDownloadProtocolMessages(virtualPath, stat) + if err != nil { + return err + } + // dirPath is a fs path, not a virtual path + lister, err := fs.ReadDir(dirPath) + if err != nil { + c.sendErrorMessage(fs, err) + return err + } + defer lister.Close() + + vdirs := c.connection.User.GetVirtualFoldersInfo(virtualPath) + + var dirs []string + for { + files, err := lister.Next(vfs.ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + c.sendErrorMessage(fs, err) + return err + } + files = c.connection.User.FilterListDir(files, fs.GetRelativePath(dirPath)) + if len(vdirs) > 0 { + files = append(files, vdirs...) + vdirs = nil + } + for _, file := range files { + filePath := fs.GetRelativePath(fs.Join(dirPath, file.Name())) + if file.Mode().IsRegular() || file.Mode()&os.ModeSymlink != 0 { + err = c.handleDownload(filePath) + if err != nil { + c.sendErrorMessage(fs, err) + return err + } + } else if file.IsDir() { + dirs = append(dirs, filePath) + } + } + if finished { + break + } + } + lister.Close() + + return c.downloadDirs(fs, dirs) + } + err = errors.New("unable to send directory for non recursive copy") + c.sendErrorMessage(nil, err) + return err +} + +func (c *scpCommand) downloadDirs(fs vfs.Fs, dirs []string) error { + for _, dir := range dirs { + if err := c.handleDownload(dir); err != nil { + c.sendErrorMessage(fs, err) + return err + } + } + if err := c.sendProtocolMessage("E\n"); err != nil { + return err + } + return c.readConfirmationMessage() +} + +func (c *scpCommand) sendDownloadFileData(fs vfs.Fs, filePath string, stat os.FileInfo, transfer *transfer) error { + var err error + if c.sendFileTime() { + modTime := stat.ModTime().UnixNano() / 1000000000 + tCommand := fmt.Sprintf("T%d 0 %d 0\n", modTime, modTime) + err = c.sendProtocolMessage(tCommand) + if err != nil { + return err + } + err = c.readConfirmationMessage() + if err != nil { + return err + } + } + if vfs.IsCryptOsFs(fs) { + stat = fs.(*vfs.CryptFs).ConvertFileInfo(stat) + } + + fileSize := stat.Size() + readed := int64(0) + fileMode := fmt.Sprintf("C%v %v %v\n", getFileModeAsString(stat.Mode(), stat.IsDir()), fileSize, filepath.Base(filePath)) + err = c.sendProtocolMessage(fileMode) + if err != nil { + return err + } + err = c.readConfirmationMessage() + if err != nil { + return err + } + + // we could replace this method with io.CopyN implementing "Read" method in transfer struct + buf := make([]byte, 32768) + var n int + for { + n, err = transfer.ReadAt(buf, readed) + if err == nil || err == io.EOF { + if n > 0 { + _, err = c.connection.channel.Write(buf[:n]) + } + } + readed += int64(n) + if err != nil { + break + } + } + if err != io.EOF { + c.sendErrorMessage(fs, err) + return err + } + err = c.sendConfirmationMessage() + if err != nil { + return err + } + err = c.readConfirmationMessage() + return err +} + +func (c *scpCommand) handleDownload(filePath string) error { + c.connection.UpdateLastActivity() + + if err := common.Connections.IsNewTransferAllowed(c.connection.User.Username); err != nil { + err := fmt.Errorf("denying file read due to transfer count limits") + c.connection.Log(logger.LevelInfo, "denying file read due to transfer count limits") + c.sendErrorMessage(nil, err) + return err + } + transferQuota := c.connection.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + c.connection.Log(logger.LevelInfo, "denying file read due to quota limits") + c.sendErrorMessage(nil, c.connection.GetReadQuotaExceededError()) + return c.connection.GetReadQuotaExceededError() + } + var err error + + fs, p, err := c.connection.GetFsAndResolvedPath(filePath) + if err != nil { + c.connection.Log(logger.LevelError, "error downloading file %q: %+v", filePath, err) + c.sendErrorMessage(nil, fmt.Errorf("unable to download file %q: %w", filePath, err)) + return err + } + + var stat os.FileInfo + if stat, err = fs.Stat(p); err != nil { + c.connection.Log(logger.LevelError, "error downloading file: %q->%q, err: %v", filePath, p, err) + c.sendErrorMessage(fs, err) + return err + } + + if stat.IsDir() { + if !c.connection.User.HasPerm(dataprovider.PermDownload, filePath) { + c.connection.Log(logger.LevelWarn, "error downloading dir: %q, permission denied", filePath) + c.sendErrorMessage(fs, common.ErrPermissionDenied) + return common.ErrPermissionDenied + } + err = c.handleRecursiveDownload(fs, p, filePath, stat) + return err + } + + if !c.connection.User.HasPerm(dataprovider.PermDownload, path.Dir(filePath)) { + c.connection.Log(logger.LevelWarn, "error downloading dir: %q, permission denied", filePath) + c.sendErrorMessage(fs, common.ErrPermissionDenied) + return common.ErrPermissionDenied + } + + if ok, policy := c.connection.User.IsFileAllowed(filePath); !ok { + c.connection.Log(logger.LevelWarn, "reading file %q is not allowed", filePath) + c.sendErrorMessage(fs, c.connection.GetErrorForDeniedFile(policy)) + return common.ErrPermissionDenied + } + + if _, err := common.ExecutePreAction(c.connection.BaseConnection, common.OperationPreDownload, p, filePath, 0, 0); err != nil { + c.connection.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", filePath, err) + c.sendErrorMessage(fs, common.ErrPermissionDenied) + return common.ErrPermissionDenied + } + + file, r, cancelFn, err := fs.Open(p, 0) + if err != nil { + c.connection.Log(logger.LevelError, "could not open file %q for reading: %v", p, err) + c.sendErrorMessage(fs, err) + return err + } + + baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, p, filePath, + common.TransferDownload, 0, 0, 0, 0, false, fs, transferQuota) + t := newTransfer(baseTransfer, nil, r, nil) + + err = c.sendDownloadFileData(fs, p, stat, t) + // we need to call Close anyway and return close error if any and + // if we have no previous error + if err == nil { + err = t.Close() + } else { + t.TransferError(err) + t.Close() + } + return err +} + +func (c *scpCommand) sendFileTime() bool { + return c.hasFlag("p") +} + +func (c *scpCommand) isRecursive() bool { + return c.hasFlag("r") +} + +func (c *scpCommand) hasFlag(flag string) bool { + for idx := 0; idx < len(c.args)-1; idx++ { + arg := c.args[idx] + if !strings.HasPrefix(arg, "--") && strings.HasPrefix(arg, "-") && strings.Contains(arg, flag) { + return true + } + } + return false +} + +// read the SCP confirmation message and the optional text message +// the channel will be closed on errors +func (c *scpCommand) readConfirmationMessage() error { + var msg strings.Builder + buf := make([]byte, 1) + n, err := c.connection.channel.Read(buf) + if err != nil { + c.connection.channel.Close() + return err + } + if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) { + isError := buf[0] == errMsg[0] + for { + n, err = c.connection.channel.Read(buf) + readed := buf[:n] + if err != nil || (n == 1 && readed[0] == newLine[0]) { + break + } + if n > 0 { + msg.Write(readed) + } + } + c.connection.Log(logger.LevelInfo, "scp error message received: %v is error: %v", msg.String(), isError) + err = fmt.Errorf("%v", msg.String()) + c.connection.channel.Close() + } + return err +} + +// protool messages are newline terminated +func (c *scpCommand) readProtocolMessage() (string, error) { + var command strings.Builder + var err error + buf := make([]byte, 1) + for { + var n int + n, err = c.connection.channel.Read(buf) + if err != nil { + break + } + if n > 0 { + readed := buf[:n] + if n == 1 && readed[0] == newLine[0] { + break + } + command.Write(readed) + } + } + if err != nil && !errors.Is(err, io.EOF) { + c.connection.channel.Close() + } + return command.String(), err +} + +// sendErrorMessage sends an error message and close the channel +// we don't check write errors here, we have to close the channel anyway +// +//nolint:errcheck +func (c *scpCommand) sendErrorMessage(fs vfs.Fs, err error) { + c.connection.channel.Write(errMsg) + if fs != nil { + c.connection.channel.Write([]byte(c.connection.GetFsError(fs, err).Error())) + } else { + c.connection.channel.Write([]byte(err.Error())) + } + c.connection.channel.Write(newLine) + c.connection.channel.Close() +} + +// send scp confirmation message and close the channel if an error happen +func (c *scpCommand) sendConfirmationMessage() error { + _, err := c.connection.channel.Write(okMsg) + if err != nil { + c.connection.channel.Close() + } + return err +} + +// sends a protocol message and close the channel on error +func (c *scpCommand) sendProtocolMessage(message string) error { + _, err := c.connection.channel.Write([]byte(message)) + if err != nil { + c.connection.Log(logger.LevelError, "error sending protocol message: %v, err: %v", message, err) + c.connection.channel.Close() + } + return err +} + +// get the next upload protocol message ignoring T command if any +func (c *scpCommand) getNextUploadProtocolMessage() (string, error) { + var command string + var err error + for { + command, err = c.readProtocolMessage() + if err != nil { + return command, err + } + if strings.HasPrefix(command, "T") { + err = c.sendConfirmationMessage() + if err != nil { + return command, err + } + } else { + break + } + } + return command, err +} + +func (c *scpCommand) createDir(fs vfs.Fs, dirPath string) error { + err := fs.Mkdir(dirPath) + if err != nil { + c.connection.Log(logger.LevelError, "error creating dir %q: %v", dirPath, err) + c.sendErrorMessage(fs, err) + return err + } + vfs.SetPathPermissions(fs, dirPath, c.connection.User.GetUID(), c.connection.User.GetGID()) + return err +} + +// parse protocol messages such as: +// D0755 0 testdir +// or: +// C0644 6 testfile +// and returns file size and file/directory name +func (c *scpCommand) parseUploadMessage(fs vfs.Fs, command string) (int64, string, error) { + var size int64 + var name string + var err error + if !strings.HasPrefix(command, "C") && !strings.HasPrefix(command, "D") { + err = fmt.Errorf("unknown or invalid upload message: %v args: %v user: %v", + command, c.args, c.connection.User.Username) + c.connection.Log(logger.LevelError, "error: %v", err) + c.sendErrorMessage(fs, err) + return size, name, err + } + parts := strings.SplitN(command, " ", 3) + if len(parts) == 3 { + size, err = strconv.ParseInt(parts[1], 10, 64) + if err != nil { + c.connection.Log(logger.LevelError, "error getting size from upload message: %v", err) + c.sendErrorMessage(fs, err) + return size, name, err + } + name = parts[2] + if name == "" { + err = fmt.Errorf("error getting name from upload message, cannot be empty") + c.connection.Log(logger.LevelError, "error: %v", err) + c.sendErrorMessage(fs, err) + return size, name, err + } + } else { + err = fmt.Errorf("unable to split upload message: %q", command) + c.connection.Log(logger.LevelError, "error: %v", err) + c.sendErrorMessage(fs, err) + return size, name, err + } + return size, name, err +} + +func (c *scpCommand) getFileUploadDestPath(fs vfs.Fs, scpDestPath, fileName string) string { + if !c.isRecursive() { + // if the upload is not recursive and the destination path does not end with "/" + // then scpDestPath is the wanted filename, for example: + // scp fileName.txt user@127.0.0.1:/newFileName.txt + // or + // scp fileName.txt user@127.0.0.1:/fileName.txt + if !strings.HasSuffix(scpDestPath, "/") { + // but if scpDestPath is an existing directory then we put the uploaded file + // inside that directory this is as scp command works, for example: + // scp fileName.txt user@127.0.0.1:/existing_dir + if p, err := fs.ResolvePath(scpDestPath); err == nil { + if stat, err := fs.Stat(p); err == nil { + if stat.IsDir() { + return path.Join(scpDestPath, fileName) + } + } + } + return scpDestPath + } + } + // if the upload is recursive or scpDestPath has the "/" suffix then the destination + // file is relative to scpDestPath + return path.Join(scpDestPath, fileName) +} + +func getFileModeAsString(fileMode os.FileMode, isDir bool) string { + var defaultMode string + if isDir { + defaultMode = "0755" + } else { + defaultMode = "0644" + } + if fileMode == 0 { + return defaultMode + } + modeString := []byte(fileMode.String()) + nullPerm := []byte("-") + u := 0 + g := 0 + o := 0 + s := 0 + lastChar := len(modeString) - 1 + if fileMode&os.ModeSticky != 0 { + s++ + } + if fileMode&os.ModeSetuid != 0 { + s += 2 + } + if fileMode&os.ModeSetgid != 0 { + s += 4 + } + if modeString[lastChar-8] != nullPerm[0] { + u += 4 + } + if modeString[lastChar-7] != nullPerm[0] { + u += 2 + } + if modeString[lastChar-6] != nullPerm[0] { + u++ + } + if modeString[lastChar-5] != nullPerm[0] { + g += 4 + } + if modeString[lastChar-4] != nullPerm[0] { + g += 2 + } + if modeString[lastChar-3] != nullPerm[0] { + g++ + } + if modeString[lastChar-2] != nullPerm[0] { + o += 4 + } + if modeString[lastChar-1] != nullPerm[0] { + o += 2 + } + if modeString[lastChar] != nullPerm[0] { + o++ + } + return fmt.Sprintf("%v%v%v%v", s, u, g, o) +} diff --git a/internal/sftpd/server.go b/internal/sftpd/server.go new file mode 100644 index 00000000..3bee040d --- /dev/null +++ b/internal/sftpd/server.go @@ -0,0 +1,1401 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "maps" + "net" + "os" + "os/exec" + "path/filepath" + "runtime/debug" + "slices" + "strings" + "sync" + "time" + + "github.com/pkg/sftp" + "github.com/sftpgo/sdk/plugin/notifier" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + defaultPrivateRSAKeyName = "id_rsa" + defaultPrivateECDSAKeyName = "id_ecdsa" + defaultPrivateEd25519KeyName = "id_ed25519" + sourceAddressCriticalOption = "source-address" + keyExchangeCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org" + extraDataPartialSuccessErrKey = "partialSuccessErr" + extraDataUserKey = "user" + extraDataKeyIDKey = "keyID" + extraDataLoginMethodKey = "login_method" +) + +var ( + supportedAlgos = ssh.SupportedAlgorithms() + insecureAlgos = ssh.InsecureAlgorithms() + sftpExtensions = []string{"statvfs@openssh.com"} + supportedHostKeyAlgos = append(supportedAlgos.HostKeys, insecureAlgos.HostKeys...) + preferredHostKeyAlgos = []string{ + ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512, + ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, + ssh.KeyAlgoED25519, + } + supportedPublicKeyAlgos = append(supportedAlgos.PublicKeyAuths, insecureAlgos.PublicKeyAuths...) + preferredPublicKeyAlgos = supportedAlgos.PublicKeyAuths + supportedKexAlgos = append(supportedAlgos.KeyExchanges, insecureAlgos.KeyExchanges...) + preferredKexAlgos = supportedAlgos.KeyExchanges + supportedCiphers = append(supportedAlgos.Ciphers, insecureAlgos.Ciphers...) + preferredCiphers = supportedAlgos.Ciphers + supportedMACs = append(supportedAlgos.MACs, insecureAlgos.MACs...) + preferredMACs = []string{ + ssh.HMACSHA256ETM, ssh.HMACSHA256, + } + + revokedCertManager = revokedCertificates{ + certs: map[string]bool{}, + } +) + +type commandExecutor interface { + CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) +} + +type defaultExecutor struct{} + +func (d defaultExecutor) CombinedOutput(ctx context.Context, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + cmd.Env = []string{} + return cmd.CombinedOutput() +} + +// Binding defines the configuration for a network listener +type Binding struct { + // The address to listen on. A blank value means listen on all available network interfaces. + Address string `json:"address" mapstructure:"address"` + // The port used for serving requests + Port int `json:"port" mapstructure:"port"` + // Apply the proxy configuration, if any, for this binding + ApplyProxyConfig bool `json:"apply_proxy_config" mapstructure:"apply_proxy_config"` +} + +// GetAddress returns the binding address +func (b *Binding) GetAddress() string { + return fmt.Sprintf("%s:%d", b.Address, b.Port) +} + +// IsValid returns true if the binding port is > 0 +func (b *Binding) IsValid() bool { + return b.Port > 0 +} + +// HasProxy returns true if the proxy protocol is active for this binding +func (b *Binding) HasProxy() bool { + return b.ApplyProxyConfig && common.Config.ProxyProtocol > 0 +} + +// Configuration for the SFTP server +type Configuration struct { + // Addresses and ports to bind to + Bindings []Binding `json:"bindings" mapstructure:"bindings"` + // Maximum number of authentication attempts permitted per connection. + // If set to a negative number, the number of attempts is unlimited. + // If set to zero, the number of attempts are limited to 6. + MaxAuthTries int `json:"max_auth_tries" mapstructure:"max_auth_tries"` + // HostKeys define the daemon's private host keys. + // Each host key can be defined as a path relative to the configuration directory or an absolute one. + // If empty or missing, the daemon will search or try to generate "id_rsa" and "id_ecdsa" host keys + // inside the configuration directory. + HostKeys []string `json:"host_keys" mapstructure:"host_keys"` + // HostCertificates defines public host certificates. + // Each certificate can be defined as a path relative to the configuration directory or an absolute one. + // Certificate's public key must match a private host key otherwise it will be silently ignored. + HostCertificates []string `json:"host_certificates" mapstructure:"host_certificates"` + // HostKeyAlgorithms lists the public key algorithms that the server will accept for host + // key authentication. + HostKeyAlgorithms []string `json:"host_key_algorithms" mapstructure:"host_key_algorithms"` + // KexAlgorithms specifies the available KEX (Key Exchange) algorithms in + // preference order. + KexAlgorithms []string `json:"kex_algorithms" mapstructure:"kex_algorithms"` + // Ciphers specifies the ciphers allowed + Ciphers []string `json:"ciphers" mapstructure:"ciphers"` + // MACs Specifies the available MAC (message authentication code) algorithms + // in preference order + MACs []string `json:"macs" mapstructure:"macs"` + // PublicKeyAlgorithms lists the supported public key algorithms for client authentication. + PublicKeyAlgorithms []string `json:"public_key_algorithms" mapstructure:"public_key_algorithms"` + // TrustedUserCAKeys specifies a list of public keys paths of certificate authorities + // that are trusted to sign user certificates for authentication. + // The paths can be absolute or relative to the configuration directory + TrustedUserCAKeys []string `json:"trusted_user_ca_keys" mapstructure:"trusted_user_ca_keys"` + // Path to a file containing the revoked user certificates. + // This file must contain a JSON list with the public key fingerprints of the revoked certificates. + // Example content: + // ["SHA256:bsBRHC/xgiqBJdSuvSTNpJNLTISP/G356jNMCRYC5Es","SHA256:119+8cL/HH+NLMawRsJx6CzPF1I3xC+jpM60bQHXGE8"] + RevokedUserCertsFile string `json:"revoked_user_certs_file" mapstructure:"revoked_user_certs_file"` + // Absolute path to the opkssh binary used for OpenPubkey SSH integration + OPKSSHPath string `json:"opkssh_path" mapstructure:"opkssh_path"` + // Expected SHA256 checksum of the opkssh binary. It is verified at application startup + OPKSSHChecksum string `json:"opkssh_checksum" mapstructure:"opkssh_checksum"` + // LoginBannerFile the contents of the specified file, if any, are sent to + // the remote user before authentication is allowed. + LoginBannerFile string `json:"login_banner_file" mapstructure:"login_banner_file"` + // List of enabled SSH commands. + // We support the following SSH commands: + // - "scp". SCP is an experimental feature, we have our own SCP implementation since + // we can't rely on scp system command to proper handle permissions, quota and + // user's home dir restrictions. + // The SCP protocol is quite simple but there is no official docs about it, + // so we need more testing and feedbacks before enabling it by default. + // We may not handle some borderline cases or have sneaky bugs. + // Please do accurate tests yourself before enabling SCP and let us known + // if something does not work as expected for your use cases. + // SCP between two remote hosts is supported using the `-3` scp option. + // - "md5sum", "sha1sum", "sha256sum", "sha384sum", "sha512sum". Useful to check message + // digests for uploaded files. These commands are implemented inside SFTPGo so they + // work even if the matching system commands are not available, for example on Windows. + // - "cd", "pwd". Some mobile SFTP clients does not support the SFTP SSH_FXP_REALPATH and so + // they use "cd" and "pwd" SSH commands to get the initial directory. + // Currently `cd` do nothing and `pwd` always returns the "/" path. + // + // The following SSH commands are enabled by default: "md5sum", "sha1sum", "cd", "pwd". + // "*" enables all supported SSH commands. + EnabledSSHCommands []string `json:"enabled_ssh_commands" mapstructure:"enabled_ssh_commands"` + // KeyboardInteractiveAuthentication specifies whether keyboard interactive authentication is allowed. + // If no keyboard interactive hook or auth plugin is defined the default is to prompt for the user password and then the + // one time authentication code, if defined. + KeyboardInteractiveAuthentication bool `json:"keyboard_interactive_authentication" mapstructure:"keyboard_interactive_authentication"` + // Absolute path to an external program or an HTTP URL to invoke for keyboard interactive authentication. + // Leave empty to disable this authentication mode. + KeyboardInteractiveHook string `json:"keyboard_interactive_auth_hook" mapstructure:"keyboard_interactive_auth_hook"` + // PasswordAuthentication specifies whether password authentication is allowed. + PasswordAuthentication bool `json:"password_authentication" mapstructure:"password_authentication"` + certChecker *ssh.CertChecker + parsedUserCAKeys []ssh.PublicKey + executor commandExecutor +} + +type authenticationError struct { + err error + loginMethod string + username string +} + +func (e *authenticationError) Error() string { + return fmt.Sprintf("Authentication error: %v", e.err) +} + +// Is reports if target matches +func (e *authenticationError) Is(target error) bool { + _, ok := target.(*authenticationError) + return ok +} + +// Unwrap returns the wrapped error +func (e *authenticationError) Unwrap() error { + return e.err +} + +func (e *authenticationError) getLoginMethod() string { + return e.loginMethod +} + +func (e *authenticationError) getUsername() string { + return e.username +} + +func newAuthenticationError(err error, loginMethod, username string) *authenticationError { + return &authenticationError{err: err, loginMethod: loginMethod, username: username} +} + +// ShouldBind returns true if there is at least a valid binding +func (c *Configuration) ShouldBind() bool { + for _, binding := range c.Bindings { + if binding.IsValid() { + return true + } + } + + return false +} + +func (c *Configuration) getServerConfig() *ssh.ServerConfig { + serverConfig := &ssh.ServerConfig{ + NoClientAuth: false, + MaxAuthTries: c.MaxAuthTries, + PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + sp, err := c.validatePublicKeyCredentials(conn, pubKey) + if err != nil { + return nil, newAuthenticationError(fmt.Errorf("could not validate public key credentials: %w", err), + dataprovider.SSHLoginMethodPublicKey, conn.User()) + } + + return sp, nil + }, + VerifiedPublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey, permissions *ssh.Permissions, signatureAlgorithm string) (*ssh.Permissions, error) { + if partialErr, ok := permissions.ExtraData[extraDataPartialSuccessErrKey]; ok { + logger.Info(logSender, hex.EncodeToString(conn.SessionID()), "user %q authenticated with partial success, signature algorithm %q", + conn.User(), signatureAlgorithm) + return nil, partialErr.(error) + } + method := dataprovider.SSHLoginMethodPublicKey + user := permissions.ExtraData[extraDataUserKey].(dataprovider.User) + keyID := permissions.ExtraData[extraDataKeyIDKey].(string) + sshPerm, err := loginUser(&user, method, fmt.Sprintf("%s (%s)", keyID, signatureAlgorithm), conn) + if err == nil { + // if we have a SSH user cert we need to merge certificate permissions with our ones + // we only set Extensions, so CriticalOptions are always the ones from the certificate + sshPerm.CriticalOptions = permissions.CriticalOptions + if permissions.Extensions != nil { + if sshPerm.Extensions == nil { + sshPerm.Extensions = make(map[string]string) + } + maps.Copy(sshPerm.Extensions, permissions.Extensions) + } + if sshPerm.ExtraData == nil { + sshPerm.ExtraData = make(map[any]any) + } + } + user.Username = conn.User() + ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + updateLoginMetrics(&user, ipAddr, method, err) + return sshPerm, err + }, + ServerVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)), + } + + if c.PasswordAuthentication { + serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + return c.validatePasswordCredentials(conn, password, dataprovider.LoginMethodPassword) + } + serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.LoginMethodPassword) + } + serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) + + return serverConfig +} + +func (c *Configuration) updateSupportedAuthentications() { + serviceStatus.Authentications = util.RemoveDuplicates(serviceStatus.Authentications, false) + + if slices.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) && + slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) { + serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndPassword) + } + + if slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) && + slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) { + serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndKeyboardInt) + } +} + +func (c *Configuration) loadFromProvider() error { + configs, err := dataprovider.GetConfigs() + if err != nil { + return fmt.Errorf("unable to load config from provider: %w", err) + } + configs.SetNilsToEmpty() + if len(configs.SFTPD.HostKeyAlgos) > 0 { + if len(c.HostKeyAlgorithms) == 0 { + c.HostKeyAlgorithms = preferredHostKeyAlgos + } + c.HostKeyAlgorithms = append(c.HostKeyAlgorithms, configs.SFTPD.HostKeyAlgos...) + } + if len(configs.SFTPD.PublicKeyAlgos) > 0 { + if len(c.PublicKeyAlgorithms) == 0 { + c.PublicKeyAlgorithms = preferredPublicKeyAlgos + } + c.PublicKeyAlgorithms = append(c.PublicKeyAlgorithms, configs.SFTPD.PublicKeyAlgos...) + } + if len(configs.SFTPD.KexAlgorithms) > 0 { + if len(c.KexAlgorithms) == 0 { + c.KexAlgorithms = preferredKexAlgos + } + c.KexAlgorithms = append(c.KexAlgorithms, configs.SFTPD.KexAlgorithms...) + } + if len(configs.SFTPD.Ciphers) > 0 { + if len(c.Ciphers) == 0 { + c.Ciphers = preferredCiphers + } + c.Ciphers = append(c.Ciphers, configs.SFTPD.Ciphers...) + } + if len(configs.SFTPD.MACs) > 0 { + if len(c.MACs) == 0 { + c.MACs = preferredMACs + } + c.MACs = append(c.MACs, configs.SFTPD.MACs...) + } + return nil +} + +// Initialize the SFTP server and add a persistent listener to handle inbound SFTP connections. +func (c *Configuration) Initialize(configDir string) error { + c.executor = defaultExecutor{} + if err := c.loadFromProvider(); err != nil { + return fmt.Errorf("unable to load configs from provider: %w", err) + } + serviceStatus = ServiceStatus{} + serverConfig := c.getServerConfig() + + if !c.ShouldBind() { + return common.ErrNoBinding + } + + sftp.SetSFTPExtensions(sftpExtensions...) //nolint:errcheck // we configure valid SFTP Extensions so we cannot get an error + sftp.MaxFilelist = 250 + + if err := c.configureSecurityOptions(serverConfig); err != nil { + return err + } + if err := c.checkAndLoadHostKeys(configDir, serverConfig); err != nil { + serviceStatus.HostKeys = nil + return err + } + if err := c.initializeCertChecker(configDir); err != nil { + return err + } + if err := c.initializeOPKSSH(); err != nil { + return err + } + c.configureKeyboardInteractiveAuth(serverConfig) + c.configureLoginBanner(serverConfig, configDir) + c.checkSSHCommands() + + exitChannel := make(chan error, 1) + serviceStatus.Bindings = nil + + for _, binding := range c.Bindings { + if !binding.IsValid() { + continue + } + serviceStatus.Bindings = append(serviceStatus.Bindings, binding) + + go func(binding Binding) { + addr := binding.GetAddress() + util.CheckTCP4Port(binding.Port) + listener, err := net.Listen("tcp", addr) + if err != nil { + logger.Warn(logSender, "", "error starting listener on address %v: %v", addr, err) + exitChannel <- err + return + } + + if binding.ApplyProxyConfig && common.Config.ProxyProtocol > 0 { + proxyListener, err := common.Config.GetProxyListener(listener) + if err != nil { + logger.Warn(logSender, "", "error enabling proxy listener: %v", err) + exitChannel <- err + return + } + listener = proxyListener + } + + exitChannel <- c.serve(listener, serverConfig) + }(binding) + } + + serviceStatus.IsActive = true + serviceStatus.SSHCommands = c.EnabledSSHCommands + c.updateSupportedAuthentications() + + return <-exitChannel +} + +func (c *Configuration) serve(listener net.Listener, serverConfig *ssh.ServerConfig) error { + logger.Info(logSender, "", "server listener registered, address: %s", listener.Addr().String()) + var tempDelay time.Duration // how long to sleep on accept failure + + for { + conn, err := listener.Accept() + if err != nil { + // see https://github.com/golang/go/blob/4aa1efed4853ea067d665a952eee77c52faac774/src/net/http/server.go#L3046 + if ne, ok := err.(net.Error); ok && ne.Temporary() { //nolint:staticcheck + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if maxDelay := 1 * time.Second; tempDelay > maxDelay { + tempDelay = maxDelay + } + logger.Warn(logSender, "", "accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + logger.Warn(logSender, "", "unrecoverable accept error: %v", err) + return err + } + tempDelay = 0 + + go c.AcceptInboundConnection(conn, serverConfig) + } +} + +func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error { + if len(c.HostKeyAlgorithms) == 0 { + c.HostKeyAlgorithms = preferredHostKeyAlgos + } else { + c.HostKeyAlgorithms = util.RemoveDuplicates(c.HostKeyAlgorithms, true) + } + for _, hostKeyAlgo := range c.HostKeyAlgorithms { + if !slices.Contains(supportedHostKeyAlgos, hostKeyAlgo) { + return fmt.Errorf("unsupported host key algorithm %q", hostKeyAlgo) + } + } + + if len(c.PublicKeyAlgorithms) > 0 { + c.PublicKeyAlgorithms = util.RemoveDuplicates(c.PublicKeyAlgorithms, true) + for _, algo := range c.PublicKeyAlgorithms { + if !slices.Contains(supportedPublicKeyAlgos, algo) { + return fmt.Errorf("unsupported public key authentication algorithm %q", algo) + } + } + } else { + c.PublicKeyAlgorithms = preferredPublicKeyAlgos + } + serverConfig.PublicKeyAuthAlgorithms = c.PublicKeyAlgorithms + serviceStatus.PublicKeyAlgorithms = c.PublicKeyAlgorithms + + return nil +} + +func (c *Configuration) checkKeyExchangeAlgorithms() { + var kexs []string + for _, k := range c.KexAlgorithms { + if k == "diffie-hellman-group18-sha512" { + logger.Warn(logSender, "", "KEX %q is not supported and will be ignored", k) + continue + } + kexs = append(kexs, k) + if strings.TrimSpace(k) == keyExchangeCurve25519SHA256LibSSH { + kexs = append(kexs, ssh.KeyExchangeCurve25519) + } + if strings.TrimSpace(k) == ssh.KeyExchangeCurve25519 { + kexs = append(kexs, keyExchangeCurve25519SHA256LibSSH) + } + } + c.KexAlgorithms = util.RemoveDuplicates(kexs, true) +} + +func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig) error { + if err := c.configureKeyAlgos(serverConfig); err != nil { + return err + } + + if len(c.KexAlgorithms) > 0 { + c.checkKeyExchangeAlgorithms() + for _, kex := range c.KexAlgorithms { + if kex == keyExchangeCurve25519SHA256LibSSH { + continue + } + if !slices.Contains(supportedKexAlgos, kex) { + return fmt.Errorf("unsupported key-exchange algorithm %q", kex) + } + } + } else { + c.KexAlgorithms = preferredKexAlgos + c.checkKeyExchangeAlgorithms() + } + serverConfig.KeyExchanges = c.KexAlgorithms + serviceStatus.KexAlgorithms = c.KexAlgorithms + + if len(c.Ciphers) > 0 { + c.Ciphers = util.RemoveDuplicates(c.Ciphers, true) + for _, cipher := range c.Ciphers { + if slices.Contains([]string{"aes192-cbc", "aes256-cbc"}, cipher) { + continue + } + if !slices.Contains(supportedCiphers, cipher) { + return fmt.Errorf("unsupported cipher %q", cipher) + } + } + } else { + c.Ciphers = preferredCiphers + } + serverConfig.Ciphers = c.Ciphers + serviceStatus.Ciphers = c.Ciphers + + if len(c.MACs) > 0 { + c.MACs = util.RemoveDuplicates(c.MACs, true) + for _, mac := range c.MACs { + if !slices.Contains(supportedMACs, mac) { + return fmt.Errorf("unsupported MAC algorithm %q", mac) + } + } + } else { + c.MACs = preferredMACs + } + serverConfig.MACs = c.MACs + serviceStatus.MACs = c.MACs + + return nil +} + +func (c *Configuration) configureLoginBanner(serverConfig *ssh.ServerConfig, configDir string) { + if c.LoginBannerFile != "" { + bannerFilePath := c.LoginBannerFile + if !filepath.IsAbs(bannerFilePath) { + bannerFilePath = filepath.Join(configDir, bannerFilePath) + } + bannerContent, err := os.ReadFile(bannerFilePath) + if err == nil { + banner := util.BytesToString(bannerContent) + serverConfig.BannerCallback = func(_ ssh.ConnMetadata) string { + return banner + } + } else { + logger.WarnToConsole("unable to read SFTPD login banner file: %v", err) + logger.Warn(logSender, "", "unable to read login banner file: %v", err) + } + } +} + +func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.ServerConfig) { + if !c.KeyboardInteractiveAuthentication { + return + } + if c.KeyboardInteractiveHook != "" { + if !strings.HasPrefix(c.KeyboardInteractiveHook, "http") { + if !filepath.IsAbs(c.KeyboardInteractiveHook) { + c.KeyboardInteractiveAuthentication = false + logger.WarnToConsole("invalid keyboard interactive authentication program: %q must be an absolute path", + c.KeyboardInteractiveHook) + logger.Warn(logSender, "", "invalid keyboard interactive authentication program: %q must be an absolute path", + c.KeyboardInteractiveHook) + return + } + _, err := os.Stat(c.KeyboardInteractiveHook) + if err != nil { + c.KeyboardInteractiveAuthentication = false + logger.WarnToConsole("invalid keyboard interactive authentication program:: %v", err) + logger.Warn(logSender, "", "invalid keyboard interactive authentication program:: %v", err) + return + } + } + } + serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyboardInteractive, false) + } + + serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) +} + +// AcceptInboundConnection handles an inbound connection to the server instance and determines if the request should be served or not. +func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { //nolint:gocyclo + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in AcceptInboundConnection: %q stack trace: %v", r, string(debug.Stack())) + } + }() + + ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) + + if !canAcceptConnection(ipAddr) { + conn.Close() + return + } + // Before beginning a handshake must be performed on the incoming net.Conn + // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH + conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck + + sconn, chans, reqs, err := ssh.NewServerConn(conn, config) + if err != nil { + logger.Debug(logSender, "", "failed to accept an incoming connection from ip %q: %v", ipAddr, err) + checkAuthError(ipAddr, err) + return + } + // handshake completed so remove the deadline, we'll use IdleTimeout configuration from now on + conn.SetDeadline(time.Time{}) //nolint:errcheck + go ssh.DiscardRequests(reqs) + + defer sconn.Close() + + user := sconn.Permissions.ExtraData[extraDataUserKey].(dataprovider.User) + loginType := sconn.Permissions.ExtraData[extraDataLoginMethodKey].(string) + connectionID := hex.EncodeToString(sconn.SessionID()) + + defer user.CloseFs() //nolint:errcheck + if err = user.CheckFsRoot(connectionID); err != nil { + logger.Warn(logSender, connectionID, "unable to check fs root for user %q: %v", user.Username, err) + go discardAllChannels(chans, "invalid root fs", connectionID) + return + } + + logger.LoginLog(user.Username, ipAddr, loginType, common.ProtocolSSH, connectionID, + util.BytesToString(sconn.ClientVersion()), true, + fmt.Sprintf("negotiated algorithms: %+v", sconn.Conn.(ssh.AlgorithmsConnMetadata).Algorithms())) + + dataprovider.UpdateLastLogin(&user) + + sshConnection := common.NewSSHConnection(connectionID, sconn) + common.Connections.AddSSHConnection(sshConnection) + + defer common.Connections.RemoveSSHConnection(connectionID) + + channelCounter := int64(0) + for newChannel := range chans { + // If its not a session channel we just move on because its not something we + // know how to handle at this point. + if newChannel.ChannelType() != "session" { + logger.Log(logger.LevelDebug, common.ProtocolSSH, connectionID, "received an unknown channel type: %v", + newChannel.ChannelType()) + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") //nolint:errcheck + continue + } + + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Log(logger.LevelWarn, common.ProtocolSSH, connectionID, "could not accept a channel: %v", err) + continue + } + + channelCounter++ + // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" + // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) + go func(in <-chan *ssh.Request, counter int64) { + for req := range in { + ok := false + connID := fmt.Sprintf("%s_%d", connectionID, counter) + + switch req.Type { + case "subsystem": + if bytes.Equal(req.Payload[4:], []byte("sftp")) { + ok = true + sshConnection.UpdateLastActivity() + connection := &Connection{ + BaseConnection: common.NewBaseConnection(connID, common.ProtocolSFTP, conn.LocalAddr().String(), + conn.RemoteAddr().String(), user), + ClientVersion: util.BytesToString(sconn.ClientVersion()), + RemoteAddr: conn.RemoteAddr(), + LocalAddr: conn.LocalAddr(), + channel: channel, + } + go c.handleSftpConnection(channel, connection) + } + case "exec": + // protocol will be set later inside processSSHCommand it could be SSH or SCP + connection := Connection{ + BaseConnection: common.NewBaseConnection(connID, "sshd_exec", conn.LocalAddr().String(), + conn.RemoteAddr().String(), user), + ClientVersion: util.BytesToString(sconn.ClientVersion()), + RemoteAddr: conn.RemoteAddr(), + LocalAddr: conn.LocalAddr(), + channel: channel, + } + ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) + if ok { + sshConnection.UpdateLastActivity() + } + } + if req.WantReply { + req.Reply(ok, nil) //nolint:errcheck + } + } + }(requests, channelCounter) + } +} + +func (c *Configuration) handleSftpConnection(channel ssh.Channel, connection *Connection) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in handleSftpConnection: %q stack trace: %v", r, string(debug.Stack())) + } + }() + if err := common.Connections.Add(connection); err != nil { + defer connection.CloseFS() //nolint:errcheck + errClose := connection.Disconnect() + logger.Info(logSender, "", "unable to add connection: %v, close err: %v", err, errClose) + return + } + defer common.Connections.Remove(connection.GetID()) + + // Create the server instance for the channel using the handler we created above. + server := sftp.NewRequestServer(channel, c.createHandlers(connection), + sftp.WithStartDirectory(connection.User.Filters.StartDirectory)) + + defer server.Close() + if err := server.Serve(); errors.Is(err, io.EOF) { + exitStatus := sshSubsystemExitStatus{Status: uint32(0)} + _, err = channel.SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) + connection.Log(logger.LevelInfo, "connection closed, sent exit status %+v error: %v", exitStatus, err) + } else if err != nil { + connection.Log(logger.LevelError, "connection closed with error: %v", err) + } +} + +func (c *Configuration) createHandlers(connection *Connection) sftp.Handlers { + return sftp.Handlers{ + FileGet: connection, + FilePut: connection, + FileCmd: connection, + FileList: connection, + } +} + +func canAcceptConnection(ip string) bool { + if common.IsBanned(ip, common.ProtocolSSH) { + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, ip %q is banned", ip) + return false + } + if err := common.Connections.IsNewConnectionAllowed(ip, common.ProtocolSSH); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection not allowed from ip %q: %v", ip, err) + return false + } + _, err := common.LimitRate(common.ProtocolSSH, ip) + if err != nil { + return false + } + if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil { + return false + } + return true +} + +func discardAllChannels(in <-chan ssh.NewChannel, message, connectionID string) { + for req := range in { + err := req.Reject(ssh.ConnectionFailed, message) + logger.Debug(logSender, connectionID, "discarded channel request, message %q err: %v", message, err) + } +} + +func checkAuthError(ip string, err error) { + var authErrors *ssh.ServerAuthError + if errors.As(err, &authErrors) { + // check public key auth errors here + for _, err := range authErrors.Errors { + var sftpAuthErr *authenticationError + if errors.As(err, &sftpAuthErr) { + if sftpAuthErr.getLoginMethod() == dataprovider.SSHLoginMethodPublicKey { + event := common.HostEventLoginFailed + logEv := notifier.LogEventTypeLoginFailed + if errors.Is(err, util.ErrNotFound) { + event = common.HostEventUserNotFound + logEv = notifier.LogEventTypeLoginNoUser + } + common.AddDefenderEvent(ip, common.ProtocolSSH, event) + plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, sftpAuthErr.getUsername(), ip, "", err) + return + } + } + } + } else { + logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTried, common.ProtocolSSH, err.Error()) + metric.AddNoAuthTried() + common.AddDefenderEvent(ip, common.ProtocolSSH, common.HostEventNoLoginTried) + dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTried, ip, common.ProtocolSSH, err) + logEv := notifier.LogEventTypeNoLoginTried + var negotiationError *ssh.AlgorithmNegotiationError + if errors.As(err, &negotiationError) { + logEv = notifier.LogEventTypeNotNegotiated + } + plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, "", ip, "", err) + } +} + +func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.ConnMetadata) (*ssh.Permissions, error) { + connectionID := "" + if conn != nil { + connectionID = hex.EncodeToString(conn.SessionID()) + } + if !filepath.IsAbs(user.HomeDir) { + logger.Warn(logSender, connectionID, "user %q has an invalid home dir: %q. Home dir must be an absolute path, login not allowed", + user.Username, user.HomeDir) + return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir) + } + if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) { + logger.Info(logSender, connectionID, "cannot login user %q, protocol SSH is not allowed", user.Username) + return nil, fmt.Errorf("protocol SSH is not allowed for user %q", user.Username) + } + if user.MaxSessions > 0 { + activeSessions := common.Connections.GetActiveSessions(user.Username) + if activeSessions >= user.MaxSessions { + logger.Info(logSender, "", "authentication refused for user: %q, too many open sessions: %v/%v", user.Username, + activeSessions, user.MaxSessions) + return nil, fmt.Errorf("too many open sessions: %v", activeSessions) + } + } + if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolSSH) { + logger.Info(logSender, connectionID, "cannot login user %q, login method %q is not allowed", + user.Username, loginMethod) + return nil, fmt.Errorf("login method %q is not allowed for user %q", loginMethod, user.Username) + } + if user.MustSetSecondFactorForProtocol(common.ProtocolSSH) { + logger.Info(logSender, connectionID, "cannot login user %q, second factor authentication is not set", + user.Username) + return nil, fmt.Errorf("second factor authentication is not set for user %q", user.Username) + } + remoteAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + if !user.IsLoginFromAddrAllowed(remoteAddr) { + logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", + user.Username, remoteAddr) + return nil, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, remoteAddr) + } + + if publicKey != "" { + loginMethod = fmt.Sprintf("%v: %v", loginMethod, publicKey) + } + p := &ssh.Permissions{} + p.ExtraData = make(map[any]any) + p.ExtraData[extraDataUserKey] = *user + p.ExtraData[extraDataLoginMethodKey] = loginMethod + return p, nil +} + +func (c *Configuration) checkSSHCommands() { + if slices.Contains(c.EnabledSSHCommands, "*") { + c.EnabledSSHCommands = GetSupportedSSHCommands() + return + } + sshCommands := []string{} + for _, command := range c.EnabledSSHCommands { + command = strings.TrimSpace(command) + if slices.Contains(supportedSSHCommands, command) { + sshCommands = append(sshCommands, command) + } else { + logger.Warn(logSender, "", "unsupported ssh command: %q ignored", command) + logger.WarnToConsole("unsupported ssh command: %q ignored", command) + } + } + c.EnabledSSHCommands = sshCommands + logger.Debug(logSender, "", "enabled SSH commands %v", c.EnabledSSHCommands) +} + +func (c *Configuration) generateDefaultHostKeys(configDir string) error { + var err error + defaultHostKeys := []string{defaultPrivateRSAKeyName, defaultPrivateECDSAKeyName, defaultPrivateEd25519KeyName} + for _, k := range defaultHostKeys { + autoFile := filepath.Join(configDir, k) + if _, err = os.Stat(autoFile); errors.Is(err, fs.ErrNotExist) { + logger.Info(logSender, "", "No host keys configured and %q does not exist; try to create a new host key", autoFile) + logger.InfoToConsole("No host keys configured and %q does not exist; try to create a new host key", autoFile) + switch k { + case defaultPrivateRSAKeyName: + err = util.GenerateRSAKeys(autoFile) + case defaultPrivateECDSAKeyName: + err = util.GenerateECDSAKeys(autoFile) + default: + err = util.GenerateEd25519Keys(autoFile) + } + if err != nil { + logger.Warn(logSender, "", "error creating host key %q: %v", autoFile, err) + logger.WarnToConsole("error creating host key %q: %v", autoFile, err) + return err + } + } + c.HostKeys = append(c.HostKeys, k) + } + + return err +} + +func (c *Configuration) checkHostKeyAutoGeneration(configDir string) error { + for _, k := range c.HostKeys { + k = strings.TrimSpace(k) + if filepath.IsAbs(k) { + if _, err := os.Stat(k); errors.Is(err, fs.ErrNotExist) { + keyName := filepath.Base(k) + switch keyName { + case defaultPrivateRSAKeyName: + logger.Info(logSender, "", "try to create non-existent host key %q", k) + logger.InfoToConsole("try to create non-existent host key %q", k) + err = util.GenerateRSAKeys(k) + if err != nil { + logger.Warn(logSender, "", "error creating host key %q: %v", k, err) + logger.WarnToConsole("error creating host key %q: %v", k, err) + return err + } + case defaultPrivateECDSAKeyName: + logger.Info(logSender, "", "try to create non-existent host key %q", k) + logger.InfoToConsole("try to create non-existent host key %q", k) + err = util.GenerateECDSAKeys(k) + if err != nil { + logger.Warn(logSender, "", "error creating host key %q: %v", k, err) + logger.WarnToConsole("error creating host key %q: %v", k, err) + return err + } + case defaultPrivateEd25519KeyName: + logger.Info(logSender, "", "try to create non-existent host key %q", k) + logger.InfoToConsole("try to create non-existent host key %q", k) + err = util.GenerateEd25519Keys(k) + if err != nil { + logger.Warn(logSender, "", "error creating host key %q: %v", k, err) + logger.WarnToConsole("error creating host key %q: %v", k, err) + return err + } + default: + logger.Warn(logSender, "", "non-existent host key %q will not be created", k) + logger.WarnToConsole("non-existent host key %q will not be created", k) + } + } + } + } + if len(c.HostKeys) == 0 { + if err := c.generateDefaultHostKeys(configDir); err != nil { + return err + } + } + return nil +} + +func (c *Configuration) getHostKeyAlgorithms(keyFormat string) []string { + var algos []string + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if slices.Contains(c.HostKeyAlgorithms, algo) { + algos = append(algos, algo) + } + } + return algos +} + +// If no host keys are defined we try to use or generate the default ones. +func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh.ServerConfig) error { + if err := c.checkHostKeyAutoGeneration(configDir); err != nil { + return err + } + hostCertificates, err := c.loadHostCertificates(configDir) + if err != nil { + return err + } + serviceStatus.HostKeys = nil + for _, hostKey := range c.HostKeys { + hostKey = strings.TrimSpace(hostKey) + if !util.IsFileInputValid(hostKey) { + logger.Warn(logSender, "", "unable to load invalid host key %q", hostKey) + logger.WarnToConsole("unable to load invalid host key %q", hostKey) + continue + } + if !filepath.IsAbs(hostKey) { + hostKey = filepath.Join(configDir, hostKey) + } + logger.Info(logSender, "", "Loading private host key %q", hostKey) + + privateBytes, err := os.ReadFile(hostKey) + if err != nil { + return err + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return err + } + k := HostKey{ + Path: hostKey, + Fingerprint: ssh.FingerprintSHA256(private.PublicKey()), + Algorithms: c.getHostKeyAlgorithms(private.PublicKey().Type()), + } + mas, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), k.Algorithms) + if err != nil { + return fmt.Errorf("could not create signer for key %q with algorithms %+v: %w", k.Path, k.Algorithms, err) + } + serviceStatus.HostKeys = append(serviceStatus.HostKeys, k) + logger.Info(logSender, "", "Host key %q loaded, type %q, fingerprint %q, algorithms %+v", hostKey, + private.PublicKey().Type(), k.Fingerprint, k.Algorithms) + + // Add private key to the server configuration. + serverConfig.AddHostKey(mas) + for _, cert := range hostCertificates { + signer, err := ssh.NewCertSigner(cert.Certificate, mas) + if err == nil { + var algos []string + for _, algo := range algorithmsForKeyFormat(signer.PublicKey().Type()) { + if underlyingAlgo, ok := certKeyAlgoNames[algo]; ok { + if slices.Contains(mas.Algorithms(), underlyingAlgo) { + algos = append(algos, algo) + } + } + } + serviceStatus.HostKeys = append(serviceStatus.HostKeys, HostKey{ + Path: cert.Path, + Fingerprint: ssh.FingerprintSHA256(signer.PublicKey()), + Algorithms: algos, + }) + serverConfig.AddHostKey(signer) + logger.Info(logSender, "", "Host certificate loaded for host key %q, fingerprint %q, algorithms %+v", + hostKey, ssh.FingerprintSHA256(signer.PublicKey()), algos) + } + } + } + var fp []string + for idx := range serviceStatus.HostKeys { + h := &serviceStatus.HostKeys[idx] + fp = append(fp, h.Fingerprint) + } + vfs.SetSFTPFingerprints(fp) + return nil +} + +func (c *Configuration) loadHostCertificates(configDir string) ([]hostCertificate, error) { + var certs []hostCertificate + for _, certPath := range c.HostCertificates { + certPath = strings.TrimSpace(certPath) + if !util.IsFileInputValid(certPath) { + logger.Warn(logSender, "", "unable to load invalid host certificate %q", certPath) + logger.WarnToConsole("unable to load invalid host certificate %q", certPath) + continue + } + if !filepath.IsAbs(certPath) { + certPath = filepath.Join(configDir, certPath) + } + certBytes, err := os.ReadFile(certPath) + if err != nil { + return certs, fmt.Errorf("unable to load host certificate %q: %w", certPath, err) + } + parsed, _, _, _, err := ssh.ParseAuthorizedKey(certBytes) + if err != nil { + return nil, fmt.Errorf("unable to parse host certificate %q: %w", certPath, err) + } + cert, ok := parsed.(*ssh.Certificate) + if !ok { + return nil, fmt.Errorf("the file %q is not an SSH certificate", certPath) + } + if cert.CertType != ssh.HostCert { + return nil, fmt.Errorf("the file %q is not an host certificate", certPath) + } + certs = append(certs, hostCertificate{ + Path: certPath, + Certificate: cert, + }) + } + return certs, nil +} + +func (c *Configuration) initializeOPKSSH() error { + if c.OPKSSHPath != "" { + if len(c.parsedUserCAKeys) > 0 { + return errors.New("opkssh and certificate authorities are mutually exclusive") + } + if !util.IsFileInputValid(c.OPKSSHPath) || !filepath.IsAbs(c.OPKSSHPath) { + return fmt.Errorf("opkssh path %q is not valid, it must be an absolute path", c.OPKSSHPath) + } + if c.OPKSSHChecksum == "" { + if _, err := os.Stat(c.OPKSSHPath); err != nil { + return fmt.Errorf("error validating opkssh path %q: %w", c.OPKSSHPath, err) + } + } else { + if err := util.VerifyFileChecksum(c.OPKSSHPath, sha256.New(), c.OPKSSHChecksum, 100*1024*1024); err != nil { + return fmt.Errorf("error validating opkssh checksum: %w", err) + } + } + } + + return nil +} + +func (c *Configuration) verifyWithOPKSSH(username string, cert *ssh.Certificate) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + args := []string{"verify", username, util.BytesToString(ssh.MarshalAuthorizedKey(cert)), cert.Type()} + out, err := c.executor.CombinedOutput(ctx, c.OPKSSHPath, args...) + if err != nil { + logger.Debug(logSender, "", "unable to execute opk verifier: %s", string(out)) + return fmt.Errorf("unable to execute opk verifier: %w", err) + } + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(out) //nolint:dogsled + if err != nil { + logger.Debug(logSender, "", "unable to validate the opk verifier output: %s", string(out)) + return fmt.Errorf("unable to validate the opk verifier output: %w", err) + } + if !bytes.Equal(pubKey.Marshal(), cert.SignatureKey.Marshal()) { + return errors.New("unable to validate opk result") + } + return nil +} + +func (c *Configuration) initializeCertChecker(configDir string) error { + for _, keyPath := range c.TrustedUserCAKeys { + keyPath = strings.TrimSpace(keyPath) + if !util.IsFileInputValid(keyPath) { + logger.Warn(logSender, "", "unable to load invalid trusted user CA key %q", keyPath) + logger.WarnToConsole("unable to load invalid trusted user CA key %q", keyPath) + continue + } + if !filepath.IsAbs(keyPath) { + keyPath = filepath.Join(configDir, keyPath) + } + keyBytes, err := os.ReadFile(keyPath) + if err != nil { + logger.Warn(logSender, "", "error loading trusted user CA key %q: %v", keyPath, err) + logger.WarnToConsole("error loading trusted user CA key %q: %v", keyPath, err) + return err + } + parsedKey, _, _, _, err := ssh.ParseAuthorizedKey(keyBytes) + if err != nil { + logger.Warn(logSender, "", "error parsing trusted user CA key %q: %v", keyPath, err) + logger.WarnToConsole("error parsing trusted user CA key %q: %v", keyPath, err) + return err + } + c.parsedUserCAKeys = append(c.parsedUserCAKeys, parsedKey) + } + c.certChecker = &ssh.CertChecker{ + SupportedCriticalOptions: []string{ + sourceAddressCriticalOption, + }, + IsUserAuthority: func(k ssh.PublicKey) bool { + for _, key := range c.parsedUserCAKeys { + if bytes.Equal(k.Marshal(), key.Marshal()) { + return true + } + } + return false + }, + } + if c.RevokedUserCertsFile != "" { + if !util.IsFileInputValid(c.RevokedUserCertsFile) { + return fmt.Errorf("invalid revoked user certificate: %q", c.RevokedUserCertsFile) + } + if !filepath.IsAbs(c.RevokedUserCertsFile) { + c.RevokedUserCertsFile = filepath.Join(configDir, c.RevokedUserCertsFile) + } + } + revokedCertManager.filePath = c.RevokedUserCertsFile + return revokedCertManager.load() +} + +func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error { + err := &ssh.PartialSuccessError{} + if c.PasswordAuthentication && slices.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) { + err.Next.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword) + } + } + if c.KeyboardInteractiveAuthentication && slices.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) { + err.Next.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt, true) + } + } + return err +} + +func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + var user dataprovider.User + var certPerm *ssh.Permissions + + method := dataprovider.SSHLoginMethodPublicKey + ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + cert, ok := pubKey.(*ssh.Certificate) + var certFingerprint string + if ok { + certFingerprint = ssh.FingerprintSHA256(cert.Key) + if c.OPKSSHPath != "" { + if err := c.verifyWithOPKSSH(conn.User(), cert); err != nil { + err := fmt.Errorf("ssh: verification with OPK failed: %v", err) + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + } else { + if cert.CertType != ssh.UserCert { + err := fmt.Errorf("ssh: cert has type %d", cert.CertType) + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + if !c.certChecker.IsUserAuthority(cert.SignatureKey) { + err := errors.New("ssh: certificate signed by unrecognized authority") + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + if len(cert.ValidPrincipals) == 0 { + err := fmt.Errorf("ssh: certificate %s has no valid principals, user: \"%s\"", certFingerprint, conn.User()) + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + if revokedCertManager.isRevoked(certFingerprint) { + err := fmt.Errorf("ssh: certificate %s is revoked", certFingerprint) + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + if err := c.certChecker.CheckCert(conn.User(), cert); err != nil { + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + } + certPerm = &cert.Permissions + } + user, keyID, err := dataprovider.CheckUserAndPubKey(conn.User(), pubKey.Marshal(), ipAddr, common.ProtocolSSH, ok) + if err != nil { + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + return nil, err + } + if ok { + keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint, + cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey)) + } + if certPerm == nil { + certPerm = &ssh.Permissions{} + } + certPerm.ExtraData = make(map[any]any) + certPerm.ExtraData[extraDataKeyIDKey] = keyID + certPerm.ExtraData[extraDataUserKey] = user + if user.IsPartialAuth() { + certPerm.ExtraData[extraDataPartialSuccessErrKey] = c.getPartialSuccessError(user.GetNextAuthMethods()) + } + return certPerm, nil +} + +func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte, method string) (*ssh.Permissions, error) { + var err error + var user dataprovider.User + var sshPerm *ssh.Permissions + + ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + if user, err = dataprovider.CheckUserAndPass(conn.User(), util.BytesToString(pass), ipAddr, common.ProtocolSSH); err == nil { + sshPerm, err = loginUser(&user, method, "", conn) + } + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + if err != nil { + return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), method, conn.User()) + } + return sshPerm, nil +} + +func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge, + method string, isPartialAuth bool, +) (*ssh.Permissions, error) { + var err error + var user dataprovider.User + var sshPerm *ssh.Permissions + + ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String()) + if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client, + ipAddr, common.ProtocolSSH, isPartialAuth); err == nil { + sshPerm, err = loginUser(&user, method, "", conn) + } + user.Username = conn.User() + updateLoginMetrics(&user, ipAddr, method, err) + if err != nil { + return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), method, conn.User()) + } + return sshPerm, nil +} + +func updateLoginMetrics(user *dataprovider.User, ip, method string, err error) { + metric.AddLoginAttempt(method) + if err == nil { + plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolSSH, user.Username, ip, "", err) + common.DelayLogin(nil) + } else { + logger.ConnectionFailedLog(user.Username, ip, method, common.ProtocolSSH, err.Error()) + if method != dataprovider.SSHLoginMethodPublicKey { + // some clients try all available public keys for a user, we + // record failed login key auth only once for session if the + // authentication fails in checkAuthError + event := common.HostEventLoginFailed + logEv := notifier.LogEventTypeLoginFailed + if errors.Is(err, util.ErrNotFound) { + event = common.HostEventUserNotFound + logEv = notifier.LogEventTypeLoginNoUser + } + common.AddDefenderEvent(ip, common.ProtocolSSH, event) + plugin.Handler.NotifyLogEvent(logEv, common.ProtocolSSH, user.Username, ip, "", err) + if method != dataprovider.SSHLoginMethodPublicKey { + common.DelayLogin(err) + } + } + } + metric.AddLoginResult(method, err) + dataprovider.ExecutePostLoginHook(user, method, ip, common.ProtocolSSH, err) +} + +type revokedCertificates struct { + filePath string + mu sync.RWMutex + certs map[string]bool +} + +func (r *revokedCertificates) load() error { + if r.filePath == "" { + return nil + } + logger.Debug(logSender, "", "loading revoked user certificate file %q", r.filePath) + info, err := os.Stat(r.filePath) + if err != nil { + return fmt.Errorf("unable to load revoked user certificate file %q: %w", r.filePath, err) + } + maxSize := int64(1048576 * 5) // 5MB + if info.Size() > maxSize { + return fmt.Errorf("unable to load revoked user certificate file %q size too big: %v/%v bytes", + r.filePath, info.Size(), maxSize) + } + content, err := os.ReadFile(r.filePath) + if err != nil { + return fmt.Errorf("unable to read revoked user certificate file %q: %w", r.filePath, err) + } + var certs []string + err = json.Unmarshal(content, &certs) + if err != nil { + return fmt.Errorf("unable to parse revoked user certificate file %q: %w", r.filePath, err) + } + + r.mu.Lock() + defer r.mu.Unlock() + + r.certs = map[string]bool{} + for _, fp := range certs { + r.certs[fp] = true + } + logger.Debug(logSender, "", "revoked user certificate file %q loaded, entries: %v", r.filePath, len(r.certs)) + return nil +} + +func (r *revokedCertificates) isRevoked(fp string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + return r.certs[fp] +} + +// Reload reloads the list of revoked user certificates +func Reload() error { + return revokedCertManager.load() +} + +func algorithmsForKeyFormat(keyFormat string) []string { + switch keyFormat { + case ssh.KeyAlgoRSA: + return []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512, ssh.KeyAlgoRSA} + case ssh.CertAlgoRSAv01: + return []string{ssh.CertAlgoRSASHA256v01, ssh.CertAlgoRSASHA512v01, ssh.CertAlgoRSAv01} + default: + return []string{keyFormat} + } +} diff --git a/internal/sftpd/sftpd.go b/internal/sftpd/sftpd.go new file mode 100644 index 00000000..3f69cdde --- /dev/null +++ b/internal/sftpd/sftpd.go @@ -0,0 +1,138 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package sftpd implements the SSH File Transfer Protocol as described in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02. +// It uses pkg/sftp library: +// https://github.com/pkg/sftp +package sftpd + +import ( + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +const ( + logSender = "sftpd" + handshakeTimeout = 2 * time.Minute +) + +var ( + supportedSSHCommands = []string{"scp", "md5sum", "sha1sum", "sha256sum", "sha384sum", "sha512sum", "cd", "pwd", + "sftpgo-copy", "sftpgo-remove"} + defaultSSHCommands = []string{"md5sum", "sha1sum", "sha256sum", "cd", "pwd", "scp"} + sshHashCommands = []string{"md5sum", "sha1sum", "sha256sum", "sha384sum", "sha512sum"} + serviceStatus ServiceStatus + certKeyAlgoNames = map[string]string{ + ssh.CertAlgoRSAv01: ssh.KeyAlgoRSA, + ssh.CertAlgoRSASHA256v01: ssh.KeyAlgoRSASHA256, + ssh.CertAlgoRSASHA512v01: ssh.KeyAlgoRSASHA512, + ssh.InsecureCertAlgoDSAv01: ssh.InsecureKeyAlgoDSA, //nolint:staticcheck + ssh.CertAlgoECDSA256v01: ssh.KeyAlgoECDSA256, + ssh.CertAlgoECDSA384v01: ssh.KeyAlgoECDSA384, + ssh.CertAlgoECDSA521v01: ssh.KeyAlgoECDSA521, + ssh.CertAlgoSKECDSA256v01: ssh.KeyAlgoSKECDSA256, + ssh.CertAlgoED25519v01: ssh.KeyAlgoED25519, + ssh.CertAlgoSKED25519v01: ssh.KeyAlgoSKED25519, + } +) + +type sshSubsystemExitStatus struct { + Status uint32 +} + +type sshSubsystemExecMsg struct { + Command string +} + +type hostCertificate struct { + Certificate *ssh.Certificate + Path string +} + +// HostKey defines the details for a used host key +type HostKey struct { + Path string `json:"path"` + Fingerprint string `json:"fingerprint"` + Algorithms []string `json:"algorithms"` +} + +// GetAlgosAsString returns the host key algorithms as comma separated string +func (h *HostKey) GetAlgosAsString() string { + return strings.Join(h.Algorithms, ", ") +} + +// ServiceStatus defines the service status +type ServiceStatus struct { + IsActive bool `json:"is_active"` + Bindings []Binding `json:"bindings"` + SSHCommands []string `json:"ssh_commands"` + HostKeys []HostKey `json:"host_keys"` + Authentications []string `json:"authentications"` + MACs []string `json:"macs"` + KexAlgorithms []string `json:"kex_algorithms"` + Ciphers []string `json:"ciphers"` + PublicKeyAlgorithms []string `json:"public_key_algorithms"` +} + +// GetSSHCommandsAsString returns enabled SSH commands as comma separated string +func (s *ServiceStatus) GetSSHCommandsAsString() string { + return strings.Join(s.SSHCommands, ", ") +} + +// GetSupportedAuthsAsString returns the supported authentications as comma separated string +func (s *ServiceStatus) GetSupportedAuthsAsString() string { + return strings.Join(s.Authentications, ", ") +} + +// GetMACsAsString returns the enabled MAC algorithms as comma separated string +func (s *ServiceStatus) GetMACsAsString() string { + return strings.Join(s.MACs, ", ") +} + +// GetKEXsAsString returns the enabled KEX algorithms as comma separated string +func (s *ServiceStatus) GetKEXsAsString() string { + return strings.Join(s.KexAlgorithms, ", ") +} + +// GetCiphersAsString returns the enabled ciphers as comma separated string +func (s *ServiceStatus) GetCiphersAsString() string { + return strings.Join(s.Ciphers, ", ") +} + +// GetPublicKeysAlgosAsString returns enabled public key authentication +// algorithms as comma separated string +func (s *ServiceStatus) GetPublicKeysAlgosAsString() string { + return strings.Join(s.PublicKeyAlgorithms, ", ") +} + +// GetStatus returns the server status +func GetStatus() ServiceStatus { + return serviceStatus +} + +// GetDefaultSSHCommands returns the SSH commands enabled as default +func GetDefaultSSHCommands() []string { + result := make([]string, len(defaultSSHCommands)) + copy(result, defaultSSHCommands) + return result +} + +// GetSupportedSSHCommands returns the supported SSH commands +func GetSupportedSSHCommands() []string { + result := make([]string, len(supportedSSHCommands)) + copy(result, supportedSSHCommands) + return result +} diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go new file mode 100644 index 00000000..44de3c31 --- /dev/null +++ b/internal/sftpd/sftpd_test.go @@ -0,0 +1,11777 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd_test + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "hash" + "io" + "io/fs" + "math" + "net" + "net/http" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/jackc/pgx/v5/stdlib" + _ "github.com/mattn/go-sqlite3" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + + "github.com/pkg/sftp" + "github.com/rs/zerolog" + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/mfa" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + logSender = "sftpdTesting" + sftpServerAddr = "127.0.0.1:2022" + sftpSrvAddr2222 = "127.0.0.1:2222" + defaultUsername = "test_user_sftp" + defaultPassword = "test_password" + defaultSFTPUsername = "test_sftpfs_user" + testPubKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0= nicola@p1" + testPubKey1 = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCd60+/j+y8f0tLftihWV1YN9RSahMI9btQMDIMqts/jeNbD8jgoogM3nhF7KxfcaMKURuD47KC4Ey6iAJUJ0sWkSNNxOcIYuvA+5MlspfZDsa8Ag76Fe1vyz72WeHMHMeh/hwFo2TeIeIXg480T1VI6mzfDrVp2GzUx0SS0dMsQBjftXkuVR8YOiOwMCAH2a//M1OrvV7d/NBk6kBN0WnuIBb2jKm15PAA7+jQQG7tzwk2HedNH3jeL5GH31xkSRwlBczRK0xsCQXehAlx6cT/e/s44iJcJTHfpPKoSk6UAhPJYe7Z1QnuoawY9P9jQaxpyeImBZxxUEowhjpj2avBxKdRGBVK8R7EL8tSOeLbhdyWe5Mwc1+foEbq9Zz5j5Kd+hn3Wm1UnsGCrXUUUoZp1jnlNl0NakCto+5KmqnT9cHxaY+ix2RLUWAZyVFlRq71OYux1UHJnEJPiEI1/tr4jFBSL46qhQZv/TfpkfVW8FLz0lErfqu0gQEZnNHr3Fc= nicola@p1" + testPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn +NhAAAAAwEAAQAAAYEAtN449A/nY5O6cSH/9Doa8a3ISU0WZJaHydTaCLuO+dkqtNpnV5mq +zFbKidXAI1eSwVctw9ReVOl1uK6aZF3lbXdOD8W9PXobR9KUUT2qBx5QC4ibfAqDKWymDA +PG9ylzz64hsYBqJr7VNk9kTFEUsDmWzLabLoH42Elnp8mF/lTkWIcpVp0ly/etS08gttXo +XenekJ1vRuxOYWDCEzGPU7kGc920TmM14k7IDdPoOh5+3sRUKedKeOUrVDH1f0n7QjHQsZ +cbshp8tgqzf734zu8cTqNrr+6taptdEOOij1iUL/qYGfzny/hA48tO5+UFUih5W8ftp0+E +NBIDkkGgk2MJ92I7QAXyMVsIABXco+mJT7pQi9tqlODGIQ3AOj0gcA3X/Ib8QX77Ih3TPi +XEh77/P1XiYZOgpp2cRmNH8QbqaL9u898hDvJwIPJPuj2lIltTElH7hjBf5LQfCzrLV7BD +10rM7sl4jr+A2q8jl1Ikp+25kainBBZSbrDummT9AAAFgDU/VLk1P1S5AAAAB3NzaC1yc2 +EAAAGBALTeOPQP52OTunEh//Q6GvGtyElNFmSWh8nU2gi7jvnZKrTaZ1eZqsxWyonVwCNX +ksFXLcPUXlTpdbiummRd5W13Tg/FvT16G0fSlFE9qgceUAuIm3wKgylspgwDxvcpc8+uIb +GAaia+1TZPZExRFLA5lsy2my6B+NhJZ6fJhf5U5FiHKVadJcv3rUtPILbV6F3p3pCdb0bs +TmFgwhMxj1O5BnPdtE5jNeJOyA3T6Doeft7EVCnnSnjlK1Qx9X9J+0Ix0LGXG7IafLYKs3 ++9+M7vHE6ja6/urWqbXRDjoo9YlC/6mBn858v4QOPLTuflBVIoeVvH7adPhDQSA5JBoJNj +CfdiO0AF8jFbCAAV3KPpiU+6UIvbapTgxiENwDo9IHAN1/yG/EF++yId0z4lxIe+/z9V4m +GToKadnEZjR/EG6mi/bvPfIQ7ycCDyT7o9pSJbUxJR+4YwX+S0Hws6y1ewQ9dKzO7JeI6/ +gNqvI5dSJKftuZGopwQWUm6w7ppk/QAAAAMBAAEAAAGAHKnC+Nq0XtGAkIFE4N18e6SAwy +0WSWaZqmCzFQM0S2AhJnweOIG/0ZZHjsRzKKauOTmppQk40dgVsejpytIek9R+aH172gxJ +2n4Cx0UwduRU5x8FFQlNc/kl722B0JWfJuB/snOZXv6LJ4o5aObIkozt2w9tVFeAqjYn2S +1UsNOfRHBXGsTYwpRDwFWP56nKo2d2wBBTHDhCy6fb2dLW1fvSi/YspueOGIlHpvlYKi2/ +CWqvs9xVrwcScMtiDoQYq0khhO0efLCxvg/o+W9CLMVM2ms4G1zoSUQKN0oYWWQJyW4+VI +YneWO8UpN0J3ElXKi7bhgAat7dBaM1g9IrAzk153DiEFZNsPxGOgL/+YdQN7zUBx/z7EkI +jyv80RV7fpUXvcq2p+qNl6UVig3VSzRrnsaJkUWu/A0u59ha7ocv6NxDIXjxpIDJme16GF +quiGVBQNnYJymS/vFEbGf6bgf7iRmMCRUMG4nqLA6fPYP9uAtch+CmDfVLZC/fIdC5AAAA +wQCDissV4zH6bfqgxJSuYNk8Vbb+19cF3b7gH1rVlB3zxpCAgcRgMHC+dP1z2NRx7UW9MR +nye6kjpkzZZ0OigLqo7TtEq8uTglD9o6W7mRXqhy5A/ySOmqPL3ernHHQhGuoNODYAHkOU +u2Rh8HXi+VLwKZcLInPOYJvcuLG4DxN8WfeVvlMHwhAOaTNNOtL4XZDHQeIPc4qHmJymmv +sV7GuyQ6yW5C10uoGdxRPd90Bh4z4h2bKfZFjvEBbSBVkqrlAAAADBAN/zNtNayd/dX7Cr +Nb4sZuzCh+CW4BH8GOePZWNCATwBbNXBVb5cR+dmuTqYm+Ekz0VxVQRA1TvKncluJOQpoa +Xj8r0xdIgqkehnfDPMKtYVor06B9Fl1jrXtXU0Vrr6QcBWruSVyK1ZxqcmcNK/+KolVepe +A6vcl/iKaG4U7su166nxLST06M2EgcSVsFJHpKn5+WAXC+X0Gx8kNjWIIb3GpiChdc0xZD +mq02xZthVJrTCVw/e7gfDoB2QRsNV8HwAAAMEAzsCghZVp+0YsYg9oOrw4tEqcbEXEMhwY +0jW8JNL8Spr1Ibp5Dw6bRSk5azARjmJtnMJhJ3oeHfF0eoISqcNuQXGndGQbVM9YzzAzc1 +NbbCNsVroqKlChT5wyPNGS+phi2bPARBno7WSDvshTZ7dAVEP2c9MJW0XwoSevwKlhgSdt +RLFFQ/5nclJSdzPBOmQouC0OBcMFSrYtMeknJ4VvueVvve5HcHFaEsaMc7ABAGaLYaBQOm +iixITGvaNZh/tjAAAACW5pY29sYUBwMQE= +-----END OPENSSH PRIVATE KEY-----` + // password protected private key + testPrivateKeyPwd = `-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABAvfwQQcs ++PyMsCLTNFcKiQAAAAEAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q ++8w23flfgskjIlKViEwMfjJR4mrbAAAAkHp5xgG8J1XW90M/fT59ZUQht8sZzzP17rEKlX +waYKvLzDxkPK6LFIYs55W1EX1eVt/2Maq+zQ7k2SOUmhPNknsUOlPV2gytX3uIYvXF7u2F +FTBIJuzZ+UQ14wFbraunliE9yye9DajVG1kz2cz2wVgXUbee+gp5NyFVvln+TcTxXwMsWD +qwlk5iw/jQekxThg== +-----END OPENSSH PRIVATE KEY----- +` + testPubKeyPwd = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILqltfCL7IPuIQ2q+8w23flfgskjIlKViEwMfjJR4mrb" + privateKeyPwd = "password" + // test CA user key. + // % ssh-keygen -f ca_user_key + testCAUserKey = "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDF5fcwZHiyixmnE6IlOZJpZhWXoh62gN+yadAA0GJ509SAEaZVLPDP8S5RsE8mUikR3wxynVshxHeqMhrkS+RlNbhSlOXDdNg94yTrq/xF8Z/PgKRInvef74k5i7bAIytza7jERzFJ/ujTEy3537T5k5EYQJ15ZQGuvzynSdv+6o99SjI4jFplyQOZ2QcYbEAmhHm5GgQlIiEFG/RlDtLksOulKZxOY3qPzP0AyQxtZJXn/5vG40aW9LTbwxCJqWlgrkFXMqAAVCbuU5YspwhiXmKt1PsldiXw23oloa4caCKN1jzbFiGuZNXEU2Ebx7JIvjQCPaUYwLjEbkRDxDqN/vmwZqBuKYiuG9Eafx+nFSQkr7QYb5b+mT+/1IFHnmeRGn38731kBqtH7tpzC/t+soRX9p2HtJM+9MYhblO2OqTSPGTlxihWUkyiRBekpAhaiHld16TsG+A3bOJHrojGcX+5g6oGarKGLAMcykL1X+rZqT993Mo6d2Z7q43MOXE= root@p1" + // this is testPubKey signed using testCAUserKey. + // % ssh-keygen -s ca_user_key -I test_user_sftp -n test_user_sftp -V always:forever -O source-address=127.0.0.1 -z 1 /tmp/test.pub + testCertValid = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgm2fil1IIoTixrA2QE9tk7Vbspj/JdEY90e3K2htxYv8AAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAQAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAACMAAAAOc291cmNlLWFkZHJlc3MAAAANAAAACTEyNy4wLjAuMQAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgMNenD7d1J9cF7JWgHA1DYpJ5+5knPtdXbbIgZAznsTxX7qOdptjeeYOuzhQ5Bwklh3fjewiJpGR1rBqbULP+6PAKeYqd7dNLH/upfKBfJweRf5pdXDpoknHaVuIhi4Uu6FeI4NkAzX9nqNKjFAflhJ+7GLGkLNb0UVZxgxr/t0rPmxc5iTg2ZRM+rk1Ij0S5RnGiKVsdAClqNA6h4TDzu5lJVdK5XvuNKBsKVRCvsVBOgJQTtRTLywQaqWR+HBfCiMj8X8EI7atDlJ6XIAlTLOO/f1sM8QPLjT0+tCHZaGFzg/lKPh3/yFQ4MvddZCptMy1Ll1xvj7cz2ynhGR4PiDfikV3YzgJU/KtL5y+ZB4jU08oPRiOP612PjwZZ+MqYOVOFCKUpMpZQs5UJHME+zNKr4LEj8M0x4YFKIciC+RsrCo4ujbJHmz61ionCadU+fmngvl3C3QjmUdgULBevODeUeIpJv4yFahNxrG1SKRTAa8VVDwJ9GdDTtmXM0mrwA== nicola@p1" + // this is testPubKey signed using a CA user key different from testCAUserKey + testCertUntrustedCA = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg8oFPWpjYy/DowMmtOjWj7Dq20d2N/4Rxzr/c710tOOUAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAAAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAGXAAAAB3NzaC1yc2EAAAADAQABAAABgQCqgm2gVlptULThfpRR0oCb4SAU3368ULlJaiZOUdq6b94KTfgmu4hTLs7u3a8hyZnVxrKrJ93uAVCwa/HGtgiN96CNC6JUt/QnPqTJ8LQ207RdoE9fbOe6mGwOle5z45+5JFoIi5ZZuD8JsBGodVoa92UepoMyBcNtZyl9q2GP4yT2tIYRon79dtG9AXiDYyhSgePqaObN67dn3ivMc4ZGNukK3cG07cYPic5y0wxX16wSMG3pGQDyUkAu+s4AqpnV9EWHM4PE7SYkCXE99++tUK3QALYqvGZKrLHgzmDKi6n+e14vHYUppAeGDZzwlawiY4oGP9eOW2KUfjZe2ZeL22JTFDYzH2lNV2WtUpeKRGGTSGaUblRVC9hRt6hKCT4c7qpW4kO4kPhE39JpcNPGLql7srNkw+3xXBs8xghMPtH3nOl1Rz2mxnX5tAqmPBb+KiPepnrs+pBRu7i+nCVp8az+iN87STYHy+zPtvTR+QURC8BpNraPOfXwpwM2HaMAAAGUAAAADHJzYS1zaGEyLTUxMgAAAYBnTXCL6tXUO3/Gtsm7lnH9Sulzca8FOoI4Y/4bVYhq4iUNu7Ca452m+Xr9qmCEoIyIJF0LEEcJ8jcS4rfX15e7tNNoknv7JbYXBFAbp1Y/76iqVf89FjfVcbEyH2ToAf7eyQAWzQ3gEKS8mQIkLnAwmCboUXC4GRodSIiOXiTt5Q6T02MVc8TxkhmlTA0uVLd5XgstySgE/oLBnL59lhJcwQmdhHL+m480+PaW55CtMuC36RTwk/tOyuWCDC5qMXnoveNB3yu45o3L/U4hoyJ0/5FyP5C8ahgydY0LoRZQG/mNzuraY4433rK+IfkQvZTyaDtcjhxE6hCD5F40aDDh88i6XaKAPikD6fqra6BN8PoPgLuRHzOJuqsMXBWM99s7qPgSnBbmXlekz/1jvvFiCh3zvAFTxFz2KyE4+SbDcCrhpxkNL7idw6r/ZsHaI/2+zhDcgSs5MgBwYLJEj6zUqVdp5XsF8YfC7yNZV5/qy68qY2+zXrC57SPifU2SCPE= nicola@p1" + // this is testPubKey signed as host certificate. + // % ssh-keygen -s ca_user_key -I test_user_sftp -h -n test_user_sftp -V always:forever -z 2 /tmp/test.pub + testHostCert = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg7O2LDpLO1jGTX3SSzEMILoAYJb9DdggyyaUMXUUg3L4AAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAgAAAAIAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAAAAAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgHlAWMTTzNrE6pxHlkr09ZXsHgJi8U2p7eifs56DOLgklYIXVUJPEEcnzMKGdpPBnqJsvg3+PccqxgOr5L1dFuOmekQ/dGiHd1enrESiGvJOvDfm0WsuBjxEZkSNFWgC9Z2NltToMmRlhVBmb4ZRZtAmi9DAFlJ/BDV4t8ikXZ5oUsigwIeOeLkdPFx3C3x9KZIpuwuAIV4Nfmz75q1NMWY2K1hv682QCKwMYqOWSotz1vWunNmZ0yPRl9UwqAq+nqwO3AApnlrQ3MmHujWQ5tl65PyhfpI8oghhUtB6YrJIAuRXNI/S0+KewCpiYm7nbFBtv9lpecujxAeTibYBrFZ5VODEUm3sdQ/HMdTmkhi6xNgPDQVlvKFqBJAaqoO3tbhKTbEZ865tJMqhyxmZ4XY08wduvSVobrNr7s3rm42/FXLIpung+UOVXonHyeIv9zQ0iJ/bvqKQ1fOsTisZdcD0lz80ZGsjdgJt7yKfUNBnAyVbTXm048E3WsZslJIYCA== nicola@p1" + // this is testPubKey signed using testCAUserKey but with source address 172.16.34.45. + // % ssh-keygen -s ca_user_key -I test_user_sftp -n test_user_sftp -V always:forever -O source-address=172.16.34.45 -z 3 /tmp/test.pub + testCertOtherSourceAddress = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgZ4Su0250R4sQRNYJqJH9VTp9OyeYMAvqY5+lJRI4LzMAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAwAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAAAAAAD//////////wAAACYAAAAOc291cmNlLWFkZHJlc3MAAAAQAAAADDE3Mi4xNi4zNC40NQAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgL34Q3Li8AJIxZLU+fh4i8ehUWpm31vEvlNjXVCeP70xI+5hWuEt6E1TgKw7GCL5GeD4KehX4vVcNs+A2eOdIUZfDBZIFxn88BN8xcMlDpAMJXgvNqGttiOwcspL6X3N288djUgpCI718lLRdz8nvFqcuYBhSpBm5KL4JzH5o1o8yqv75wMJsH8CJYwGhvWi0OgWOqaLRAk3IUxq3Fbgo/nX11NgrkY/dHIZCkGBFaLJ/M5mfmt/K/5hJAVgLcSxMwB/ryyGaziB9Pv7CwZ9uwnMoRcAvyr96lqgdtLt7LNY8ktugAJ7EnBWjQn4+EJAjjRK2sCaiwpdP37ckDZgmk0OWGEL1yVy8VXgl9QBd7Mb1EVl+lhRyw8jlgBXZOGqpdDrmKCdBYGtU7ujyndLXmxZEAlqhef0yCsyZPTkYH3RhjCYs8ATrEqndEpiL59Nej5uUGQURYijJfHep08AMb4rCxvIZATVm1Ocxu48rGCGolv8jZFJzSJq84HCrVRKMw== nicola@p1" + // this is testPubKey signed using testCAUserKey but expired. + // % ssh-keygen -s ca_user_key -I test_user_sftp -n test_user_sftp -V 20100101123000:20110101123000 -z 4 /tmp/test.pub + testCertExpired = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgU3TLP5285k20fBSsdZioI78oJUpaRXFlgx5IPg6gWg8AAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAABAAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAASAAAADnRlc3RfdXNlcl9zZnRwAAAAAEs93LgAAAAATR8QOAAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAGXAAAAB3NzaC1yc2EAAAADAQABAAABgQDF5fcwZHiyixmnE6IlOZJpZhWXoh62gN+yadAA0GJ509SAEaZVLPDP8S5RsE8mUikR3wxynVshxHeqMhrkS+RlNbhSlOXDdNg94yTrq/xF8Z/PgKRInvef74k5i7bAIytza7jERzFJ/ujTEy3537T5k5EYQJ15ZQGuvzynSdv+6o99SjI4jFplyQOZ2QcYbEAmhHm5GgQlIiEFG/RlDtLksOulKZxOY3qPzP0AyQxtZJXn/5vG40aW9LTbwxCJqWlgrkFXMqAAVCbuU5YspwhiXmKt1PsldiXw23oloa4caCKN1jzbFiGuZNXEU2Ebx7JIvjQCPaUYwLjEbkRDxDqN/vmwZqBuKYiuG9Eafx+nFSQkr7QYb5b+mT+/1IFHnmeRGn38731kBqtH7tpzC/t+soRX9p2HtJM+9MYhblO2OqTSPGTlxihWUkyiRBekpAhaiHld16TsG+A3bOJHrojGcX+5g6oGarKGLAMcykL1X+rZqT993Mo6d2Z7q43MOXEAAAGUAAAADHJzYS1zaGEyLTUxMgAAAYAlH3hhj8J6xLyVpeLZjblzwDKrxp/MWiH30hQ965ExPrPRcoAZFEKVqOYdj6bp4Q19Q4Yzqdobg3aN5ym2iH0b2TlOY0mM901CAoHbNJyiLs+0KiFRoJ+30EDj/hcKusg6v8ln2yixPagAyQu3zyiWo4t1ZuO3I86xchGlptStxSdHAHPFCfpbhcnzWFZctiMqUutl82C4ROWyjOZcRzdVdWHeN5h8wnooXuvba2VkT8QPmjYYyRGuQ3Hg+ySdh8Tel4wiix1Dg5MX7Wjh4hKEx80No9UPy+0iyZMNc07lsWAtrY6NRxGM5CzB6mklscB8TzFrVSnIl9u3bquLfaCrFt/Mft5dR7Yy4jmF+zUhjia6h6giCZ91J+FZ4hV+WkBtPCvTfrGWoA1BgEB/iI2xOq/NPqJ7UXRoMXk/l0NPgRPT2JS1adegqnt4ddr6IlmPyZxaSEvXhanjKdfMlEFYO1wz7ouqpYUozQVy4KXBlzFlNwyD1hI+k4+/A6AIYeI= nicola@p1" + // this is testPubKey signed without a principal + // ssh-keygen -s ca_user_key -I test_user_sftp -V always:forever -O source-address=127.0.0.1 -z 1 /tmp/test.pub + testCertNoPrincipals = "ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg2Bx0s8nafJtriqoBuQfbFByhdQMkjDIZhV90JZSGN8AAAAADAQABAAABgQC03jj0D+djk7pxIf/0OhrxrchJTRZklofJ1NoIu4752Sq02mdXmarMVsqJ1cAjV5LBVy3D1F5U6XW4rppkXeVtd04Pxb09ehtH0pRRPaoHHlALiJt8CoMpbKYMA8b3KXPPriGxgGomvtU2T2RMURSwOZbMtpsugfjYSWenyYX+VORYhylWnSXL961LTyC21ehd6d6QnW9G7E5hYMITMY9TuQZz3bROYzXiTsgN0+g6Hn7exFQp50p45StUMfV/SftCMdCxlxuyGny2CrN/vfjO7xxOo2uv7q1qm10Q46KPWJQv+pgZ/OfL+EDjy07n5QVSKHlbx+2nT4Q0EgOSQaCTYwn3YjtABfIxWwgAFdyj6YlPulCL22qU4MYhDcA6PSBwDdf8hvxBfvsiHdM+JcSHvv8/VeJhk6CmnZxGY0fxBupov27z3yEO8nAg8k+6PaUiW1MSUfuGMF/ktB8LOstXsEPXSszuyXiOv4DaryOXUiSn7bmRqKcEFlJusO6aZP0AAAAAAAAAAQAAAAEAAAAOdGVzdF91c2VyX3NmdHAAAAAAAAAAAAAAAAD//////////wAAACMAAAAOc291cmNlLWFkZHJlc3MAAAANAAAACTEyNy4wLjAuMQAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAZcAAAAHc3NoLXJzYQAAAAMBAAEAAAGBAMXl9zBkeLKLGacToiU5kmlmFZeiHraA37Jp0ADQYnnT1IARplUs8M/xLlGwTyZSKRHfDHKdWyHEd6oyGuRL5GU1uFKU5cN02D3jJOur/EXxn8+ApEie95/viTmLtsAjK3NruMRHMUn+6NMTLfnftPmTkRhAnXllAa6/PKdJ2/7qj31KMjiMWmXJA5nZBxhsQCaEebkaBCUiIQUb9GUO0uSw66UpnE5jeo/M/QDJDG1klef/m8bjRpb0tNvDEImpaWCuQVcyoABUJu5TliynCGJeYq3U+yV2JfDbeiWhrhxoIo3WPNsWIa5k1cRTYRvHski+NAI9pRjAuMRuREPEOo3++bBmoG4piK4b0Rp/H6cVJCSvtBhvlv6ZP7/UgUeeZ5EaffzvfWQGq0fu2nML+36yhFf2nYe0kz70xiFuU7Y6pNI8ZOXGKFZSTKJEF6SkCFqIeV3XpOwb4Dds4keuiMZxf7mDqgZqsoYsAxzKQvVf6tmpP33cyjp3Znurjcw5cQAAAZQAAAAMcnNhLXNoYTItNTEyAAABgHgax/++NA5YZXDHH180BcQtDBve8Vc+XJzqQUe8xBiqd+KJnas6He7vW62qMaAfu63i0Uycj2Djfjy5dyx1GB9wup8YuP5mXlmJTx+7UPPjwbfrZWtk8iJ7KhFAwjh0KRZD4uIvoeecK8QE9zh64k2LNVqlWbFTdoPulRC29cGcXDpMU2eToFEyWbceHOZyyifXf98ZMZbaQzWzwSZ5rFucJ1b0aeT6aAJWB+Dq7mIQWf/jCWr8kNaeCzMKJsFQkQEfmHls29ChV92sNRhngUDxll0Ir0wpPea1fFEBnUhLRTLC8GhDDbWAzsZtXqx9fjoAkb/gwsU6TGxevuOMxEABjDA9PyJiTXJI9oTUCwDIAUVVFLsCEum3o/BblngXajUGibaif5ZSKBocpP70oTeAngQYB7r1/vquQzGsGFhTN4FUXLSpLu9Zqi1z58/qa7SgKSfNp98X/4zrhltAX73ZEvg0NUMv2HwlwlqHdpF3FYolAxInp7c2jBTncQ2l3w== nicola@p1" + osWindows = "windows" + testFileName = "test_file_sftp.dat" + testDLFileName = "test_download_sftp.dat" +) + +var ( + configDir = filepath.Join(".", "..", "..") + allPerms = []string{dataprovider.PermAny} + homeBasePath string + scpPath string + scpForce bool + gitPath string + sshPath string + hookCmdPath string + pubKeyPath string + privateKeyPath string + trustedCAUserKey string + revokeUserCerts string + gitWrapPath string + extAuthPath string + keyIntAuthPath string + preLoginPath string + postConnectPath string + preDownloadPath string + preUploadPath string + checkPwdPath string + logFilePath string + hostKeyFPs []string +) + +func TestMain(m *testing.M) { + logFilePath = filepath.Join(configDir, "sftpgo_sftpd_test.log") + loginBannerFileName := "login_banner" + loginBannerFile := filepath.Join(configDir, loginBannerFileName) + logger.InitLogger(logFilePath, 10, 1, 28, false, false, zerolog.DebugLevel) + err := os.WriteFile(loginBannerFile, []byte("simple login banner\n"), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error creating login banner: %v", err) + } + os.Setenv("SFTPGO_COMMON__UPLOAD_MODE", "2") + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") + os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") + os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") + os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") + err = config.LoadConfig(configDir, "") + if err != nil { + logger.ErrorToConsole("error loading configuration: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + logger.InfoToConsole("Starting SFTPD tests, provider: %v", providerConf.Driver) + + commonConf := config.GetCommonConfig() + homeBasePath = os.TempDir() + checkSystemCommands() + var scriptArgs string + if runtime.GOOS == osWindows { + scriptArgs = "%*" + } else { + commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete", "ssh_cmd", + "pre-download", "pre-upload"} + commonConf.Actions.Hook = hookCmdPath + scriptArgs = "$@" + } + + err = dataprovider.Initialize(providerConf, configDir, true) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + + err = dataprovider.UpdateConfigs(nil, "", "", "") + if err != nil { + logger.ErrorToConsole("error resetting configs: %v", err) + os.Exit(1) + } + + err = common.Initialize(commonConf, 0) + if err != nil { + logger.WarnToConsole("error initializing common: %v", err) + os.Exit(1) + } + + httpConfig := config.GetHTTPConfig() + httpConfig.Initialize(configDir) //nolint:errcheck + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing kms: %v", err) + os.Exit(1) + } + mfaConfig := config.GetMFAConfig() + err = mfaConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing MFA: %v", err) + os.Exit(1) + } + + sftpdConf := config.GetSFTPDConfig() + httpdConf := config.GetHTTPDConfig() + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 2022, + ApplyProxyConfig: true, + }, + } + sftpdConf.KexAlgorithms = []string{"curve25519-sha256@libssh.org", ssh.KeyExchangeECDHP256, + ssh.KeyExchangeECDHP384} + sftpdConf.Ciphers = []string{ssh.CipherChaCha20Poly1305, ssh.CipherAES128GCM, + ssh.CipherAES256CTR} + sftpdConf.LoginBannerFile = loginBannerFileName + // we need to test all supported ssh commands + sftpdConf.EnabledSSHCommands = []string{"*"} + + keyIntAuthPath = filepath.Join(homeBasePath, "keyintauth.sh") + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing keyboard interactive script: %v", err) + os.Exit(1) + } + sftpdConf.KeyboardInteractiveAuthentication = true + sftpdConf.KeyboardInteractiveHook = keyIntAuthPath + + createInitialFiles(scriptArgs) + sftpdConf.TrustedUserCAKeys = append(sftpdConf.TrustedUserCAKeys, trustedCAUserKey) + sftpdConf.RevokedUserCertsFile = revokeUserCerts + + go func(cfg sftpd.Configuration) { + logger.Debug(logSender, "", "initializing SFTP server with config %+v", sftpdConf) + if err := cfg.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server: %v", err) + os.Exit(1) + } + }(sftpdConf) + + go func() { + if err := httpdConf.Initialize(configDir, 0); err != nil { + logger.ErrorToConsole("could not start HTTP server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 2222, + ApplyProxyConfig: true, + }, + } + sftpdConf.PasswordAuthentication = false + common.Config.ProxyProtocol = 1 + go func(cfg sftpd.Configuration) { + logger.Debug(logSender, "", "initializing SFTP server with config %+v and proxy protocol %v", + sftpdConf, common.Config.ProxyProtocol) + if err := cfg.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server with proxy protocol 1: %v", err) + os.Exit(1) + } + }(sftpdConf) + + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 2226, + ApplyProxyConfig: false, + }, + } + sftpdConf.PasswordAuthentication = true + go func(cfg sftpd.Configuration) { + logger.Debug(logSender, "", "initializing SFTP server with config %+v and proxy protocol %v", + cfg, common.Config.ProxyProtocol) + if err := cfg.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server with proxy protocol 2: %v", err) + os.Exit(1) + } + }(sftpdConf) + + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 2224, + ApplyProxyConfig: true, + }, + } + sftpdConf.PasswordAuthentication = true + common.Config.ProxyProtocol = 2 + go func() { + logger.Debug(logSender, "", "initializing SFTP server with config %+v and proxy protocol %v", + sftpdConf, common.Config.ProxyProtocol) + if err := sftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server with proxy protocol 2: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + getHostKeysFingerprints(sftpdConf.HostKeys) + startHTTPFs() + + exitCode := m.Run() + os.Remove(logFilePath) + os.Remove(loginBannerFile) + os.Remove(pubKeyPath) + os.Remove(privateKeyPath) + os.Remove(trustedCAUserKey) + os.Remove(revokeUserCerts) + os.Remove(gitWrapPath) + os.Remove(extAuthPath) + os.Remove(preLoginPath) + os.Remove(postConnectPath) + os.Remove(preDownloadPath) + os.Remove(preUploadPath) + os.Remove(keyIntAuthPath) + os.Remove(checkPwdPath) + os.Exit(exitCode) +} + +func TestInitialization(t *testing.T) { + err := config.LoadConfig(configDir, "") + assert.NoError(t, err) + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 2022, + ApplyProxyConfig: true, + }, + { + Port: 0, + }, + } + sftpdConf.LoginBannerFile = "invalid_file" + sftpdConf.EnabledSSHCommands = append(sftpdConf.EnabledSSHCommands, "ls") + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.KeyboardInteractiveAuthentication = true + sftpdConf.KeyboardInteractiveHook = "invalid_file" + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.KeyboardInteractiveAuthentication = true + sftpdConf.KeyboardInteractiveHook = filepath.Join(homeBasePath, "invalid_file") + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.KeyboardInteractiveAuthentication = false + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 4444, + ApplyProxyConfig: true, + }, + } + common.Config.ProxyProtocol = 1 + assert.True(t, sftpdConf.Bindings[0].HasProxy()) + common.Config.ProxyProtocol = 0 + sftpdConf.HostKeys = []string{"missing key"} + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.HostKeys = nil + sftpdConf.TrustedUserCAKeys = []string{"missing ca key"} + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.Bindings = nil + err = sftpdConf.Initialize(configDir) + assert.EqualError(t, err, common.ErrNoBinding.Error()) + sftpdConf = config.GetSFTPDConfig() + sftpdConf.Ciphers = []string{"not a cipher"} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported cipher") + } + sftpdConf.Ciphers = nil + sftpdConf.MACs = []string{"not a MAC"} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported MAC algorithm") + } + sftpdConf.MACs = nil + sftpdConf.KexAlgorithms = []string{"diffie-hellman-group-exchange-sha1", "not a KEX"} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported key-exchange algorithm") + } + sftpdConf.KexAlgorithms = nil + sftpdConf.PublicKeyAlgorithms = []string{"not a pub key algo"} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported public key authentication algorithm") + } + sftpdConf.PublicKeyAlgorithms = nil + sftpdConf.HostKeyAlgorithms = []string{"not a host key algo"} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unsupported host key algorithm") + } + sftpdConf.HostKeyAlgorithms = nil + sftpdConf.HostCertificates = []string{"missing file"} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to load host certificate") + } + sftpdConf.HostCertificates = []string{"."} + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + hostCertPath := filepath.Join(os.TempDir(), "host_cert.pub") + err = os.WriteFile(hostCertPath, []byte(testCertValid), 0600) + assert.NoError(t, err) + sftpdConf.HostKeys = []string{privateKeyPath} + sftpdConf.HostCertificates = []string{hostCertPath} + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not an host certificate") + } + err = os.WriteFile(hostCertPath, []byte(testPubKey), 0600) + assert.NoError(t, err) + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not an SSH certificate") + } + err = os.WriteFile(hostCertPath, []byte("abc"), 0600) + assert.NoError(t, err) + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to parse host certificate") + } + err = os.WriteFile(hostCertPath, []byte(testHostCert), 0600) + assert.NoError(t, err) + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + + err = os.Remove(hostCertPath) + assert.NoError(t, err) + sftpdConf.HostKeys = nil + sftpdConf.HostCertificates = nil + sftpdConf.OPKSSHPath = "relative path" + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.OPKSSHPath = filepath.Join(os.TempDir(), "missing path") + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.OPKSSHChecksum = "invalid checksum" + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.OPKSSHPath = "" + sftpdConf.OPKSSHChecksum = "" + sftpdConf.RevokedUserCertsFile = "." + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + sftpdConf.RevokedUserCertsFile = "a missing file" + err = sftpdConf.Initialize(configDir) + assert.ErrorIs(t, err, os.ErrNotExist) + + err = createTestFile(revokeUserCerts, 10*1024*1024) + assert.NoError(t, err) + sftpdConf.RevokedUserCertsFile = revokeUserCerts + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + + err = os.WriteFile(revokeUserCerts, []byte(`[]`), 0644) + assert.NoError(t, err) + err = sftpdConf.Initialize(configDir) + assert.Error(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = sftpdConf.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to load configs from provider") + } + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestBasicSFTPHandling(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := user.UsedQuotaSize + testFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("/missing_dir", testFileName), testFileSize, client) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, int64(0), user.FirstDownload) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Equal(t, int64(0), user.FirstDownload) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + err = client.Remove(testFileName) + assert.NoError(t, err) + _, err = client.Lstat(testFileName) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Greater(t, user.FirstDownload, int64(0)) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + u.Username = "missing user" + _, _, err = getSftpClient(u, false) + assert.Error(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + status := sftpd.GetStatus() + assert.True(t, status.IsActive) + sshCommands := status.GetSSHCommandsAsString() + assert.NotEmpty(t, sshCommands) + sshAuths := status.GetSupportedAuthsAsString() + assert.NotEmpty(t, sshAuths) + assert.NotEmpty(t, status.HostKeys[0].GetAlgosAsString()) + assert.NotEmpty(t, status.GetMACsAsString()) + assert.NotEmpty(t, status.GetKEXsAsString()) + assert.NotEmpty(t, status.GetCiphersAsString()) + assert.NotEmpty(t, status.GetPublicKeysAlgosAsString()) +} + +func TestBasicSFTPFsHandling(t *testing.T) { + usePubKey := true + baseUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser(usePubKey) + u.QuotaSize = 6553600 + u.FsConfig.SFTPConfig.DisableCouncurrentReads = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + testLinkName := testFileName + ".link" + testLinkToLinkName := testLinkName + ".link" + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + err = client.Symlink(testFileName, testLinkName) + assert.NoError(t, err) + info, err := client.Lstat(testLinkName) + if assert.NoError(t, err) { + assert.True(t, info.Mode()&os.ModeSymlink != 0) + } + info, err = client.Stat(testLinkName) + if assert.NoError(t, err) { + assert.True(t, info.Mode()&os.ModeSymlink == 0) + } + val, err := client.ReadLink(testLinkName) + if assert.NoError(t, err) { + assert.Equal(t, path.Join("/", testFileName), val) + } + linkDir := "linkDir" + err = client.Mkdir(linkDir) + assert.NoError(t, err) + linkToLinkPath := path.Join(linkDir, testLinkToLinkName) + err = client.Symlink(path.Join("/", testLinkName), linkToLinkPath) + assert.NoError(t, err) + info, err = client.Lstat(linkToLinkPath) + if assert.NoError(t, err) { + assert.True(t, info.Mode()&os.ModeSymlink != 0) + } + info, err = client.Stat(linkToLinkPath) + if assert.NoError(t, err) { + assert.True(t, info.Mode()&os.ModeSymlink == 0) + } + val, err = client.ReadLink(linkToLinkPath) + if assert.NoError(t, err) { + assert.Equal(t, path.Join("/", testLinkName), val) + } + err = client.Remove(linkToLinkPath) + assert.NoError(t, err) + err = client.RemoveDirectory(linkDir) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + _, err = client.Lstat(testFileName) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + // now overwrite the symlink + err = sftpUploadFile(testFilePath, testLinkName, testFileSize, client) + assert.NoError(t, err) + contents, err := client.ReadDir("/") + if assert.NoError(t, err) { + assert.Len(t, contents, 1) + assert.Equal(t, testFileSize, contents[0].Size()) + assert.Equal(t, testLinkName, contents[0].Name()) + assert.False(t, contents[0].IsDir()) + assert.True(t, contents[0].Mode().IsRegular()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Equal(t, uint64(u.QuotaSize/4096), stat.Blocks) + assert.Equal(t, uint64((u.QuotaSize-testFileSize)/4096), stat.Bfree) + assert.Equal(t, uint64(1), stat.Files-stat.Ffree) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPFsPasswordProtectedPrivateKey(t *testing.T) { + usePubKey := false + u := getTestUser(true) + u.PublicKeys = []string{testPubKeyPwd} + baseUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKeyPwd) + u.FsConfig.SFTPConfig.KeyPassphrase = kms.NewPlainSecret(privateKeyPwd) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + // update the user, the key must be preserved + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPFsEscapeHomeDir(t *testing.T) { + usePubKey := true + baseUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser(usePubKey) + sftpPrefix := "/prefix" + u.FsConfig.SFTPConfig.Prefix = sftpPrefix + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + dirName := "dir" + linkName := "link" + err := client.Mkdir(dirName) + assert.NoError(t, err) + err = os.Symlink(baseUser.GetHomeDir(), filepath.Join(baseUser.GetHomeDir(), sftpPrefix, dirName, linkName)) + assert.NoError(t, err) + err = os.Symlink(filepath.Join(baseUser.GetHomeDir(), sftpPrefix, dirName, linkName), + filepath.Join(baseUser.GetHomeDir(), sftpPrefix, linkName)) + assert.NoError(t, err) + // linkName points to a link inside the home dir and this link points to a dir outside the home dir + _, err = client.ReadLink(linkName) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.RealPath(linkName) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.ReadDir(linkName) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.ReadDir(path.Join(dirName, linkName)) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.ReadDir("/") + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestReadDirLongNames(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + numFiles := 1000 + for i := 0; i < 1000; i++ { + fPath := filepath.Join(user.GetHomeDir(), hex.EncodeToString(util.GenerateRandomBytes(127))) + err = os.WriteFile(fPath, util.GenerateRandomBytes(30), 0666) + assert.NoError(t, err) + } + + entries, err := client.ReadDir("/") + assert.NoError(t, err) + assert.Len(t, entries, numFiles) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestGroupSettingsOverride(t *testing.T) { + usePubKey := true + g := getTestGroup() + g.UserSettings.Filters.StartDirectory = "/%username%" + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + u := getTestUser(usePubKey) + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + currentDir, err := client.Getwd() + assert.NoError(t, err) + assert.Equal(t, "/"+user.Username, currentDir) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestStartDirectory(t *testing.T) { + usePubKey := false + startDir := "/st@ rt/dir" + u := getTestUser(usePubKey) + u.Filters.StartDirectory = startDir + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.Filters.StartDirectory = startDir + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + currentDir, err := client.Getwd() + assert.NoError(t, err) + assert.Equal(t, startDir, currentDir) + + entries, err := client.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, entries, 0) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + _, err = client.Stat(testFileName) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+"_rename") + assert.NoError(t, err) + + entries, err = client.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, entries, 1) + + currentDir, err = client.RealPath("..") + assert.NoError(t, err) + assert.Equal(t, path.Dir(startDir), currentDir) + + currentDir, err = client.RealPath("../..") + assert.NoError(t, err) + assert.Equal(t, "/", currentDir) + + currentDir, err = client.RealPath("../../..") + assert.NoError(t, err) + assert.Equal(t, "/", currentDir) + + err = client.Remove(testFileName + "_rename") + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginNonExistentUser(t *testing.T) { + usePubKey := true + user := getTestUser(usePubKey) + _, _, err := getSftpClient(user, usePubKey) + assert.Error(t, err) +} + +func TestRateLimiter(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.RateLimitersConfig = []common.RateLimiterConfig{ + { + Average: 1, + Period: 1000, + Burst: 1, + Type: 1, + Protocols: []string{common.ProtocolSSH}, + }, + } + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestDefender(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Threshold = 3 + cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreValid = 1 + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + user.Password = "wrong_pwd" + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 1, host.Score) + } + + for i := 0; i < 2; i++ { + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + } + + user.Password = defaultPassword + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestOpenReadWrite(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaSize = 6553600 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaSize = 6553600 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("sample test data") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + buffer := make([]byte, 128) + n, err = sftpFile.ReadAt(buffer, 1) + assert.EqualError(t, err, io.EOF.Error()) + assert.Equal(t, len(testData)-1, n) + assert.Equal(t, testData[1:], buffer[:n]) + err = sftpFile.Close() + assert.NoError(t, err) + } + sftpFile, err = client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("new test data") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + buffer := make([]byte, 128) + n, err = sftpFile.ReadAt(buffer, 1) + assert.EqualError(t, err, io.EOF.Error()) + assert.Equal(t, len(testData)-1, n) + assert.Equal(t, testData[1:], buffer[:n]) + err = sftpFile.Close() + assert.NoError(t, err) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestOpenReadWritePerm(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + // we cannot read inside "/sub", rename is needed otherwise the atomic upload will fail for the sftpfs user + u.Permissions["/sub"] = []string{dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermRename} + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.Permissions["/sub"] = []string{dataprovider.PermUpload, dataprovider.PermListItems} + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("sub") + assert.NoError(t, err) + sftpFileName := path.Join("sub", "file.txt") + sftpFile, err := client.OpenFile(sftpFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("test data") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + buffer := make([]byte, 128) + _, err = sftpFile.ReadAt(buffer, 1) + if assert.Error(t, err) { + assert.Contains(t, strings.ToLower(err.Error()), "permission denied") + } + err = sftpFile.Close() + assert.NoError(t, err) + } + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestConcurrency(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 0 + + usePubKey := true + numLogins := 50 + u := getTestUser(usePubKey) + u.QuotaFiles = numLogins + 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + var wg sync.WaitGroup + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(262144) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + var closedConns atomic.Int32 + for i := 0; i < numLogins; i++ { + wg.Add(1) + go func(counter int) { + defer wg.Done() + defer closedConns.Add(1) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + err = sftpUploadFile(testFilePath, testFileName+strconv.Itoa(counter), testFileSize, client) + assert.NoError(t, err) + assert.Greater(t, common.Connections.GetActiveSessions(defaultUsername), 0) + client.Close() + conn.Close() + } + }(i) + } + + wg.Add(1) + go func() { + defer wg.Done() + + maxConns := 0 + maxSessions := 0 + for { + servedReqs := closedConns.Load() + if servedReqs > 0 { + stats := common.Connections.GetStats("") + if len(stats) > maxConns { + maxConns = len(stats) + } + activeSessions := common.Connections.GetActiveSessions(defaultUsername) + if activeSessions > maxSessions { + maxSessions = activeSessions + } + } + if servedReqs >= int32(numLogins) { + break + } + time.Sleep(1 * time.Millisecond) + } + assert.Greater(t, maxConns, 0) + assert.Greater(t, maxSessions, 0) + }() + + wg.Wait() + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + files, err := client.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, files, numLogins) + client.Close() + conn.Close() + } + + assert.Eventually(t, func() bool { + return common.Connections.GetActiveSessions(defaultUsername) == 0 + }, 1*time.Second, 50*time.Millisecond) + + assert.Eventually(t, func() bool { + return len(common.Connections.GetStats("")) == 0 + }, 1*time.Second, 50*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestProxyProtocol(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + // remove the home dir to test auto creation + err = os.RemoveAll(user.HomeDir) + assert.NoError(t, err) + conn, client, err := getSftpClientWithAddr(user, usePubKey, sftpSrvAddr2222) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + _, _, err = getSftpClientWithAddr(user, usePubKey, "127.0.0.1:2224") + assert.Error(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRealPath(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + p, err := client.RealPath("../..") + assert.NoError(t, err) + assert.Equal(t, "/", p) + p, err = client.RealPath("../test") + assert.NoError(t, err) + assert.Equal(t, "/test", p) + subdir := "testsubdir" + err = client.Mkdir(subdir) + assert.NoError(t, err) + linkName := testFileName + "_link" + err = client.Symlink(path.Join("/", testFileName), path.Join(subdir, linkName)) + assert.NoError(t, err) + p, err = client.RealPath(path.Join(subdir, linkName)) + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testFileName), p) + // an existing path + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("hello world") + n, err := sftpFile.WriteAt(testData, 0) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + } + p, err = client.RealPath(path.Join(subdir, linkName)) + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testFileName), p) + // now a link outside the home dir + err = os.Symlink(filepath.Clean(os.TempDir()), filepath.Join(localUser.GetHomeDir(), subdir, "temp")) + assert.NoError(t, err) + _, err = client.RealPath(path.Join(subdir, "temp")) + assert.ErrorIs(t, err, os.ErrPermission) + + conn.Close() + client.Close() + err = os.Remove(filepath.Join(localUser.GetHomeDir(), subdir, "temp")) + assert.NoError(t, err) + if user.Username == localUser.Username { + err = os.RemoveAll(filepath.Join(localUser.GetHomeDir(), subdir)) + assert.NoError(t, err) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestBufferedSFTP(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.FsConfig.SFTPConfig.BufferSize = 2 + u.HomeDir = filepath.Join(os.TempDir(), u.Username) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(sftpUser, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + appendDataSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + initialHash, err := computeHashForFile(sha256.New(), testFilePath) + assert.NoError(t, err) + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = appendToTestFile(testFilePath, appendDataSize) + assert.NoError(t, err) + err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, false, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + downloadedFileHash, err := computeHashForFile(sha256.New(), localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, initialHash, downloadedFileHash) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("sample test sftp data") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + err = sftpFile.Truncate(0) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = sftpFile.Truncate(4) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + buffer := make([]byte, 128) + _, err = sftpFile.Read(buffer) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_OP_UNSUPPORTED") + } + err = sftpFile.Close() + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(len(testData)), info.Size()) + } + } + // test WriteAt + sftpFile, err = client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("hello world") + n, err := sftpFile.WriteAt(testData[:6], 0) + assert.NoError(t, err) + assert.Equal(t, 6, n) + n, err = sftpFile.WriteAt(testData[6:], 6) + assert.NoError(t, err) + assert.Equal(t, 5, n) + err = sftpFile.Close() + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(len(testData)), info.Size()) + } + } + // test ReadAt + sftpFile, err = client.OpenFile(testFileName, os.O_RDONLY) + if assert.NoError(t, err) { + buffer := make([]byte, 128) + n, err := sftpFile.ReadAt(buffer, 6) + assert.ErrorIs(t, err, io.EOF) + assert.Equal(t, 5, n) + assert.Equal(t, []byte("world"), buffer[:n]) + err = sftpFile.Close() + assert.NoError(t, err) + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(sftpUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadResume(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestUser(usePubKey) + u.FsConfig.OSConfig = sdk.OSFsConfig{ + WriteBufferSize: 1, + ReadBufferSize: 1, + } + u.Username += "_buffered" + u.HomeDir += "_with_buf" + bufferedUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + for _, user := range []dataprovider.User{localUser, sftpUser, bufferedUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + appendDataSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = appendToTestFile(testFilePath, appendDataSize) + assert.NoError(t, err) + err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, false, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize+appendDataSize, client) + assert.NoError(t, err) + initialHash, err := computeHashForFile(sha256.New(), testFilePath) + assert.NoError(t, err) + downloadedFileHash, err := computeHashForFile(sha256.New(), localDownloadPath) + assert.NoError(t, err) + assert.Equal(t, initialHash, downloadedFileHash) + err = sftpUploadResumeFile(testFilePath, testFileName, testFileSize+appendDataSize, true, client) + assert.Error(t, err, "resume uploading file with invalid offset must fail") + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(bufferedUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(bufferedUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDirCommands(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + // remove the home dir to test auto creation + err = os.RemoveAll(user.HomeDir) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("test1") + assert.NoError(t, err) + err = client.Rename("test1", "test") + assert.NoError(t, err) + // rename a missing file + err = client.Rename("test1", "test2") + assert.Error(t, err) + _, err = client.Lstat("/test1") + assert.Error(t, err, "stat for renamed dir must not succeed") + err = client.PosixRename("test", "test1") + assert.NoError(t, err) + err = client.Remove("test1") + assert.NoError(t, err) + err = client.Mkdir("/test/test1") + assert.Error(t, err, "recursive mkdir must fail") + err = client.Mkdir("/test") + assert.NoError(t, err) + err = client.Mkdir("/test/test1") + assert.NoError(t, err) + _, err = client.ReadDir("/this/dir/does/not/exist") + assert.Error(t, err, "reading a missing dir must fail") + err = client.RemoveDirectory("/test/test1") + assert.NoError(t, err) + err = client.RemoveDirectory("/test") + assert.NoError(t, err) + _, err = client.Lstat("/test") + assert.Error(t, err, "stat for deleted dir must not succeed") + _, err = client.Stat("/test") + assert.Error(t, err, "stat for deleted dir must not succeed") + err = client.RemoveDirectory("/test") + assert.Error(t, err, "remove missing path must fail") + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRemove(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("test") + assert.NoError(t, err) + err = client.Mkdir("/test/test1") + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("/test", testFileName), testFileSize, client) + assert.NoError(t, err) + err = client.Remove("/test") + assert.Error(t, err, "remove non empty dir must fail") + err = client.RemoveDirectory(path.Join("/test", testFileName)) + assert.Error(t, err, "remove a file with rmdir must fail") + err = client.Remove(path.Join("/test", testFileName)) + assert.NoError(t, err) + err = client.Remove(path.Join("/test", testFileName)) + assert.Error(t, err, "remove missing file must fail") + err = client.Remove("/test/test1") + assert.NoError(t, err) + err = client.Remove("/test") + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLink(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + linkName, err := client.ReadLink(testFileName + ".link") + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testFileName), linkName) + err = client.Symlink(testFileName, testFileName+".link") + assert.Error(t, err, "creating a symlink to an existing one must fail") + err = client.Link(testFileName, testFileName+".hlink") + assert.Error(t, err, "hard link is not supported and must fail") + err = client.Remove(testFileName + ".link") + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestStat(t *testing.T) { + usePubKey := false + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + _, err := client.Lstat(testFileName) + assert.NoError(t, err) + _, err = client.Stat(testFileName) + assert.NoError(t, err) + // stat a missing path we should get an fs.ErrNotExist error + _, err = client.Stat("missing path") + assert.True(t, errors.Is(err, fs.ErrNotExist)) + _, err = client.Lstat("missing path") + assert.True(t, errors.Is(err, fs.ErrNotExist)) + // mode 0666 and 0444 works on Windows too + newPerm := os.FileMode(0666) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + newFi, err := client.Lstat(testFileName) + assert.NoError(t, err) + assert.Equal(t, newPerm, newFi.Mode().Perm()) + newPerm = os.FileMode(0444) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + newFi, err = client.Lstat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, newPerm, newFi.Mode().Perm()) + } + _, err = client.ReadLink(testFileName) + assert.Error(t, err, "readlink on a file must fail") + symlinkName := testFileName + ".sym" + err = client.Symlink(testFileName, symlinkName) + assert.NoError(t, err) + info, err := client.Lstat(symlinkName) + if assert.NoError(t, err) { + assert.True(t, info.Mode()&os.ModeSymlink != 0) + } + info, err = client.Stat(symlinkName) + if assert.NoError(t, err) { + assert.False(t, info.Mode()&os.ModeSymlink != 0) + } + linkName, err := client.ReadLink(symlinkName) + assert.NoError(t, err) + assert.Equal(t, path.Join("/", testFileName), linkName) + newPerm = os.FileMode(0666) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + err = client.Truncate(testFileName, 100) + assert.NoError(t, err) + fi, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, int64(100), fi.Size()) + } + f, err := client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + err = f.Truncate(5) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + f, err = client.OpenFile(testFileName, os.O_WRONLY) + newPerm = os.FileMode(0444) + if assert.NoError(t, err) { + err = f.Chmod(newPerm) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + newFi, err = client.Lstat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, newPerm, newFi.Mode().Perm()) + } + newPerm = os.FileMode(0666) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestStatChownChmod(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("chown is not supported on Windows, chmod is partially supported") + } + usePubKey := true + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.NoError(t, err) + newPerm := os.FileMode(0600) + err = client.Chmod(testFileName, newPerm) + assert.NoError(t, err) + newFi, err := client.Lstat(testFileName) + assert.NoError(t, err) + assert.Equal(t, newPerm, newFi.Mode().Perm()) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Chmod(testFileName, newPerm) + assert.EqualError(t, err, os.ErrNotExist.Error()) + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.EqualError(t, err, os.ErrNotExist.Error()) + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPFsLoginWrongFingerprint(t *testing.T) { + usePubKey := true + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(sftpUser, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + sftpUser.FsConfig.SFTPConfig.Fingerprints = append(sftpUser.FsConfig.SFTPConfig.Fingerprints, "wrong") + _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(sftpUser, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + out, err := runSSHCommand("md5sum", sftpUser, usePubKey) + assert.NoError(t, err) + assert.Contains(t, string(out), "d41d8cd98f00b204e9800998ecf8427e") + + sftpUser.FsConfig.SFTPConfig.Fingerprints = []string{"wrong"} + _, _, err = httpdtest.UpdateUser(sftpUser, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(sftpUser, usePubKey) + if !assert.Error(t, err) { + defer conn.Close() + defer client.Close() + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestChtimes(t *testing.T) { + usePubKey := false + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + testDir := "test" //nolint:goconst + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + acmodTime := time.Now() + err = client.Chtimes(testFileName, acmodTime, acmodTime) + assert.NoError(t, err) + newFi, err := client.Lstat(testFileName) + assert.NoError(t, err) + diff := math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) + assert.LessOrEqual(t, diff, float64(1)) + err = client.Chtimes("invalidFile", acmodTime, acmodTime) + assert.EqualError(t, err, os.ErrNotExist.Error()) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Chtimes(testDir, acmodTime, acmodTime) + assert.NoError(t, err) + newFi, err = client.Lstat(testDir) + assert.NoError(t, err) + diff = math.Abs(newFi.ModTime().Sub(acmodTime).Seconds()) + assert.LessOrEqual(t, diff, float64(1)) + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +// basic tests to verify virtual chroot, should be improved to cover more cases ... +func TestEscapeHomeDir(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + dirOutsideHome := filepath.Join(homeBasePath, defaultUsername+"1", "dir") + err = os.MkdirAll(dirOutsideHome, os.ModePerm) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + testDir := "testDir" //nolint:goconst + linkPath := filepath.Join(homeBasePath, defaultUsername, testDir) + err = os.Symlink(homeBasePath, linkPath) + assert.NoError(t, err) + _, err = client.ReadDir(testDir) + assert.Error(t, err, "reading a symbolic link outside home dir should not succeeded") + err = os.Remove(linkPath) + assert.NoError(t, err) + err = os.Symlink(dirOutsideHome, linkPath) + assert.NoError(t, err) + _, err := client.ReadDir(testDir) + assert.Error(t, err, "reading a symbolic link outside home dir should not succeeded") + err = client.Chmod(path.Join(testDir, "sub", "dir"), os.ModePerm) + assert.ErrorIs(t, err, os.ErrPermission) + assert.Error(t, err, "setstat on a file outside home dir must fail") + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteDestPath := path.Join("..", "..", testFileName) + err = sftpUploadFile(testFilePath, remoteDestPath, testFileSize, client) + assert.NoError(t, err) + _, err = client.Lstat(testFileName) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + linkPath = filepath.Join(homeBasePath, defaultUsername, testFileName) + err = os.Symlink(homeBasePath, linkPath) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, testFilePath, 0, client) + assert.Error(t, err, "download file outside home dir must fail") + err = sftpUploadFile(testFilePath, remoteDestPath, testFileSize, client) + assert.Error(t, err, "overwrite a file outside home dir must fail") + err = client.Chmod(remoteDestPath, 0644) + assert.Error(t, err, "setstat on a file outside home dir must fail") + err = os.Remove(linkPath) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(homeBasePath, defaultUsername+"1")) + assert.NoError(t, err) +} + +func TestEscapeSFTPFsPrefix(t *testing.T) { + usePubKey := false + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + u := getTestSFTPUser(usePubKey) + sftpPrefix := "/prefix" + outPrefix1 := "/pre" + outPrefix2 := sftpPrefix + "1" + out1 := "out1" + out2 := "out2" + u.FsConfig.SFTPConfig.Prefix = sftpPrefix + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(localUser, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir(sftpPrefix) + assert.NoError(t, err) + err = client.Mkdir(outPrefix1) + assert.NoError(t, err) + err = client.Mkdir(outPrefix2) + assert.NoError(t, err) + err = client.Symlink(outPrefix1, path.Join(sftpPrefix, out1)) + assert.NoError(t, err) + err = client.Symlink(outPrefix2, path.Join(sftpPrefix, out2)) + assert.NoError(t, err) + } + + conn, client, err = getSftpClient(sftpUser, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + contents, err := client.ReadDir("/") + assert.NoError(t, err) + assert.Len(t, contents, 2) + _, err = client.ReadDir(out1) + assert.Error(t, err) + _, err = client.ReadDir(out2) + assert.Error(t, err) + err = client.Mkdir(path.Join(out1, "subout1")) + assert.Error(t, err) + err = client.Mkdir(path.Join(out2, "subout2")) + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestGetMimeTypeSFTPFs(t *testing.T) { + usePubKey := false + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(localUser, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC) + if assert.NoError(t, err) { + testData := []byte("some UTF-8 text so we should get a text/plain mime type") + n, err := sftpFile.Write(testData) + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + err = sftpFile.Close() + assert.NoError(t, err) + } + } + + sftpUser.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + sftpUser.FsConfig.SFTPConfig.PrivateKey = kms.NewEmptySecret() + fs, err := sftpUser.GetFilesystem("connID") + if assert.NoError(t, err) { + assert.True(t, vfs.IsSFTPFs(fs)) + mime, err := fs.GetMimeType(testFileName) + assert.NoError(t, err) + assert.Equal(t, "text/plain; charset=utf-8", mime) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHomeSpecialChars(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.HomeDir = filepath.Join(homeBasePath, "abc açà#&%lk") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + files, err := client.ReadDir(".") + assert.NoError(t, err) + assert.Equal(t, 1, len(files)) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLogin(t *testing.T) { + u := getTestUser(false) + u.PublicKeys = []string{testPubKey} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, false) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) + } + conn, client, err = getSftpClient(user, true) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user.Password = "invalid password" + conn, client, err = getSftpClient(user, false) + if !assert.Error(t, err, "login with invalid password must fail") { + client.Close() + conn.Close() + } + // testPubKey1 is not authorized + user.PublicKeys = []string{testPubKey1} + user.Password = "" + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, true) + if !assert.Error(t, err, "login with invalid public key must fail") { + defer conn.Close() + defer client.Close() + } + // login a user with multiple public keys, only the second one is valid + user.PublicKeys = []string{testPubKey1, testPubKey} + user.Password = "" + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, true) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginUserCert(t *testing.T) { + u := getTestUser(true) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + // try login using a cert signed from a trusted CA + signer, err := getSignerForUserCert([]byte(testCertValid)) + assert.NoError(t, err) + conn, client, err := getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + // revoke the certificate + certs := []string{"SHA256:OkxVB1ImSJ2XeI8nA2Wg+6zJVlxdevD1FYBSEJjFEN4"} + data, err := json.Marshal(certs) + assert.NoError(t, err) + err = os.WriteFile(revokeUserCerts, data, 0644) + assert.NoError(t, err) + err = sftpd.Reload() + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + // if we remove the revoked certificate login should work again + certs = []string{"SHA256:bsBRHC/xgiqBJdSuvSTNpJNLTISP/G356jNMCRYC5Es, SHA256:1kxVB1ImSJ2XeI8nA2Wg+6zJVlxdevD1FYBSEJjFEN4"} + data, err = json.Marshal(certs) + assert.NoError(t, err) + err = os.WriteFile(revokeUserCerts, data, 0644) + assert.NoError(t, err) + err = sftpd.Reload() + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + // try login using a cert signed from an untrusted CA + signer, err = getSignerForUserCert([]byte(testCertUntrustedCA)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + // try login using an host certificate instead of an user certificate + signer, err = getSignerForUserCert([]byte(testHostCert)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + // try login using a user certificate with an authorized source address different from localhost + signer, err = getSignerForUserCert([]byte(testCertOtherSourceAddress)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + // try login using an expired certificate + signer, err = getSignerForUserCert([]byte(testCertExpired)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + // try login using a certificate with no principals + signer, err = getSignerForUserCert([]byte(testCertNoPrincipals)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + // the user does not exist + signer, err = getSignerForUserCert([]byte(testCertValid)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + // now login with a username not in the set of valid principals for the given certificate + u.Username += "1" + user, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + signer, err = getSignerForUserCert([]byte(testCertValid)) + assert.NoError(t, err) + conn, client, err = getCustomAuthSftpClient(user, []ssh.AuthMethod{ssh.PublicKeys(signer)}, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + err = os.WriteFile(revokeUserCerts, []byte(`[]`), 0644) + assert.NoError(t, err) + err = sftpd.Reload() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMultiStepLoginKeyAndPwd(t *testing.T) { + u := getTestUser(true) + u.Password = defaultPassword + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ + dataprovider.SSHLoginMethodKeyAndKeyboardInt, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + }...) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, true) + if !assert.Error(t, err, "login with public key is disallowed and must fail") { + client.Close() + conn.Close() + } + conn, client, err = getSftpClient(user, true) + if !assert.Error(t, err, "login with password is disallowed and must fail") { + client.Close() + conn.Close() + } + signer, _ := ssh.ParsePrivateKey([]byte(testPrivateKey)) + authMethods := []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.Password(defaultPassword), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + conn, client, err = getCustomAuthSftpClient(user, authMethods, sftpSrvAddr2222) + if !assert.Error(t, err, "password auth is disabled on port 2222, multi-step auth must fail") { + client.Close() + conn.Close() + } + authMethods = []ssh.AuthMethod{ + ssh.Password(defaultPassword), + ssh.PublicKeys(signer), + } + _, _, err = getCustomAuthSftpClient(user, authMethods, "") + assert.Error(t, err, "multi step auth login with wrong order must fail") + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMultiStepLoginKeyAndKeyInt(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser(true) + u.Password = defaultPassword + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ + dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + }...) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, true) + if !assert.Error(t, err, "login with public key is disallowed and must fail") { + client.Close() + conn.Close() + } + + signer, _ := ssh.ParsePrivateKey([]byte(testPrivateKey)) + authMethods := []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { + return []string{"1", "2"}, nil + }), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + conn, client, err = getCustomAuthSftpClient(user, authMethods, sftpSrvAddr2222) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + authMethods = []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { + return []string{"1", "2"}, nil + }), + ssh.PublicKeys(signer), + } + _, _, err = getCustomAuthSftpClient(user, authMethods, "") + assert.Error(t, err, "multi step auth login with wrong order must fail") + + authMethods = []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.Password(defaultPassword), + } + _, _, err = getCustomAuthSftpClient(user, authMethods, "") + assert.Error(t, err, "multi step auth login with wrong method must fail") + + user.Filters.DeniedLoginMethods = nil + user.Filters.DeniedLoginMethods = append(user.Filters.DeniedLoginMethods, []string{ + dataprovider.SSHLoginMethodKeyAndKeyboardInt, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + }...) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, _, err = getCustomAuthSftpClient(user, authMethods, sftpSrvAddr2222) + assert.Error(t, err) + conn, client, err = getCustomAuthSftpClient(user, authMethods, "") + if assert.NoError(t, err) { + assert.NoError(t, checkBasicSFTP(client)) + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMultiStepLoginCertAndPwd(t *testing.T) { + u := getTestUser(true) + u.Password = defaultPassword + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ + dataprovider.SSHLoginMethodKeyAndKeyboardInt, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + }...) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + signer, err := getSignerForUserCert([]byte(testCertValid)) + assert.NoError(t, err) + authMethods := []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.Password(defaultPassword), + } + conn, client, err := getCustomAuthSftpClient(user, authMethods, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + signer, err = getSignerForUserCert([]byte(testCertOtherSourceAddress)) + assert.NoError(t, err) + authMethods = []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.Password(defaultPassword), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods, "") + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginUserStatus(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) + } + user.Status = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err, "login for a disabled user must fail") { + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginUserExpiration(t *testing.T) { + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) + } + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) - 120000 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err, "login for an expired user must fail") { + client.Close() + conn.Close() + } + user.ExpirationDate = util.GetTimeAsMsSinceEpoch(time.Now()) + 120000 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginWithDatabaseCredentials(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "testbucket" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account", "private_key": " ", "client_email": "example@iam.gserviceaccount.com" }`) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.NotEmpty(t, user.FsConfig.GCSConfig.Credentials.GetPayload()) + assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginInvalidFs(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if !assert.Error(t, err, "login must fail, the user has an invalid filesystem config") { + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDeniedProtocols(t *testing.T) { + u := getTestUser(true) + u.Filters.DeniedProtocols = []string{common.ProtocolSSH} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, true) + if !assert.Error(t, err, "SSH protocol is disabled, authentication must fail") { + client.Close() + conn.Close() + } + user.Filters.DeniedProtocols = []string{common.ProtocolFTP, common.ProtocolWebDAV} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, true) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDeniedLoginMethods(t *testing.T) { + u := getTestUser(true) + u.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.LoginMethodPassword} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, true) + if !assert.Error(t, err, "public key login is disabled, authentication must fail") { + client.Close() + conn.Close() + } + user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.LoginMethodPassword} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, true) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user.Password = defaultPassword + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + conn, client, err = getSftpClient(user, false) + if !assert.Error(t, err, "password login is disabled, authentication must fail") { + client.Close() + conn.Close() + } + user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodPublicKey} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, false) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginWithIPFilters(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Filters.DeniedIP = []string{"192.167.0.0/24", "172.18.0.0/16"} + u.Filters.AllowedIP = []string{} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.LastLogin, int64(0), "last login must be updated after a successful login: %v", user.LastLogin) + } + user.Filters.AllowedIP = []string{"127.0.0.0/8"} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user.Filters.AllowedIP = []string{"172.19.0.0/16"} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err, "login from an not allowed IP must fail") { + client.Close() + conn.Close() + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginEmptyPassword(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Password = "" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Password = "empty" + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginAnonymousUser(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Password = "" + u.Filters.IsAnonymous = true + _, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.Error(t, err) + user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.IsAnonymous) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) + assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) + assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAnonymousGroupInheritance(t *testing.T) { + g := getTestGroup() + g.UserSettings.Filters.IsAnonymous = true + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + usePubKey := false + u := getTestUser(usePubKey) + u.Groups = []sdk.GroupMapping{ + { + Name: group.Name, + Type: sdk.GroupTypePrimary, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoginAfterUserUpdateEmptyPwd(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + user.Password = "" + // password should remain unchanged + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginKeyboardInteractiveAuth(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + user, _, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) + assert.NoError(t, err) + conn, client, err := getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user.Status = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) + if !assert.Error(t, err, "keyboard interactive auth must fail the user is disabled") { + client.Close() + conn.Close() + } + user.Status = 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, -1), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) + if !assert.Error(t, err, "keyboard interactive auth must fail the script returned -1") { + client.Close() + conn.Close() + } + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, true, 1), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) + if !assert.Error(t, err, "keyboard interactive auth must fail the script returned bad json") { + client.Close() + conn.Close() + } + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 5, true, 1), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getKeyboardInteractiveSftpClient(user, []string{"1", "2"}) + if !assert.Error(t, err, "keyboard interactive auth must fail the script returned bad json") { + client.Close() + conn.Close() + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestInteractiveLoginWithPasscode(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + user, _, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) + assert.NoError(t, err) + // test password check + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(false, 1), os.ModePerm) + assert.NoError(t, err) + conn, client, err := getKeyboardInteractiveSftpClient(user, []string{defaultPassword}) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + // wrong password + _, _, err = getKeyboardInteractiveSftpClient(user, []string{"wrong_password"}) + assert.Error(t, err) + // correct password but the script returns an error + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(false, 0), os.ModePerm) + assert.NoError(t, err) + _, _, err = getKeyboardInteractiveSftpClient(user, []string{"wrong_password"}) + assert.Error(t, err) + // add multi-factor authentication + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + passcode, err := totp.GenerateCodeCustom(key.Secret(), time.Now(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + assert.NoError(t, err) + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(true, 1), os.ModePerm) + assert.NoError(t, err) + + passwordAsked := false + passcodeAsked := false + authMethods := []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) { + var answers []string + if strings.HasPrefix(questions[0], "Password") { + answers = append(answers, defaultPassword) + passwordAsked = true + } else { + answers = append(answers, passcode) + passcodeAsked = true + } + return answers, nil + }), + } + conn, client, err = getCustomAuthSftpClient(user, authMethods, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + assert.True(t, passwordAsked) + assert.True(t, passcodeAsked) + // the same passcode cannot be reused + _, _, err = getCustomAuthSftpClient(user, authMethods, "") + assert.Error(t, err) + // correct passcode but the script returns an error + configName, key, _, err = mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + passcode, err = totp.GenerateCodeCustom(key.Secret(), time.Now(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + assert.NoError(t, err) + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptForBuiltinChecks(true, 0), os.ModePerm) + assert.NoError(t, err) + passwordAsked = false + passcodeAsked = false + _, _, err = getCustomAuthSftpClient(user, authMethods, "") + assert.Error(t, err) + authMethods = []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) { + var answers []string + if strings.HasPrefix(questions[0], "Password") { + answers = append(answers, defaultPassword) + passwordAsked = true + } else { + answers = append(answers, passcode) + passcodeAsked = true + } + return answers, nil + }), + } + _, _, err = getCustomAuthSftpClient(user, authMethods, "") + assert.Error(t, err) + assert.True(t, passwordAsked) + assert.True(t, passcodeAsked) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMustChangePasswordRequirement(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Filters.RequirePasswordChange = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + // public key auth works even if the user must change password + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + // password auth does not work + _, _, err = getSftpClient(user, false) + assert.Error(t, err) + // change password + err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, false) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSecondFactorRequirement(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Filters.TwoFactorAuthProtocols = []string{common.ProtocolSSH} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + _, _, err = getSftpClient(user, usePubKey) + assert.Error(t, err) + + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestNamingRules(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.NamingRules = 7 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + usePubKey := true + u := getTestUser(usePubKey) + u.Username = "useR@user.com " + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, "user@user.com", user.Username) + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + u.Password = defaultPassword + _, _, err = httpdtest.UpdateUser(u, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(u, false) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + _, err = httpdtest.RemoveUser(u, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestPreLoginScript(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := true + u := getTestUser(usePubKey) + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + folderMountPath := "/vpath" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: folderMountPath, + }) + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + assert.NoError(t, checkBasicSFTP(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testData := []byte("test data") + err = os.WriteFile(testFilePath, testData, 0666) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(folderMountPath, testFileName), int64(len(testData)), client) + assert.NoError(t, err) + info, err := os.Stat(filepath.Join(mappedPath, testFileName)) + if assert.NoError(t, err) { + assert.Greater(t, info.Size(), int64(len(testData))) + } + } + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "pre-login script returned a non json response, login must fail") { + client.Close() + conn.Close() + } + // now disable the the hook + user.Filters.Hooks.PreLoginDisabled = true + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + user.Filters.Hooks.PreLoginDisabled = false + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + user.Status = 0 + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "pre-login script returned a disabled user, login must fail") { + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestPreLoginUserCreation(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/list"] = []string{"list", "download"} + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) + assert.NoError(t, err) + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Permissions, 2) + assert.Empty(t, user.Description) + u.Description = "some desc" + delete(u.Permissions, "/list") + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + // The user should be updated and list permission removed + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Permissions, 1) + assert.NotEmpty(t, user.Description) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestPreLoginHookPreserveMFAConfig(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + // add multi-factor authentication + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 0) + assert.False(t, user.Filters.TOTPConfig.Enabled) + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + for i := 0; i < 12; i++ { + user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ + Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))), + }) + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 12) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) + assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) + + err = os.WriteFile(extAuthPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 12) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) + assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestPreDownloadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} + common.Config.Actions.Hook = preDownloadPath + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + } + + remoteSCPDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) + assert.NoError(t, err) + + err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Remove(testFileName) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + } + + err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) + assert.Error(t, err) + + common.Config.Actions.Hook = "http://127.0.0.1:8080/web/admin/login" + + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Remove(testFileName) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + } + err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) + assert.NoError(t, err) + + common.Config.Actions.Hook = "http://127.0.0.1:8080/" + + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = client.Remove(testFileName) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + } + err = scpDownload(localDownloadPath, remoteSCPDownPath, false, false) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPreUploadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} + common.Config.Actions.Hook = preUploadPath + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preUploadPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + } + + remoteSCPUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpUpload(testFilePath, remoteSCPUpPath, true, false) + assert.NoError(t, err) + + err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + err = sftpUploadFile(testFilePath, testFileName+"1", testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + } + err = scpUpload(testFilePath, remoteSCPUpPath, true, false) + assert.Error(t, err) + + common.Config.Actions.Hook = "http://127.0.0.1:8080/web/client/login" + + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + } + err = scpUpload(testFilePath, remoteSCPUpPath, true, false) + assert.NoError(t, err) + + common.Config.Actions.Hook = "http://127.0.0.1:8080/web" + + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + err = sftpUploadFile(testFilePath, testFileName+"1", testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + } + + err = scpUpload(testFilePath, remoteSCPUpPath, true, false) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPostConnectHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + common.Config.PostConnectHook = postConnectPath + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + common.Config.PostConnectHook = "http://127.0.0.1:8080/healthz" + + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + common.Config.PostConnectHook = "http://127.0.0.1:8080/notfound" + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.PostConnectHook = "" +} + +func TestCheckPwdHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(checkPwdPath, getCheckPwdScriptsContents(2, defaultPassword), os.ModePerm) + assert.NoError(t, err) + providerConf.CheckPasswordHook = checkPwdPath + providerConf.CheckPasswordScope = 1 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + err = checkBasicSFTP(client) + assert.NoError(t, err) + client.Close() + conn.Close() + } + + err = os.WriteFile(checkPwdPath, getCheckPwdScriptsContents(0, defaultPassword), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + // now disable the the hook + user.Filters.Hooks.CheckPasswordDisabled = true + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + err = checkBasicSFTP(client) + assert.NoError(t, err) + client.Close() + conn.Close() + } + + // enable the hook again + user.Filters.Hooks.CheckPasswordDisabled = false + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + err = os.WriteFile(checkPwdPath, getCheckPwdScriptsContents(1, ""), os.ModePerm) + assert.NoError(t, err) + user.Password = defaultPassword + "1" + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + err = checkBasicSFTP(client) + assert.NoError(t, err) + client.Close() + conn.Close() + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + providerConf.CheckPasswordScope = 6 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + user, _, err = httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Password = defaultPassword + "1" + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(checkPwdPath) + assert.NoError(t, err) +} + +func TestLoginExternalAuthPwdAndPubKey(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + testFileSize := int64(65535) + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + u.Username = defaultUsername + "1" + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "external auth login with invalid user must fail") { + client.Close() + conn.Close() + } + usePubKey = false + u = getTestUser(usePubKey) + u.PublicKeys = []string{} + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, len(user.PublicKeys)) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + u.Status = 0 + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err) { + client.Close() + conn.Close() + } + // now disable the the hook + user.Filters.Hooks.ExternalAuthDisabled = true + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthMultiStepLoginKeyAndPwd(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser(true) + u.Password = defaultPassword + u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, []string{ + dataprovider.SSHLoginMethodKeyAndKeyboardInt, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + }...) + + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + assert.NoError(t, err) + authMethods := []ssh.AuthMethod{ + ssh.PublicKeys(signer), + ssh.Password(defaultPassword), + } + conn, client, err := getCustomAuthSftpClient(u, authMethods, "") + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + // wrong sequence should fail + authMethods = []ssh.AuthMethod{ + ssh.Password(defaultPassword), + ssh.PublicKeys(signer), + } + _, _, err = getCustomAuthSftpClient(u, authMethods, "") + assert.Error(t, err) + + // public key only auth must fail + _, _, err = getSftpClient(u, true) + assert.Error(t, err) + // password only auth must fail + _, _, err = getSftpClient(u, false) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(u, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthEmptyResponse(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + testFileSize := int64(65535) + // the user will be created + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, len(user.PublicKeys)) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + // now modify the user + user.MaxSessions = 10 + user.QuotaFiles = 100 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm) + assert.NoError(t, err) + + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 10, user.MaxSessions) + assert.Equal(t, 100, user.QuotaFiles) + + // the auth script accepts any password and returns an empty response, the + // user password must be updated + u.Password = defaultUsername + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthDifferentUsername(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + extAuthUsername := "common_user" + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, extAuthUsername), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + // the user logins using "defaultUsername" and the external auth returns "extAuthUsername" + testFileSize := int64(65535) + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + + // logins again to test that used quota is preserved + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = checkBasicSFTP(client) + assert.NoError(t, err) + } + + _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) + assert.NoError(t, err) + + user, _, err := httpdtest.GetUserByUsername(extAuthUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, len(user.PublicKeys)) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestLoginExternalAuth(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + mappedPath := filepath.Join(os.TempDir(), "vdir1") + folderName := filepath.Base(mappedPath) + extAuthScopes := []int{1, 2} + for _, authScope := range extAuthScopes { + var usePubKey bool + if authScope == 1 { + usePubKey = false + } else { + usePubKey = true + } + u := getTestUser(usePubKey) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vpath", + QuotaFiles: 1 + authScope, + QuotaSize: 10 + int64(authScope), + }) + + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = authScope + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + if !usePubKey { + dbUser, err := dataprovider.UserExists(defaultUsername, "") + assert.NoError(t, err) + found, match := dataprovider.CheckCachedUserPassword(defaultUsername, defaultPassword, dbUser.Password) + assert.True(t, found) + assert.True(t, match) + } + u.Username = defaultUsername + "1" + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "external auth login with invalid user must fail") { + client.Close() + conn.Close() + } + usePubKey = !usePubKey + u = getTestUser(usePubKey) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "external auth login with valid user but invalid auth scope must fail") { + client.Close() + conn.Close() + } + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, user.VirtualFolders, 1) { + folder := user.VirtualFolders[0] + assert.Equal(t, folderName, folder.Name) + assert.Equal(t, mappedPath, folder.MappedPath) + assert.Equal(t, 1+authScope, folder.QuotaFiles) + assert.Equal(t, 10+int64(authScope), folder.QuotaSize) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) + } +} + +func TestLoginExternalAuthCache(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser(false) + u.Filters.ExternalAuthCacheTime = 120 + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 1 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + conn, client, err := getSftpClient(u, false) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + lastLogin := user.LastLogin + assert.Greater(t, lastLogin, int64(0)) + assert.Equal(t, u.Filters.ExternalAuthCacheTime, user.Filters.ExternalAuthCacheTime) + // the auth should be now cached so update the hook to return an error + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, true, false, ""), os.ModePerm) + assert.NoError(t, err) + conn, client, err = getSftpClient(u, false) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, lastLogin, user.LastLogin) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestLoginExternalAuthInteractive(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 4 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + err = os.WriteFile(keyIntAuthPath, getKeyboardInteractiveScriptContent([]string{"1", "2"}, 0, false, 1), os.ModePerm) + assert.NoError(t, err) + conn, client, err := getKeyboardInteractiveSftpClient(u, []string{"1", "2"}) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + u.Username = defaultUsername + "1" + conn, client, err = getKeyboardInteractiveSftpClient(u, []string{"1", "2"}) + if !assert.Error(t, err, "external auth login with invalid user must fail") { + client.Close() + conn.Close() + } + usePubKey = true + u = getTestUser(usePubKey) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "external auth login with valid user but invalid auth scope must fail") { + client.Close() + conn.Close() + } + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestLoginExternalAuthErrors(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := true + u := getTestUser(usePubKey) + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, true, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + conn, client, err := getSftpClient(u, usePubKey) + if !assert.Error(t, err, "login must fail, external auth returns a non json response") { + client.Close() + conn.Close() + } + + usePubKey = false + u = getTestUser(usePubKey) + conn, client, err = getSftpClient(u, usePubKey) + if !assert.Error(t, err, "login must fail, external auth returns a non json response") { + client.Close() + conn.Close() + } + _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthReturningAnonymousUser(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + u.Filters.IsAnonymous = true + u.Password = "" + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + _, _, err = getSftpClient(u, usePubKey) + assert.Error(t, err) + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.IsAnonymous) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) + assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) + assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) + + // test again, the user now exists + _, _, err = getSftpClient(u, usePubKey) + assert.Error(t, err) + updatedUser, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + user.UpdatedAt = updatedUser.UpdatedAt + user.LastPasswordChange = updatedUser.LastPasswordChange + assert.Equal(t, user, updatedUser) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthPreserveMFAConfig(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, false, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + conn, client, err := getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + // add multi-factor authentication + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 0) + assert.False(t, user.Filters.TOTPConfig.Enabled) + configName, key, _, err := mfa.GenerateTOTPSecret(mfa.GetAvailableTOTPConfigNames()[0], user.Username) + assert.NoError(t, err) + user.Password = defaultPassword + user.Filters.TOTPConfig = dataprovider.UserTOTPConfig{ + Enabled: true, + ConfigName: configName, + Secret: kms.NewPlainSecret(key.Secret()), + Protocols: []string{common.ProtocolSSH}, + } + for i := 0; i < 12; i++ { + user.Filters.RecoveryCodes = append(user.Filters.RecoveryCodes, dataprovider.RecoveryCode{ + Secret: kms.NewPlainSecret(fmt.Sprintf("RC-%v", strings.ToUpper(util.GenerateUniqueID()))), + }) + } + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + // login again and check that the MFA configs are preserved + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 12) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) + assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) + + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, false, true, ""), os.ModePerm) + assert.NoError(t, err) + + conn, client, err = getSftpClient(u, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + + user, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Len(t, user.Filters.RecoveryCodes, 12) + assert.True(t, user.Filters.TOTPConfig.Enabled) + assert.Equal(t, configName, user.Filters.TOTPConfig.ConfigName) + assert.Equal(t, []string{common.ProtocolSSH}, user.Filters.TOTPConfig.Protocols) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.Filters.TOTPConfig.Secret.GetStatus()) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestQuotaDisabledError(t *testing.T) { + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + providerConf.TrackQuota = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName+"1", testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName+"1", testFileName+".rename") //nolint:goconst + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +//nolint:dupl +func TestMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + usePubKey := true + user := getTestUser(usePubKey) + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + assert.NoError(t, checkBasicSFTP(client)) + s, c, err := getSftpClient(user, usePubKey) + if !assert.Error(t, err, "max total connections exceeded, new login should not succeed") { + c.Close() + s.Close() + } + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxTotalConnections = oldValue +} + +//nolint:dupl +func TestMaxPerHostConnections(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 1 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + usePubKey := true + user := getTestUser(usePubKey) + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + assert.NoError(t, checkBasicSFTP(client)) + s, c, err := getSftpClient(user, usePubKey) + if !assert.Error(t, err, "max per host connections exceeded, new login should not succeed") { + c.Close() + s.Close() + } + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + usePubKey := true + user := getTestUser(usePubKey) + err := dataprovider.AddUser(&user, "", "", "") + assert.NoError(t, err) + user.Password = "" + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + assert.NoError(t, checkBasicSFTP(client)) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + + f1, err := client.Create("file1") + assert.NoError(t, err) + f2, err := client.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error()) + + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err) + + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.ErrorContains(t, err, sftp.ErrSSHFxPermissionDenied.Error()) + + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.Error(t, err) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + } + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + return common.Connections.GetTotalTransfers() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestMaxSessions(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Username += "1" + u.MaxSessions = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + s, c, err := getSftpClient(user, usePubKey) + if !assert.Error(t, err, "max sessions exceeded, new login should not succeed") { + c.Close() + s.Close() + } + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSupportedExtensions(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + v, ok := client.HasExtension("statvfs@openssh.com") + assert.Equal(t, "2", v) + assert.True(t, ok) + _, ok = client.HasExtension("hardlink@openssh.com") + assert.False(t, ok) + _, ok = client.HasExtension("posix-rename@openssh.com") + assert.False(t, ok) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaFileReplace(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaFiles = 1000 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { //nolint:dupl + defer conn.Close() + defer client.Close() + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // now replace the same file, the quota must not change + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + // now create a symlink, replace it with a file and check the quota + // replacing a symlink is like uploading a new file + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + expectedQuotaFiles++ + expectedQuotaSize += testFileSize + err = sftpUploadFile(testFilePath, testFileName+".link", testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now set a quota size restriction and upload the same file, upload should fail for space limit exceeded + user.QuotaSize = testFileSize*2 - 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.Error(t, err, "quota size exceeded, file upload must fail") + err = client.Remove(testFileName) + assert.NoError(t, err) + } + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + user.QuotaSize = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaRename(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaFiles = 1000 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(65535) + testFileSize1 := int64(65537) + testFileName1 := "test_file1.dat" //nolint:goconst + testFilePath := filepath.Join(homeBasePath, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + err = client.Rename(testFileName1, testFileName+".rename") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + err = client.Symlink(testFileName+".rename", testFileName+".symlink") //nolint:goconst + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // overwrite a symlink + err = client.Rename(testFileName, testFileName+".symlink") + assert.NoError(t, err) + err = client.Mkdir("testdir") + assert.NoError(t, err) + err = client.Rename("testdir", "testdir1") + assert.NoError(t, err) + err = client.Mkdir("testdir") + assert.NoError(t, err) + err = client.Rename("testdir", "testdir1") + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + testDir := "tdir" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 4, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1*2, user.UsedQuotaSize) + err = client.Rename(testDir, testDir+"1") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 4, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1*2, user.UsedQuotaSize) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaScan(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(65535) + expectedQuotaSize := user.UsedQuotaSize + testFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + // create user with the same home dir, so there is at least an untracked file + user, _, err = httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMultipleQuotaScans(t *testing.T) { + res := common.QuotaScans.AddUserQuotaScan(defaultUsername, "") + assert.True(t, res) + res = common.QuotaScans.AddUserQuotaScan(defaultUsername, "") + assert.False(t, res, "add quota must fail if another scan is already active") + assert.True(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername)) + activeScans := common.QuotaScans.GetUsersQuotaScans("") + assert.Equal(t, 0, len(activeScans)) + assert.False(t, common.QuotaScans.RemoveUserQuotaScan(defaultUsername)) +} + +func TestQuotaLimits(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 1 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaFiles = 1 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + testFileSize2 := int64(32768) + testFileName2 := "test_file2.dat" //nolint:goconst + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath2, testFileSize2) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + // test quota files + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, testFileName+".quota", testFileSize, client) //nolint:goconst + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName+".quota.1", testFileSize, client) + if assert.Error(t, err, "user is over quota files, upload must fail") { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + // rename should work + err = client.Rename(testFileName+".quota", testFileName) + assert.NoError(t, err) + } + // test quota size + user.QuotaSize = testFileSize - 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, testFileName+".quota.1", testFileSize, client) + if assert.Error(t, err, "user is over quota size, upload must fail") { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + err = client.Rename(testFileName, testFileName+".quota") + assert.NoError(t, err) + err = client.Rename(testFileName+".quota", testFileName) + assert.NoError(t, err) + } + // now test quota limits while uploading the current file, we have 1 bytes remaining + user.QuotaSize = testFileSize + 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "SSH_FX_FAILURE") + if user.Username == localUser.Username { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + _, err = client.Stat(testFileName1) + assert.Error(t, err) + _, err = client.Lstat(testFileName1) + assert.Error(t, err) + // overwriting an existing file will work if the resulting size is lesser or equal than the current one + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, testFileName, testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName, testFileSize1, client) + assert.Error(t, err) + _, err := client.Stat(testFileName) + assert.Error(t, err) + } + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestTransferQuotaLimits(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(550000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + // error while download is active + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + // error before starting the download + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrReadQuotaExceeded.Error()) + } + // error while upload is active + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + // error before starting the upload + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + } + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.UsedDownloadDataTransfer, int64(1024*1024)) + assert.Greater(t, user.UsedUploadDataTransfer, int64(1024*1024)) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadMaxSize(t *testing.T) { + testFileSize := int64(65535) + usePubKey := false + u := getTestUser(usePubKey) + u.Filters.MaxUploadFileSize = testFileSize + 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.Error(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // now test overwrite an existing file with a size bigger than the allowed one + err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.Error(t, err) + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestBandwidthAndConnections(t *testing.T) { + usePubKey := false + testFileSize := int64(524288) + u := getTestUser(usePubKey) + u.UploadBandwidth = 120 + u.DownloadBandwidth = 100 + wantedUploadElapsed := 1000 * (testFileSize / 1024) / u.UploadBandwidth + wantedDownloadElapsed := 1000 * (testFileSize / 1024) / u.DownloadBandwidth + // 100 ms tolerance + wantedUploadElapsed -= 100 + wantedDownloadElapsed -= 100 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + startTime := time.Now() + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + elapsed := time.Since(startTime).Nanoseconds() / 1000000 + assert.GreaterOrEqual(t, elapsed, wantedUploadElapsed, "upload bandwidth throttling not respected") + startTime = time.Now() + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + c := sftpDownloadNonBlocking(testFileName, localDownloadPath, testFileSize, client) + waitForActiveTransfers(t) + // wait some additional arbitrary time to wait for transfer activity to happen + // it is need to reach all the code in CheckIdleConnections + time.Sleep(100 * time.Millisecond) + err = <-c + assert.NoError(t, err) + elapsed = time.Since(startTime).Nanoseconds() / 1000000 + assert.GreaterOrEqual(t, elapsed, wantedDownloadElapsed, "download bandwidth throttling not respected") + // test disconnection + c = sftpUploadNonBlocking(testFilePath, testFileName+"_partial", testFileSize, client) + waitForActiveTransfers(t) + time.Sleep(100 * time.Millisecond) + + for _, stat := range common.Connections.GetStats("") { + common.Connections.Close(stat.ConnectionID, "") + } + err = <-c + assert.Error(t, err, "connection closed while uploading: the upload must fail") + assert.Eventually(t, func() bool { + return len(common.Connections.GetStats("")) == 0 + }, 10*time.Second, 200*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPatternsFilters(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName+".zip", testFileSize, client) + assert.NoError(t, err) + } + user.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + AllowedPatterns: []string{"*.zIp"}, + DeniedPatterns: []string{}, + }, + } + user.Filters.DisableFsChecks = true + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.Error(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + err = client.Rename(testFileName, testFileName+"1") + assert.Error(t, err) + err = client.Remove(testFileName) + assert.Error(t, err) + err = sftpDownloadFile(testFileName+".zip", localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = client.Mkdir("dir.zip") + assert.NoError(t, err) + err = client.Rename("dir.zip", "dir1.zip") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestVirtualFolders(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/subdir" + testDir := "/userDir" + testDir1 := "/userDir1" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + }) + u.Permissions[testDir] = []string{dataprovider.PermCreateDirs} + u.Permissions[testDir1] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermRename} + u.Permissions[path.Join(testDir1, "subdir")] = []string{dataprovider.PermRename} + + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + // check virtual folder auto creation + _, err = os.Stat(mappedPath) + assert.NoError(t, err) + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpUploadFile(testFilePath, path.Join(vdirPath, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(vdirPath, "new_name") + assert.Error(t, err, "renaming a virtual folder must fail") + err = client.RemoveDirectory(vdirPath) + assert.Error(t, err, "removing a virtual folder must fail") + err = client.Mkdir(vdirPath) + assert.Error(t, err, "creating a virtual folder must fail") + err = client.Symlink(path.Join(vdirPath, testFileName), vdirPath) + assert.Error(t, err, "symlink to a virtual folder must fail") + err = client.Rename("/vdir", "/vdir1") + assert.Error(t, err, "renaming a directory with a virtual folder inside must fail") + err = client.RemoveDirectory("/vdir") + assert.Error(t, err, "removing a directory with a virtual folder inside must fail") + err = client.Mkdir("vdir1") + assert.NoError(t, err) + // rename empty dir /vdir1, we have permission on / + err = client.Rename("vdir1", "vdir2") + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("vdir2", testFileName), testFileSize, client) + assert.NoError(t, err) + // we don't have rename permission in testDir and vdir2 contains a file + err = client.Rename("vdir2", testDir) + assert.Error(t, err) + err = client.Rename("vdir2", testDir1) + assert.NoError(t, err) + err = client.Rename(testDir1, "vdir2") + assert.NoError(t, err) + err = client.MkdirAll(path.Join("vdir2", "subdir")) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("vdir2", "subdir", testFileName), testFileSize, client) + assert.NoError(t, err) + err = client.Rename("vdir2", testDir1) + assert.NoError(t, err) + err = client.Rename(testDir1, "vdir2") + assert.NoError(t, err) + err = client.MkdirAll(path.Join("vdir2", "subdir", "subdir")) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("vdir2", "subdir", "subdir", testFileName), testFileSize, client) + assert.NoError(t, err) + err = client.Rename("vdir2", testDir1) + assert.NoError(t, err) + err = client.Rename(testDir1, "vdir3") + assert.NoError(t, err) + err = client.Remove(path.Join("vdir3", "subdir", "subdir", testFileName)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join("vdir3", "subdir", "subdir")) + assert.NoError(t, err) + err = client.Rename("vdir3", testDir1) + assert.NoError(t, err) + err = client.Rename(testDir1, "vdir2") + assert.NoError(t, err) + err = client.Symlink(path.Join("vdir2", "subdir", testFileName), path.Join("vdir2", "subdir", "alink")) + assert.NoError(t, err) + err = client.Rename("vdir2", testDir1) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaLimit(t *testing.T) { + usePubKey := false + u1 := getTestUser(usePubKey) + u1.QuotaFiles = 1 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" //nolint:goconst + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" //nolint:goconst + u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u1.VirtualFolders = append(u1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 1, + QuotaSize: 0, + }) + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + err := createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + u2 := getTestUser(usePubKey) + u2.QuotaSize = testFileSize + 1 + u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u2.VirtualFolders = append(u2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 0, + QuotaSize: testFileSize + 1, + }) + users := []dataprovider.User{u1, u2} + for _, u := range users { + err = os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.Error(t, err) + _, err = client.Stat(testFileName) + assert.Error(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName+"1"), testFileSize, client) + assert.Error(t, err) + _, err = client.Stat(path.Join(vdirPath1, testFileName+"1")) + assert.Error(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+"1"), testFileSize, client) + assert.Error(t, err) + _, err = client.Stat(path.Join(vdirPath2, testFileName+"1")) + assert.Error(t, err) + err = client.Remove(path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.Error(t, err) + // now test renames + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".rename")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".rename")) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath2, testFileName+".rename"), testFileName+".rename") + assert.Error(t, err) + err = client.Rename(path.Join(vdirPath2, testFileName+".rename"), path.Join(vdirPath1, testFileName)) + assert.Error(t, err) + err = client.Rename(path.Join(vdirPath1, testFileName+".rename"), path.Join(vdirPath2, testFileName)) + assert.Error(t, err) + err = client.Rename(path.Join(vdirPath1, testFileName+".rename"), testFileName) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) + } + err = os.Remove(testFilePath) + assert.NoError(t, err) +} + +func TestSFTPLoopSimple(t *testing.T) { + usePubKey := false + user1 := getTestSFTPUser(usePubKey) + user2 := getTestSFTPUser(usePubKey) + user1.Username += "1" + user2.Username += "2" + user1.FsConfig.Provider = sdk.SFTPFilesystemProvider + user2.FsConfig.Provider = sdk.SFTPFilesystemProvider + user1.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user2.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user1.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + _, _, err = getSftpClient(user1, usePubKey) + assert.Error(t, err) + _, _, err = getSftpClient(user2, usePubKey) + assert.Error(t, err) + + user1.FsConfig.SFTPConfig.Username = user1.Username + user1.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + + _, _, err = httpdtest.UpdateUser(user1, http.StatusOK, "") + assert.NoError(t, err) + _, _, err = getSftpClient(user1, usePubKey) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPLoopVirtualFolders(t *testing.T) { + usePubKey := false + sftpFloderName := "sftp" + user1 := getTestUser(usePubKey) + user2 := getTestSFTPUser(usePubKey) + user3 := getTestSFTPUser(usePubKey) + user1.Username += "1" + user2.Username += "2" + user3.Username += "3" + + // user1 is a local account with a virtual SFTP folder to user2 + // user2 has user1 as SFTP fs + user1.VirtualFolders = append(user1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: sftpFloderName, + }, + VirtualPath: "/vdir", + }) + + user2.FsConfig.Provider = sdk.SFTPFilesystemProvider + user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user1.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + user3.FsConfig.Provider = sdk.SFTPFilesystemProvider + user3.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user1.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + f := vfs.BaseVirtualFolder{ + Name: sftpFloderName, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user2.Username, + EqualityCheckMode: 1, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user3, resp, err = httpdtest.AddUser(user3, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + // login will work but /vdir will not be accessible + conn, client, err := getSftpClient(user1, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + _, err = client.ReadDir("/vdir") + assert.Error(t, err) + } + // now make user2 a local account with an SFTP virtual folder to user1. + // So we have: + // user1 -> local account with the SFTP virtual folder /vdir to user2 + // user2 -> local account with the SFTP virtual folder /vdir2 to user3 + // user3 -> sftp user with user1 as fs + user2.FsConfig.Provider = sdk.LocalFilesystemProvider + user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{} + user2.VirtualFolders = append(user2.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: sftpFloderName, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user3.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + }, + VirtualPath: "/vdir2", + }) + user2, _, err = httpdtest.UpdateUser(user2, http.StatusOK, "") + assert.NoError(t, err) + + // login will work but /vdir will not be accessible + conn, client, err = getSftpClient(user1, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + _, err = client.ReadDir("/vdir") + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user3, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user3.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: sftpFloderName}, http.StatusOK) + assert.NoError(t, err) +} + +func TestNestedVirtualFolders(t *testing.T) { + usePubKey := true + baseUser, resp, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err, string(resp)) + u := getTestSFTPUser(usePubKey) + u.QuotaFiles = 1000 + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + }, + VirtualPath: vdirCryptPath, + QuotaFiles: 100, + }) + mappedPath := filepath.Join(os.TempDir(), "local") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + mappedPathNested := filepath.Join(os.TempDir(), "nested") + folderNameNested := filepath.Base(mappedPathNested) + vdirNestedPath := "/vdir/crypt/nested" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameNested, + }, + VirtualPath: vdirNestedPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderNameNested, + MappedPath: mappedPathNested, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + expectedQuotaSize := int64(0) + expectedQuotaFiles := 0 + fileSize := int64(32765) + err = writeSFTPFile(testFileName, fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 38764 + err = writeSFTPFile(path.Join("/vdir", testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 18769 + err = writeSFTPFile(path.Join(vdirPath, testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 27658 + err = writeSFTPFile(path.Join(vdirNestedPath, testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 39765 + err = writeSFTPFile(path.Join(vdirCryptPath, testFileName), fileSize, client) + assert.NoError(t, err) + + userGet, _, err := httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize) + + folderGet, _, err := httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, folderGet.UsedQuotaSize, fileSize) + assert.Equal(t, 1, folderGet.UsedQuotaFiles) + + folderGet, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folderGet.UsedQuotaSize) + assert.Equal(t, 0, folderGet.UsedQuotaFiles) + + folderGet, _, err = httpdtest.GetFolderByName(folderNameNested, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folderGet.UsedQuotaSize) + assert.Equal(t, 0, folderGet.UsedQuotaFiles) + + files, err := client.ReadDir("/") + if assert.NoError(t, err) { + assert.Len(t, files, 2) + } + info, err := client.Stat("vdir") + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + files, err = client.ReadDir("/vdir") + if assert.NoError(t, err) { + assert.Len(t, files, 3) + } + files, err = client.ReadDir(vdirCryptPath) + if assert.NoError(t, err) { + assert.Len(t, files, 2) + } + info, err = client.Stat(vdirNestedPath) + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + // finally add some files directly using os method and then check quota + fName := "testfile" + fileSize = 123456 + err = createTestFile(filepath.Join(baseUser.HomeDir, fName), fileSize) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 8765 + err = createTestFile(filepath.Join(mappedPath, fName), fileSize) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + fileSize = 98751 + err = createTestFile(filepath.Join(mappedPathNested, fName), fileSize) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + err = createTestFile(filepath.Join(mappedPathCrypt, fName), fileSize) + assert.NoError(t, err) + _, err = httpdtest.StartQuotaScan(user, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + + userGet, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, userGet.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, userGet.UsedQuotaSize) + + // the crypt folder is not included within user quota so we need to do a separate scan + _, err = httpdtest.StartFolderQuotaScan(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + folderGet, _, err = httpdtest.GetFolderByName(folderNameCrypt, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, folderGet.UsedQuotaSize, int64(39765+98751)) + assert.Equal(t, 2, folderGet.UsedQuotaFiles) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathNested) + assert.NoError(t, err) +} + +func TestBufferedUser(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + u.FsConfig.OSConfig = sdk.OSFsConfig{ + WriteBufferSize: 2, + ReadBufferSize: 1, + } + vdirPath := "/crypted" + mappedPath := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName := filepath.Base(mappedPath) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + OSFsConfig: sdk.OSFsConfig{ + WriteBufferSize: 3, + ReadBufferSize: 2, + }, + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + expectedQuotaSize := int64(0) + expectedQuotaFiles := 0 + fileSize := int64(32768) + err = writeSFTPFile(testFileName, fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + err = writeSFTPFile(path.Join(vdirPath, testFileName), fileSize, client) + assert.NoError(t, err) + expectedQuotaSize += fileSize + expectedQuotaFiles++ + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Greater(t, user.UsedQuotaSize, expectedQuotaSize) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, fileSize, client) + assert.NoError(t, err) + err = sftpDownloadFile(path.Join(vdirPath, testFileName), localDownloadPath, fileSize, client) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Remove(path.Join(vdirPath, testFileName)) + assert.NoError(t, err) + + data := []byte("test data") + f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaSize := int64(2) + expectedQuotaFiles := 0 + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(5) + assert.NoError(t, err) + expectedQuotaSize = int64(5) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Close() + assert.NoError(t, err) + expectedQuotaSize = int64(5) + int64(len(data)) + expectedQuotaFiles = 1 + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now truncate by path + err = client.Truncate(testFileName, 5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestTruncateQuotaLimits(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaSize = 20 + mappedPath := filepath.Join(os.TempDir(), "mapped") + folderName := filepath.Base(mappedPath) + err := os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + vdirPath := "/vmapped" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaFiles: 10, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaSize = 20 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + data := []byte("test data") + f, err := client.OpenFile(testFileName, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(5) + assert.NoError(t, err) + expectedQuotaSize = int64(5) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + _, err = f.Seek(expectedQuotaSize, io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + expectedQuotaSize = int64(5) + int64(len(data)) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now truncate by path + err = client.Truncate(testFileName, 5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + // now open an existing file without truncate it, quota should not change + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + } + // open the file truncating it + f, err = client.OpenFile(testFileName, os.O_WRONLY|os.O_TRUNC) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + // now test max write size + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(11) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(11), user.UsedQuotaSize) + _, err = f.Seek(int64(11), io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(5) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + _, err = f.Seek(int64(5), io.SeekStart) + assert.NoError(t, err) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(12) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(12), user.UsedQuotaSize) + _, err = f.Seek(int64(12), io.SeekStart) + assert.NoError(t, err) + _, err = f.Write(data) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + // the file is deleted + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + + if user.Username == defaultUsername { + // basic test inside a virtual folder + vfileName := path.Join(vdirPath, testFileName) + f, err = client.OpenFile(vfileName, os.O_WRONLY|os.O_CREATE) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + fold, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + fold, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + } + err = client.Truncate(vfileName, 1) + assert.NoError(t, err) + fold, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(1), fold.UsedQuotaSize) + assert.Equal(t, 1, fold.UsedQuotaFiles) + // cleanup + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + user.QuotaSize = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { + usePubKey := true + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize1 := int64(65537) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err := createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + u := getTestUser(usePubKey) + u.QuotaFiles = 0 + u.QuotaSize = 0 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + mappedPath3 := filepath.Join(os.TempDir(), "vdir3") + folderName3 := filepath.Base(mappedPath3) + vdirPath3 := "/vdir3" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: 2, + QuotaSize: 0, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 0, + QuotaSize: testFileSize + testFileSize1 + 1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName3, + }, + VirtualPath: vdirPath3, + QuotaFiles: 2, + QuotaSize: testFileSize * 2, + }) + err = os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath3, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderName3, + MappedPath: mappedPath3, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath3, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath3, testFileName+"1"), testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName+".rename")) + assert.Error(t, err) + // we overwrite an existing file and we have unlimited size + err = client.Rename(testFileName, path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + // we have no space and we try to overwrite a bigger file with a smaller one, this should succeed + err = client.Rename(testFileName1, path.Join(vdirPath2, testFileName)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // we have no space and we try to overwrite a smaller file with a bigger one, this should fail + err = client.Rename(testFileName, path.Join(vdirPath2, testFileName1)) + assert.Error(t, err) + fi, err := client.Stat(path.Join(vdirPath1, testFileName1)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize1, fi.Size()) + } + // we are overquota inside vdir3 size 2/2 and size 262144/262144 + err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName1+".rename")) + assert.Error(t, err) + // we overwrite an existing file and we have enough size + err = client.Rename(path.Join(vdirPath1, testFileName1), path.Join(vdirPath3, testFileName)) + assert.NoError(t, err) + testFileName2 := "test_file2.dat" + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath2, testFileSize+testFileSize1) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, testFileName2, testFileSize+testFileSize1, client) + assert.NoError(t, err) + // we overwrite an existing file and we haven't enough size + err = client.Rename(testFileName2, path.Join(vdirPath3, testFileName)) + assert.Error(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + // now remove a file from vdir3, create a dir with 2 files and try to rename it in vdir3 + // this will fail since the rename will result in 3 files inside vdir3 and quota limits only + // allow 2 total files there + err = client.Remove(path.Join(vdirPath3, testFileName+"1")) + assert.NoError(t, err) + aDir := "a dir" + err = client.Mkdir(aDir) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(aDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(aDir, testFileName1+"1"), testFileSize1, client) + assert.NoError(t, err) + err = client.Rename(aDir, path.Join(vdirPath3, aDir)) + assert.Error(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName3}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath3) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) +} + +func TestVirtualFoldersQuotaValues(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" //nolint:goconst + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" //nolint:goconst + folderName2 := filepath.Base(mappedPath2) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // we copy the same file two times to test quota update on file overwrite + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + expectedQuotaFiles := 2 + expectedQuotaSize := testFileSize * 2 + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = client.Remove(path.Join(vdirPath1, testFileName)) + assert.NoError(t, err) + err = client.Remove(path.Join(vdirPath2, testFileName)) + assert.NoError(t, err) + + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameInsideSameVirtualFolder(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + vdirPath1 := "/vdir1" + folderName1 := filepath.Base(mappedPath1) + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + vdirPath2 := "/vdir2" + folderName2 := filepath.Base(mappedPath2) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + dir1 := "dir1" //nolint:goconst + dir2 := "dir2" //nolint:goconst + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file inside vdir1 it is included inside user quota, so we have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath1, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file inside vdir2, it isn't included inside user quota, so we have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName.rename + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath2, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file inside vdir2 overwriting an existing, we now have: + // - vdir1/dir1/testFileName.rename + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName.rename (initial testFileName1) + err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file inside vdir1 overwriting an existing, we now have: + // - vdir1/dir1/testFileName.rename (initial testFileName1) + // - vdir2/dir1/testFileName.rename (initial testFileName1) + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath1, dir1, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a directory inside the same virtual folder, quota should not change + err = client.RemoveDirectory(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath1, dir1), path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Rename(path.Join(vdirPath2, dir1), path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameBetweenVirtualFolder(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + dir1 := "dir1" + dir2 := "dir2" + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file from vdir1 to vdir2, vdir1 is included inside user quota, so we have: + // - vdir1/dir1/testFileName + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(vdirPath2, dir1, testFileName1+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + // rename a file from vdir2 to vdir1, vdir2 is not included inside user quota, so we have: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName.rename + // - vdir2/dir2/testFileName1 + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(vdirPath1, dir2, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir1 to vdir2 overwriting an existing file, vdir1 is included inside user quota, so we have: + // - vdir1/dir2/testFileName.rename + // - vdir2/dir2/testFileName1 (is the initial testFileName) + // - vdir2/dir1/testFileName1.rename + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(vdirPath2, dir2, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir2 to vdir1 overwriting an existing file, vdir2 is not included inside user quota, so we have: + // - vdir1/dir2/testFileName.rename (is the initial testFileName1) + // - vdir2/dir2/testFileName1 (is the initial testFileName) + err = client.Rename(path.Join(vdirPath2, dir1, testFileName1+".rename"), path.Join(vdirPath1, dir2, testFileName+".rename")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName+"1.dupl"), testFileSize1, client) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.RemoveDirectory(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + // - vdir1/dir2/testFileName.rename (initial testFileName1) + // - vdir1/dir2/testFileName + // - vdir2/dir2/testFileName1 (initial testFileName) + // - vdir2/dir2/testFileName (initial testFileName1) + // - vdir2/dir2/testFileName1.dupl + // rename directories between the two virtual folders + err = client.Rename(path.Join(vdirPath2, dir2), path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 5, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1*3+testFileSize*2, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // now move on vpath2 + err = client.Rename(path.Join(vdirPath1, dir2), path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1*2+testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameFromVirtualFolder(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + dir1 := "dir1" + dir2 := "dir2" + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir2, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - vdir1/dir1/testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + // + // rename a file from vdir1 to the user home dir, vdir1 is included in user quota so we have: + // - testFileName + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + // - vdir2/dir2/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1, testFileName), path.Join(testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + // rename a file from vdir2 to the user home dir, vdir2 is not included in user quota so we have: + // - testFileName + // - testFileName1 + // - vdir1/dir2/testFileName1 + // - vdir2/dir1/testFileName + err = client.Rename(path.Join(vdirPath2, dir2, testFileName1), path.Join(testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from vdir1 to the user home dir overwriting an existing file, vdir1 is included in user quota so we have: + // - testFileName (initial testFileName1) + // - testFileName1 + // - vdir2/dir1/testFileName + err = client.Rename(path.Join(vdirPath1, dir2, testFileName1), path.Join(testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from vdir2 to the user home dir overwriting an existing file, vdir2 is not included in user quota so we have: + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + err = client.Rename(path.Join(vdirPath2, dir1, testFileName), path.Join(testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // dir rename + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + // - vdir1/dir1/testFileName + // - vdir1/dir1/testFileName1 + // - dir1/testFileName + // - dir1/testFileName1 + err = client.Rename(path.Join(vdirPath2, dir1), dir1) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // - testFileName (initial testFileName1) + // - testFileName1 (initial testFileName) + // - dir2/testFileName + // - dir2/testFileName1 + // - dir1/testFileName + // - dir1/testFileName1 + err = client.Rename(path.Join(vdirPath1, dir1), dir2) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*3+testFileSize1*3, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestQuotaRenameToVirtualFolder(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName1 := "test_file1.dat" + testFileSize := int64(131072) + testFileSize1 := int64(65535) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + dir1 := "dir1" + dir2 := "dir2" + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, dir2)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, dir2)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.NoError(t, err) + // initial files: + // - testFileName + // - testFileName1 + // + // rename a file from user home dir to vdir1, vdir1 is included in user quota so we have: + // - testFileName + // - /vdir1/dir1/testFileName1 + err = client.Rename(testFileName1, path.Join(vdirPath1, dir1, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + // rename a file from user home dir to vdir2, vdir2 is not included in user quota so we have: + // - /vdir2/dir1/testFileName + // - /vdir1/dir1/testFileName1 + err = client.Rename(testFileName, path.Join(vdirPath2, dir1, testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // upload two new files to the user home dir so we have: + // - testFileName + // - testFileName1 + // - /vdir1/dir1/testFileName1 + // - /vdir2/dir1/testFileName + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1+testFileSize1, user.UsedQuotaSize) + // rename a file from user home dir to vdir1 overwriting an existing file, vdir1 is included in user quota so we have: + // - testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName + err = client.Rename(testFileName, path.Join(vdirPath1, dir1, testFileName1)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 2, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // rename a file from user home dir to vdir2 overwriting an existing file, vdir2 is not included in user quota so we have: + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + err = client.Rename(testFileName1, path.Join(vdirPath2, dir1, testFileName)) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + + err = client.Mkdir(dir1) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - /dir1/testFileName + // - /dir1/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + // - /vdir1/adir/testFileName + // - /vdir1/adir/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + err = client.Rename(dir1, path.Join(vdirPath1, "adir")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + err = client.Mkdir(dir1) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(dir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(dir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + // - /vdir1/adir/testFileName + // - /vdir1/adir/testFileName1 + // - /vdir1/dir1/testFileName1 (initial testFileName) + // - /vdir2/dir1/testFileName (initial testFileName1) + // - /vdir2/adir/testFileName + // - /vdir2/adir/testFileName1 + err = client.Rename(dir1, path.Join(vdirPath2, "adir")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize*2+testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize1*2+testFileSize, f.UsedQuotaSize) + assert.Equal(t, 3, f.UsedQuotaFiles) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestVirtualFoldersLink(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + // quota is unlimited and excluded from user's one + QuotaFiles: 0, + QuotaSize: 0, + }) + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + testDir := "adir" + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath1, testDir, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath2, testDir, testFileName+".link")) + assert.NoError(t, err) + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testFileName+".link1")) //nolint:goconst + assert.Error(t, err) + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath1, testDir, testFileName+".link1")) + assert.Error(t, err) + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testFileName+".link1")) + assert.Error(t, err) + err = client.Symlink(path.Join("/", testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) + assert.Error(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), testFileName+".link1") + assert.Error(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), testFileName+".link1") + assert.Error(t, err) + err = client.Symlink(path.Join(vdirPath1, testFileName), path.Join(vdirPath2, testDir, testFileName+".link1")) + assert.Error(t, err) + err = client.Symlink(path.Join(vdirPath2, testFileName), path.Join(vdirPath1, testFileName+".link1")) + assert.Error(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestVirtualFolderQuotaScan(t *testing.T) { + mappedPath := filepath.Join(os.TempDir(), "mapped_dir") + folderName := filepath.Base(mappedPath) + err := os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + testFileSize := int64(65535) + testFilePath := filepath.Join(mappedPath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + folder, _, err := httpdtest.AddFolder(vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + }, http.StatusCreated) + assert.NoError(t, err) + _, err = httpdtest.StartFolderQuotaScan(folder, http.StatusAccepted) + assert.NoError(t, err) + assert.Eventually(t, func() bool { + scans, _, err := httpdtest.GetFoldersQuotaScans(http.StatusOK) + if err == nil { + return len(scans) == 0 + } + return false + }, 1*time.Second, 50*time.Millisecond) + folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, folder.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, folder.UsedQuotaSize) + _, err = httpdtest.RemoveFolder(folder, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestVFolderMultipleQuotaScan(t *testing.T) { + folderName := "folder_name" + res := common.QuotaScans.AddVFolderQuotaScan(folderName) + assert.True(t, res) + res = common.QuotaScans.AddVFolderQuotaScan(folderName) + assert.False(t, res) + res = common.QuotaScans.RemoveVFolderQuotaScan(folderName) + assert.True(t, res) + activeScans := common.QuotaScans.GetVFoldersQuotaScans() + assert.Len(t, activeScans, 0) + res = common.QuotaScans.RemoveVFolderQuotaScan(folderName) + assert.False(t, res) +} + +func TestVFolderQuotaSize(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + testFileSize := int64(131072) + u.QuotaFiles = 1 + u.QuotaSize = testFileSize + 1 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vpath1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vpath2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + // quota is included in the user's one + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 1, + QuotaSize: testFileSize * 2, + }) + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // vdir1 is included in the user quota so upload must fail + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.Error(t, err) + // upload to vdir2 must work, it has its own quota + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName), testFileSize, client) + assert.NoError(t, err) + // now vdir2 is over quota + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+".quota"), testFileSize, client) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + // remove a file + err = client.Remove(testFileName) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + // upload to vdir1 must work now + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testFileName), testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, f.UsedQuotaSize) + assert.Equal(t, 1, f.UsedQuotaFiles) + } + // now create another user with the same shared folder but a different quota limit + u.Username = defaultUsername + "1" + u.VirtualFolders = nil + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 10, + QuotaSize: testFileSize*2 + 1, + }) + user1, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err = getSftpClient(user1, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+".quota"), testFileSize, client) + assert.NoError(t, err) + // the folder is now over quota for size but not for files + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testFileName+".quota1"), testFileSize, client) + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestMissingFile(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile("missing_file", localDownloadPath, 0, client) + assert.Error(t, err, "download missing file must fail") + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestOpenError(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + usePubKey := false + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = os.Chmod(user.GetHomeDir(), 0001) + assert.NoError(t, err) + _, err = client.ReadDir(".") + assert.Error(t, err, "read dir must fail if we have no filesystem read permissions") + err = os.Chmod(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + testFileSize := int64(65535) + testFilePath := filepath.Join(user.GetHomeDir(), testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + _, err = client.Stat(testFileName) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = os.Chmod(testFilePath, 0001) + assert.NoError(t, err) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err, "file download must fail if we have no filesystem read permissions") + err = sftpUploadFile(localDownloadPath, testFileName, testFileSize, client) + assert.Error(t, err, "upload must fail if we have no filesystem write permissions") + testDir := "test" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = createTestFile(filepath.Join(user.GetHomeDir(), testDir, testFileName), testFileSize) + assert.NoError(t, err) + err = os.Chmod(user.GetHomeDir(), 0000) + assert.NoError(t, err) + _, err = client.Lstat(testFileName) + assert.Error(t, err, "file stat must fail if we have no filesystem read permissions") + err = sftpUploadFile(localDownloadPath, path.Join(testDir, testFileName), testFileSize, client) + assert.ErrorIs(t, err, os.ErrPermission) + _, err = client.ReadLink(testFileName) + assert.ErrorIs(t, err, os.ErrPermission) + err = client.Remove(testFileName) + assert.ErrorIs(t, err, os.ErrPermission) + err = os.Chmod(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), 0000) + assert.NoError(t, err) + err = client.Rename(testFileName, path.Join(testDir, testFileName)) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir), os.ModePerm) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestOverwriteDirWithFile(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(65535) + testDirName := "test_dir" //nolint:goconst + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = client.Mkdir(testDirName) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testDirName, testFileSize, client) + assert.Error(t, err, "copying a file over an existing dir must fail") + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testDirName) + assert.Error(t, err, "rename a file over an existing dir must fail") + err = client.RemoveDirectory(testDirName) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHashedPasswords(t *testing.T) { + usePubKey := false + plainPwd := "password" + pwdMapping := make(map[string]string) + pwdMapping["$argon2id$v=19$m=65536,t=3,p=2$xtcO/oRkC8O2Tn+mryl2mw$O7bn24f2kuSGRMi9s5Cm61Wqd810px1jDsAasrGWkzQ"] = plainPwd + pwdMapping["$pbkdf2-sha1$150000$DveVjgYUD05R$X6ydQZdyMeOvpgND2nqGR/0GGic="] = plainPwd + pwdMapping["$pbkdf2-sha256$150000$E86a9YMX3zC7$R5J62hsSq+pYw00hLLPKBbcGXmq7fj5+/M0IFoYtZbo="] = plainPwd + pwdMapping["$pbkdf2-sha512$150000$dsu7T5R3IaVQ$1hFXPO1ntRBcoWkSLKw+s4sAP09Xtu4Ya7CyxFq64jM9zdUg8eRJVr3NcR2vQgb0W9HHvZaILHsL4Q/Vr6arCg=="] = plainPwd + pwdMapping["$1$b5caebda$VODr/nyhGWgZaY8sJ4x05."] = plainPwd + pwdMapping["$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK"] = "secret" + pwdMapping["$6$459ead56b72e44bc$uog86fUxscjt28BZxqFBE2pp2QD8P/1e98MNF75Z9xJfQvOckZnQ/1YJqiq1XeytPuDieHZvDAMoP7352ELkO1"] = "secret" + pwdMapping["$5$h4Aalt0fJdGX8sgv$Rd2ew0fvgzUN.DzAVlKa9QL4q/DZWo4SsKpB9.3AyZ/"] = plainPwd + pwdMapping["$apr1$OBWLeSme$WoJbB736e7kKxMBIAqilb1"] = plainPwd + pwdMapping["{MD5}5f4dcc3b5aa765d61d8327deb882cf99"] = plainPwd + pwdMapping["{SHA256}5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8"] = plainPwd + pwdMapping["{SHA512}b109f3bbbc244eb82441917ed06d618b9008dd09b3befd1b5e07394c706a8bb980b1d7785e5976ec049b46df5f1326af5a2ea6d103fd07c95385ffab0cacbc86"] = plainPwd + for pwd, clearPwd := range pwdMapping { + u := getTestUser(usePubKey) + u.Password = pwd + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Password = "" + userGetInitial, _, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.Equal(t, pwd, user.Password) + user.Password = clearPwd + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err, "unable to login with password %q", pwd) { + assert.NoError(t, checkBasicSFTP(client)) + conn.Close() + client.Close() + } + user.Password = pwd + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err, "login with wrong password must fail") { + client.Close() + conn.Close() + } + // the password must converted to bcrypt and we should still be able to login + user, err = dataprovider.UserExists(user.Username, "") + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(user.Password, "$2a$")) + // update the user to invalidate the cached password and force a new check + user.Password = "" + userGet, _, err := httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + userGetInitial.LastLogin = userGet.LastLogin + userGetInitial.UpdatedAt = userGet.UpdatedAt + assert.Equal(t, userGetInitial, userGet) + // login should still work + user.Password = clearPwd + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err, "unable to login with password %q", pwd) { + assert.NoError(t, checkBasicSFTP(client)) + conn.Close() + client.Close() + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + } +} + +func TestPasswordsHashPbkdf2Sha256_389DS(t *testing.T) { + pbkdf389dsPwd := "{PBKDF2_SHA256}AAAIAMZIKG4ie44zJY4HOXI+upFR74PzWLUQV63jg+zzkbEjCK3N4qW583WF7EdcpeoOMQ4HY3aWEXB6lnXhXJixbJkU4vVSJkL6YCbU3TrD0qn1uUUVSkaIgAOtmZENitwbhYhiWfEzGyAtFqkFd75P5xhWJEog9XhQKYrR0f7S3WGGZq03JRcLJ460xpU97bE/sWRn7sshgkWzLuyrs0I+XRKmK7FJeaA9zd+1m44Y3IVmZ2YLdKATzjRHAIgpBC6i1TWOcpKJT1+feP1C9hrxH8vU9baw9thNiO8jSHaZlwb//KpJFe0ahVnG/1ubiG8cO0+CCqDqXVJR6Vr4QZxHP+4pwooW+4TP/L+HFdyA1y6z4gKfqYnBsmb3sD1R1TbxfH4btTdvgZAnBk9CmR3QASkFXxeTYsrmNd5+9IAHc6dm" + pbkdf389dsPwd = pbkdf389dsPwd[15:] + hashBytes, err := base64.StdEncoding.DecodeString(pbkdf389dsPwd) + assert.NoError(t, err) + iterBytes := hashBytes[0:4] + var iterations int32 + err = binary.Read(bytes.NewBuffer(iterBytes), binary.BigEndian, &iterations) + assert.NoError(t, err) + salt := hashBytes[4:68] + targetKey := hashBytes[68:] + key := base64.StdEncoding.EncodeToString(targetKey) + pbkdf2Pwd := fmt.Sprintf("$pbkdf2-b64salt-sha256$%v$%v$%v", iterations, base64.StdEncoding.EncodeToString(salt), key) + pbkdf2ClearPwd := "password" + usePubKey := false + u := getTestUser(usePubKey) + u.Password = pbkdf2Pwd + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + user.Password = pbkdf2ClearPwd + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + assert.NoError(t, checkBasicSFTP(client)) + } + user.Password = pbkdf2Pwd + conn, client, err = getSftpClient(user, usePubKey) + if !assert.Error(t, err, "login with wrong password must fail") { + client.Close() + conn.Close() + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermList(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + u.Permissions["/sub"] = []string{dataprovider.PermCreateSymlinks, dataprovider.PermListItems} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + _, err = client.ReadDir(".") + assert.ErrorIs(t, err, os.ErrPermission, "read remote dir without permission should not succeed") + _, err = client.Stat("test_file") + assert.ErrorIs(t, err, os.ErrPermission, "stat remote file without permission should not succeed") + _, err = client.Lstat("test_file") + assert.ErrorIs(t, err, os.ErrPermission, "lstat remote file without permission should not succeed") + _, err = client.ReadLink("test_link") + assert.ErrorIs(t, err, os.ErrPermission, "read remote link without permission on source dir should not succeed") + _, err = client.RealPath(".") + assert.ErrorIs(t, err, os.ErrPermission, "real path without permission should not succeed") + f, err := client.Create(testFileName) + if assert.NoError(t, err) { + _, err = f.Write([]byte("content")) + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + } + err = client.Mkdir("sub") + assert.NoError(t, err) + err = client.Symlink(path.Join("/", testFileName), path.Join("/sub", testFileName)) + assert.NoError(t, err) + _, err = client.ReadLink(path.Join("/sub", testFileName)) + assert.Error(t, err, "read remote link without permission on targe dir should not succeed") + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermDownload(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload, dataprovider.PermDelete, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err, "file download without permission should not succeed") + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermUpload(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermDelete, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.Error(t, err, "file upload without permission should not succeed") + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermOverwrite(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.Error(t, err, "file overwrite without permission should not succeed") + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermDelete(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermRename, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Remove(testFileName) + assert.Error(t, err, "delete without permission should not succeed") + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestPermRename(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+".rename") + assert.True(t, errors.Is(err, fs.ErrPermission)) + _, err = client.Stat(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestPermRenameOverwrite(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermChmod, dataprovider.PermRename, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testFileName+".rename") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermCreateDirs(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, dataprovider.PermChmod, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("testdir") + assert.Error(t, err, "mkdir without permission should not succeed") + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestPermSymlink(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermOverwrite, dataprovider.PermChmod, dataprovider.PermChown, + dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Symlink(testFilePath, testFilePath+".symlink") + assert.Error(t, err, "symlink without permission should not succeed") + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermChmod(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, + dataprovider.PermChown, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Chmod(testFileName, os.ModePerm) + assert.Error(t, err, "chmod without permission should not succeed") + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestPermChown(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, + dataprovider.PermChmod, dataprovider.PermChtimes} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Chown(testFileName, os.Getuid(), os.Getgid()) + assert.Error(t, err, "chown without permission should not succeed") + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +//nolint:dupl +func TestPermChtimes(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermDelete, + dataprovider.PermRename, dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermOverwrite, + dataprovider.PermChmod, dataprovider.PermChown} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Chtimes(testFileName, time.Now(), time.Now()) + assert.Error(t, err, "chtimes without permission should not succeed") + err = client.Remove(testFileName) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSubDirsUploads(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir"] = []string{dataprovider.PermChtimes, dataprovider.PermDownload, dataprovider.PermOverwrite} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("subdir") + assert.NoError(t, err) + testFileNameSub := "/subdir/test_file_dat" + testSubFile := filepath.Join(user.GetHomeDir(), "subdir", "file.dat") + testDir := "testdir" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testSubFile, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileNameSub, testFileSize, client) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Symlink(testFileName, testFileNameSub+".link") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + err = client.Rename(testFileName, testFileNameSub+".rename") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + // rename overwriting an existing file + err = client.Rename(testFileName, testFileName+".rename") + assert.NoError(t, err) + // now try to overwrite a directory + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(testFileName, testDir) + assert.Error(t, err) + err = client.Remove(testFileName) + assert.NoError(t, err) + err = client.Remove(testDir) + assert.NoError(t, err) + err = client.Remove(path.Join("/subdir", "file.dat")) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Remove(testFileName + ".rename") + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSubDirsOverwrite(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir"] = []string{dataprovider.PermOverwrite, dataprovider.PermListItems} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileName := "/subdir/test_file.dat" //nolint:goconst + testFilePath := filepath.Join(homeBasePath, "test_file.dat") + testFileSFTPPath := filepath.Join(u.GetHomeDir(), "subdir", "test_file.dat") + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFileSFTPPath, 16384) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName+".new", testFileSize, client) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSubDirsDownloads(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir"] = []string{dataprovider.PermChmod, dataprovider.PermUpload, dataprovider.PermListItems} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("subdir") + assert.NoError(t, err) + testFileName := "/subdir/test_file.dat" + testFilePath := filepath.Join(homeBasePath, "test_file.dat") + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = sftpDownloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Chtimes(testFileName, time.Now(), time.Now()) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Rename(testFileName, testFileName+".rename") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Symlink(testFileName, testFileName+".link") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Remove(testFileName) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermsSubDirsSetstat(t *testing.T) { + // for setstat we check the parent dir permission if the requested path is a dir + // otherwise the path permission + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermCreateDirs} + u.Permissions["/subdir"] = []string{dataprovider.PermAny} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("subdir") + assert.NoError(t, err) + testFileName := "/subdir/test_file.dat" + testFilePath := filepath.Join(homeBasePath, "test_file.dat") + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = client.Chtimes("/subdir/", time.Now(), time.Now()) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Chtimes("subdir/", time.Now(), time.Now()) + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Chtimes(testFileName, time.Now(), time.Now()) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestOpenUnhandledChannel(t *testing.T) { + u := getTestUser(false) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, + Timeout: 5 * time.Second, + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if assert.NoError(t, err) { + _, _, err = conn.OpenChannel("unhandled", nil) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unknown channel type") + } + err = conn.Close() + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAlgorithmNotNegotiated(t *testing.T) { + u := getTestUser(false) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + config := &ssh.ClientConfig{ + Config: ssh.Config{ + Ciphers: []string{ssh.InsecureCipherRC4}, + }, + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: []ssh.AuthMethod{ssh.Password(defaultPassword)}, + Timeout: 5 * time.Second, + } + _, err = ssh.Dial("tcp", sftpServerAddr, config) + if assert.Error(t, err) { + negotiationErr := &ssh.AlgorithmNegotiationError{} + assert.ErrorAs(t, err, &negotiationErr) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPermsSubDirsCommands(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs} + u.Permissions["/subdir/otherdir"] = []string{dataprovider.PermListItems, dataprovider.PermDownload} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Mkdir("subdir") + assert.NoError(t, err) + acmodTime := time.Now() + err = client.Chtimes("/subdir", acmodTime, acmodTime) + assert.NoError(t, err) + _, err = client.Stat("/subdir") + assert.NoError(t, err) + _, err = client.ReadDir("/") + assert.NoError(t, err) + _, err = client.ReadDir("/subdir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.RemoveDirectory("/subdir/dir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Mkdir("/subdir/otherdir/dir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Mkdir("/otherdir") + assert.NoError(t, err) + err = client.Mkdir("/subdir/otherdir") + assert.NoError(t, err) + err = client.Rename("/otherdir", "/subdir/otherdir/adir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Symlink("/otherdir", "/subdir/otherdir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Symlink("/otherdir", "/otherdir_link") + assert.NoError(t, err) + err = client.Rename("/otherdir", "/otherdir1") + assert.NoError(t, err) + err = client.RemoveDirectory("/otherdir1") + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRootDirCommands(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/subdir"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Rename("/", "rootdir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.Symlink("/", "rootdir") + assert.True(t, errors.Is(err, fs.ErrPermission)) + err = client.RemoveDirectory("/") + assert.True(t, errors.Is(err, fs.ErrPermission)) + } + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRelativePaths(t *testing.T) { + user := getTestUser(true) + var path, rel string + filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "", nil)} + keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/" + s3config := vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + KeyPrefix: keyPrefix, + }, + } + s3fs, _ := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config) + gcsConfig := vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + KeyPrefix: keyPrefix, + }, + } + gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig) + sftpconfig := vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: defaultUsername, + Prefix: keyPrefix, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + sftpfs, _ := vfs.NewSFTPFs("", "", os.TempDir(), []string{user.Username}, sftpconfig) + if runtime.GOOS != osWindows { + filesystems = append(filesystems, s3fs, gcsfs, sftpfs) + } + rootPath := "/" + for _, fs := range filesystems { + path = filepath.Join(user.HomeDir, "/") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "//") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "../..") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "../../../../../") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "/..") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "/../../../..") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, ".") + rel = fs.GetRelativePath(path) + assert.Equal(t, rootPath, rel) + path = filepath.Join(user.HomeDir, "somedir") + rel = fs.GetRelativePath(path) + assert.Equal(t, "/somedir", rel) + path = filepath.Join(user.HomeDir, "/somedir/subdir") + rel = fs.GetRelativePath(path) + assert.Equal(t, "/somedir/subdir", rel) + } +} + +func TestResolvePaths(t *testing.T) { + user := getTestUser(true) + var path, resolved string + var err error + filesystems := []vfs.Fs{vfs.NewOsFs("", user.GetHomeDir(), "", nil)} + keyPrefix := strings.TrimPrefix(user.GetHomeDir(), "/") + "/" + s3config := vfs.S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + KeyPrefix: keyPrefix, + Bucket: "bucket", + Region: "us-east-1", + }, + } + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + s3fs, err := vfs.NewS3Fs("", user.GetHomeDir(), "", s3config) + assert.NoError(t, err) + gcsConfig := vfs.GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + KeyPrefix: keyPrefix, + }, + } + gcsfs, _ := vfs.NewGCSFs("", user.GetHomeDir(), "", gcsConfig) + if runtime.GOOS != osWindows { + filesystems = append(filesystems, s3fs, gcsfs) + } + for _, fs := range filesystems { + path = "/" + resolved, _ = fs.ResolvePath(filepath.ToSlash(path)) + assert.Equal(t, fs.Join(user.GetHomeDir(), "/"), resolved) + path = "." + resolved, _ = fs.ResolvePath(filepath.ToSlash(path)) + assert.Equal(t, fs.Join(user.GetHomeDir(), "/"), resolved) + path = "test/sub" + resolved, _ = fs.ResolvePath(filepath.ToSlash(path)) + assert.Equal(t, fs.Join(user.GetHomeDir(), "/test/sub"), resolved) + path = "../test/sub" + resolved, err = fs.ResolvePath(filepath.ToSlash(path)) + assert.NoError(t, err) + assert.Equal(t, fs.Join(user.GetHomeDir(), "/test/sub"), resolved) + path = "../../../test/../sub" + resolved, err = fs.ResolvePath(filepath.ToSlash(path)) + assert.NoError(t, err) + assert.Equal(t, fs.Join(user.GetHomeDir(), "/sub"), resolved) + } + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestVirtualRelativePaths(t *testing.T) { + user := getTestUser(true) + mappedPath := filepath.Join(os.TempDir(), "mdir") + vdirPath := "/vdir" //nolint:goconst + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + }) + err := os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + fsRoot := vfs.NewOsFs("", user.GetHomeDir(), "", nil) + fsVdir := vfs.NewOsFs("", mappedPath, vdirPath, nil) + rel := fsVdir.GetRelativePath(mappedPath) + assert.Equal(t, vdirPath, rel) + rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "..")) + assert.Equal(t, "/", rel) + // path outside home and virtual dir + rel = fsRoot.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) + assert.Equal(t, "/", rel) + rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "../vdir1")) + assert.Equal(t, "/vdir", rel) + rel = fsVdir.GetRelativePath(filepath.Join(mappedPath, "file.txt")) + assert.Equal(t, "/vdir/file.txt", rel) + rel = fsRoot.GetRelativePath(filepath.Join(user.HomeDir, "vdir1/file.txt")) + assert.Equal(t, "/vdir1/file.txt", rel) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestUserPerms(t *testing.T) { + user := getTestUser(true) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermListItems} + user.Permissions["/p"] = []string{dataprovider.PermDelete} + user.Permissions["/p/1"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} + user.Permissions["/p/2"] = []string{dataprovider.PermCreateDirs} + user.Permissions["/p/3"] = []string{dataprovider.PermChmod} + user.Permissions["/p/3/4"] = []string{dataprovider.PermChtimes} + user.Permissions["/tmp"] = []string{dataprovider.PermRename} + assert.True(t, user.HasPerm(dataprovider.PermListItems, "/")) + assert.True(t, user.HasPerm(dataprovider.PermListItems, ".")) + assert.True(t, user.HasPerm(dataprovider.PermListItems, "")) + assert.True(t, user.HasPerm(dataprovider.PermListItems, "../")) + // path p and /p are the same + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/p")) + assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/1")) + assert.True(t, user.HasPerm(dataprovider.PermCreateDirs, "p/2")) + assert.True(t, user.HasPerm(dataprovider.PermChmod, "/p/3")) + assert.True(t, user.HasPerm(dataprovider.PermChtimes, "p/3/4/")) + assert.True(t, user.HasPerm(dataprovider.PermChtimes, "p/3/4/../4")) + // undefined paths have permissions of the nearest path + assert.True(t, user.HasPerm(dataprovider.PermListItems, "/p34")) + assert.True(t, user.HasPerm(dataprovider.PermListItems, "/p34/p1/file.dat")) + assert.True(t, user.HasPerm(dataprovider.PermChtimes, "/p/3/4/5/6")) + assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/1/test/file.dat")) +} + +func TestWildcardPermissions(t *testing.T) { + user := getTestUser(true) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermListItems} + user.Permissions["/p*"] = []string{dataprovider.PermDelete} + user.Permissions["/p/*"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} + user.Permissions["/p/2"] = []string{dataprovider.PermCreateDirs} + user.Permissions["/pa"] = []string{dataprovider.PermChmod} + user.Permissions["/p/3/4"] = []string{dataprovider.PermChtimes} + assert.True(t, user.HasPerm(dataprovider.PermListItems, "/")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/p1")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/ppppp")) + assert.False(t, user.HasPerm(dataprovider.PermDelete, "/pa")) + assert.True(t, user.HasPerm(dataprovider.PermChmod, "/pa")) + assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/1")) + assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/p")) + assert.False(t, user.HasPerm(dataprovider.PermUpload, "/p/2")) + assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/3")) + assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/a/a/a")) + assert.False(t, user.HasPerm(dataprovider.PermDownload, "/p/3/4")) + assert.True(t, user.HasPerm(dataprovider.PermChtimes, "/p/3/4")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/pb/a/a/a")) + assert.False(t, user.HasPerm(dataprovider.PermDelete, "/abc/a/a/a")) + assert.True(t, user.HasPerm(dataprovider.PermListItems, "/abc/a/a/a/b")) +} + +func TestRootWildcardPerms(t *testing.T) { + user := getTestUser(true) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermListItems} + user.Permissions["/*"] = []string{dataprovider.PermDelete} + user.Permissions["/p/*"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} + user.Permissions["/p/2"] = []string{dataprovider.PermCreateDirs} + user.Permissions["/pa"] = []string{dataprovider.PermChmod} + user.Permissions["/p/3/4"] = []string{dataprovider.PermChtimes} + assert.True(t, user.HasPerm(dataprovider.PermListItems, "/")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/p1")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/ppppp")) + assert.False(t, user.HasPerm(dataprovider.PermDelete, "/pa")) + assert.True(t, user.HasPerm(dataprovider.PermChmod, "/pa")) + assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/1")) + assert.True(t, user.HasPerm(dataprovider.PermUpload, "/p/p")) + assert.False(t, user.HasPerm(dataprovider.PermUpload, "/p/2")) + assert.True(t, user.HasPerm(dataprovider.PermCreateDirs, "/p/2")) + assert.True(t, user.HasPerm(dataprovider.PermCreateDirs, "/p/2/a")) + assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/3")) + assert.True(t, user.HasPerm(dataprovider.PermDownload, "/p/a/a/a")) + assert.False(t, user.HasPerm(dataprovider.PermDownload, "/p/3/4")) + assert.True(t, user.HasPerm(dataprovider.PermChtimes, "/p/3/4")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/pb/a/a/a")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/abc/a/a/a")) + assert.False(t, user.HasPerm(dataprovider.PermListItems, "/abc/a/a/a/b")) + assert.True(t, user.HasPerm(dataprovider.PermDelete, "/abc/a/a/a/b")) +} + +func TestFilterFilePatterns(t *testing.T) { + user := getTestUser(true) + pattern := sdk.PatternsFilter{ + Path: "/test", + AllowedPatterns: []string{"*.jpg", "*.png"}, + DeniedPatterns: []string{"*.pdf"}, + } + filters := dataprovider.UserFilters{ + BaseUserFilters: sdk.BaseUserFilters{ + FilePatterns: []sdk.PatternsFilter{pattern}, + }, + } + user.Filters = filters + ok, _ := user.IsFileAllowed("/test/test.jPg") + assert.True(t, ok) + ok, _ = user.IsFileAllowed("/test/test.pdf") + assert.False(t, ok) + ok, _ = user.IsFileAllowed("/test.pDf") + assert.True(t, ok) + + filters.FilePatterns = append(filters.FilePatterns, sdk.PatternsFilter{ + Path: "/", + AllowedPatterns: []string{"*.zip", "*.rar", "*.pdf"}, + DeniedPatterns: []string{"*.gz"}, + }) + user.Filters = filters + ok, _ = user.IsFileAllowed("/test1/test.gz") + assert.False(t, ok) + ok, _ = user.IsFileAllowed("/test1/test.zip") + assert.True(t, ok) + ok, _ = user.IsFileAllowed("/test/sub/test.pdf") + assert.False(t, ok) + ok, _ = user.IsFileAllowed("/test1/test.png") + assert.False(t, ok) + + filters.FilePatterns = append(filters.FilePatterns, sdk.PatternsFilter{ + Path: "/test/sub", + DeniedPatterns: []string{"*.tar"}, + }) + user.Filters = filters + ok, _ = user.IsFileAllowed("/test/sub/sub/test.tar") + assert.False(t, ok) + ok, _ = user.IsFileAllowed("/test/sub/test.gz") + assert.True(t, ok) + ok, _ = user.IsFileAllowed("/test/test.zip") + assert.False(t, ok) +} + +func TestUserAllowedLoginMethods(t *testing.T) { + user := getTestUser(true) + user.Filters.DeniedLoginMethods = dataprovider.ValidLoginMethods + allowedMethods := user.GetAllowedLoginMethods() + assert.Equal(t, 0, len(allowedMethods)) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + allowedMethods = user.GetAllowedLoginMethods() + assert.Equal(t, 4, len(allowedMethods)) + + assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt)) + assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword)) +} + +func TestUserPartialAuth(t *testing.T) { + user := getTestUser(true) + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + assert.True(t, user.IsPartialAuth()) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + assert.False(t, user.IsPartialAuth()) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + } + assert.False(t, user.IsPartialAuth()) + user.Filters.DeniedLoginMethods = []string{ + dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + assert.True(t, user.IsPartialAuth()) +} + +func TestUserGetNextAuthMethods(t *testing.T) { + user := getTestUser(true) + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + methods := user.GetNextAuthMethods() + require.Len(t, methods, 2) + assert.Equal(t, dataprovider.LoginMethodPassword, methods[0]) + assert.Equal(t, dataprovider.SSHLoginMethodKeyboardInteractive, methods[1]) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, + } + methods = user.GetNextAuthMethods() + require.Len(t, methods, 1) + assert.Equal(t, dataprovider.LoginMethodPassword, methods[0]) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + dataprovider.SSHLoginMethodKeyAndPassword, + } + methods = user.GetNextAuthMethods() + require.Len(t, methods, 1) + assert.Equal(t, dataprovider.SSHLoginMethodKeyboardInteractive, methods[0]) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, + } + methods = user.GetNextAuthMethods() + require.Len(t, methods, 0) +} + +func TestUserIsLoginMethodAllowed(t *testing.T) { + user := getTestUser(true) + user.Filters.DeniedLoginMethods = []string{ + dataprovider.LoginMethodPassword, + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolSSH)) + assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolFTP)) + assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolWebDAV)) + assert.False(t, user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodPublicKey, common.ProtocolSSH)) + assert.False(t, user.IsLoginMethodAllowed(dataprovider.SSHLoginMethodKeyboardInteractive, common.ProtocolSSH)) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.SSHLoginMethodPublicKey, + dataprovider.SSHLoginMethodKeyboardInteractive, + } + assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolSSH)) + + user.Filters.DeniedLoginMethods = []string{ + dataprovider.SSHLoginMethodPassword, + } + assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP)) + assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolFTP)) + assert.True(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolWebDAV)) + assert.False(t, user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolSSH)) +} + +func TestUserEmptySubDirPerms(t *testing.T) { + user := getTestUser(true) + user.Permissions = make(map[string][]string) + user.Permissions["/emptyperms"] = []string{} + for _, p := range dataprovider.ValidPerms { + assert.False(t, user.HasPerm(p, "/emptyperms")) + } +} + +func TestUserFiltersIPMaskConditions(t *testing.T) { + user := getTestUser(true) + // with no filter login must be allowed even if the remoteIP is invalid + assert.True(t, user.IsLoginFromAddrAllowed("192.168.1.5")) + assert.True(t, user.IsLoginFromAddrAllowed("invalid")) + + user.Filters.DeniedIP = append(user.Filters.DeniedIP, "192.168.1.0/24") + assert.False(t, user.IsLoginFromAddrAllowed("192.168.1.5")) + assert.True(t, user.IsLoginFromAddrAllowed("192.168.2.6")) + + user.Filters.AllowedIP = append(user.Filters.AllowedIP, "192.168.1.5/32") + // if the same ip/mask is both denied and allowed then login must be allowed + assert.True(t, user.IsLoginFromAddrAllowed("192.168.1.5")) + assert.False(t, user.IsLoginFromAddrAllowed("192.168.1.3")) + assert.False(t, user.IsLoginFromAddrAllowed("192.168.3.6")) + + user.Filters.DeniedIP = []string{} + assert.True(t, user.IsLoginFromAddrAllowed("192.168.1.5")) + assert.False(t, user.IsLoginFromAddrAllowed("192.168.1.6")) + + user.Filters.DeniedIP = []string{"192.168.0.0/16", "172.16.0.0/16"} + user.Filters.AllowedIP = []string{} + assert.False(t, user.IsLoginFromAddrAllowed("192.168.5.255")) + assert.False(t, user.IsLoginFromAddrAllowed("172.16.1.2")) + assert.True(t, user.IsLoginFromAddrAllowed("172.18.2.1")) + + user.Filters.AllowedIP = []string{"10.4.4.0/24"} + assert.False(t, user.IsLoginFromAddrAllowed("10.5.4.2")) + assert.True(t, user.IsLoginFromAddrAllowed("10.4.4.2")) + assert.True(t, user.IsLoginFromAddrAllowed("invalid")) +} + +func TestGetVirtualFolderForPath(t *testing.T) { + user := getTestUser(true) + mappedPath1 := filepath.Join(os.TempDir(), "vpath1") + mappedPath2 := filepath.Join(os.TempDir(), "vpath1") + mappedPath3 := filepath.Join(os.TempDir(), "vpath3") + vdirPath := "/vdir/sub" + vSubDirPath := path.Join(vdirPath, "subdir", "subdir") + vSubDir1Path := path.Join(vSubDirPath, "subdir", "subdir") + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath1, + }, + VirtualPath: vdirPath, + }) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath2, + }, + VirtualPath: vSubDir1Path, + }) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath3, + }, + VirtualPath: vSubDirPath, + }) + folder, err := user.GetVirtualFolderForPath(path.Join(vSubDirPath, "file")) + assert.NoError(t, err) + assert.Equal(t, folder.MappedPath, mappedPath3) + _, err = user.GetVirtualFolderForPath("/file") + assert.Error(t, err) + folder, err = user.GetVirtualFolderForPath(path.Join(vdirPath, "/file")) + assert.NoError(t, err) + assert.Equal(t, folder.MappedPath, mappedPath1) + folder, err = user.GetVirtualFolderForPath(path.Join(vSubDirPath+"1", "file")) + assert.NoError(t, err) + assert.Equal(t, folder.MappedPath, mappedPath1) + _, err = user.GetVirtualFolderForPath("/vdir/sub1/file") + assert.Error(t, err) + folder, err = user.GetVirtualFolderForPath(vdirPath) + assert.NoError(t, err) +} + +func TestStatVFS(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(65535) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Greater(t, stat.ID, uint32(0)) + assert.Greater(t, stat.Blocks, uint64(0)) + assert.Greater(t, stat.Bsize, uint64(0)) + + _, err = client.StatVFS("missing-path") + assert.Error(t, err) + assert.True(t, errors.Is(err, fs.ErrNotExist)) + } + user.QuotaFiles = 100 + user.Filters.DisableFsChecks = true + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + + _, err = client.StatVFS("missing-path") + assert.Error(t, err) + assert.ErrorIs(t, err, fs.ErrNotExist) + + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Greater(t, stat.ID, uint32(0)) + assert.Greater(t, stat.Blocks, uint64(0)) + assert.Greater(t, stat.Bsize, uint64(0)) + assert.Equal(t, uint64(100), stat.Files) + assert.Equal(t, uint64(99), stat.Ffree) + } + + user.QuotaSize = 8192 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Greater(t, stat.ID, uint32(0)) + assert.Greater(t, stat.Blocks, uint64(0)) + assert.Greater(t, stat.Bsize, uint64(0)) + assert.Equal(t, uint64(100), stat.Files) + assert.Equal(t, uint64(0), stat.Ffree) + assert.Equal(t, uint64(2), stat.Blocks) + assert.Equal(t, uint64(0), stat.Bfree) + } + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Greater(t, stat.ID, uint32(0)) + assert.Greater(t, stat.Blocks, uint64(0)) + assert.Greater(t, stat.Bsize, uint64(0)) + assert.Greater(t, stat.Files, uint64(0)) + assert.Equal(t, uint64(0), stat.Ffree) + assert.Equal(t, uint64(2), stat.Blocks) + assert.Equal(t, uint64(0), stat.Bfree) + } + + user.QuotaSize = 1 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Greater(t, stat.ID, uint32(0)) + assert.Equal(t, uint64(1), stat.Blocks) + assert.Equal(t, uint64(1), stat.Bsize) + assert.Greater(t, stat.Files, uint64(0)) + assert.Equal(t, uint64(0), stat.Ffree) + assert.Equal(t, uint64(1), stat.Blocks) + assert.Equal(t, uint64(0), stat.Bfree) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestStatVFSCloudBackend(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.FsConfig.Provider = sdk.AzureBlobFilesystemProvider + u.FsConfig.AzBlobConfig.SASURL = kms.NewPlainSecret("https://myaccount.blob.core.windows.net/sasurl") + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + + err = dataprovider.UpdateUserQuota(&user, 100, 8192, true) + assert.NoError(t, err) + stat, err := client.StatVFS("/") + assert.NoError(t, err) + assert.Greater(t, stat.ID, uint32(0)) + assert.Greater(t, stat.Blocks, uint64(0)) + assert.Greater(t, stat.Bsize, uint64(0)) + assert.Equal(t, uint64(1000000+100), stat.Files) + assert.Equal(t, uint64(2147483648+2), stat.Blocks) + assert.Equal(t, uint64(1000000), stat.Ffree) + assert.Equal(t, uint64(2147483648), stat.Bfree) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSSHCommands(t *testing.T) { + usePubKey := false + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + _, err = runSSHCommand("ls", user, usePubKey) + assert.Error(t, err, "unsupported ssh command must fail") + + _, err = runSSHCommand("cd", user, usePubKey) + assert.NoError(t, err) + out, err := runSSHCommand("pwd", user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "/\n", string(out)) + } + out, err = runSSHCommand("md5sum", user, usePubKey) + assert.NoError(t, err) + // echo -n '' | md5sum + assert.Contains(t, string(out), "d41d8cd98f00b204e9800998ecf8427e") + + out, err = runSSHCommand("sha1sum", user, usePubKey) + assert.NoError(t, err) + assert.Contains(t, string(out), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + + out, err = runSSHCommand("sha256sum", user, usePubKey) + assert.NoError(t, err) + assert.Contains(t, string(out), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + + out, err = runSSHCommand("sha384sum", user, usePubKey) + assert.NoError(t, err) + assert.Contains(t, string(out), "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSSHFileHash(t *testing.T) { + usePubKey := true + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + u := getTestUserWithCryptFs(usePubKey) + u.Username = u.Username + "_crypt" + cryptUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser, cryptUser} { + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermUpload} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = runSSHCommand("sha512sum "+testFileName, user, usePubKey) + assert.Error(t, err, "hash command with no list permission must fail") + + user.Permissions["/"] = []string{dataprovider.PermAny} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + initialHash, err := computeHashForFile(sha512.New(), testFilePath) + assert.NoError(t, err) + + out, err := runSSHCommand("sha512sum "+testFileName, user, usePubKey) + if assert.NoError(t, err) { + assert.Contains(t, string(out), initialHash) + } + _, err = runSSHCommand("sha512sum invalid_path", user, usePubKey) + assert.Error(t, err, "hash for an invalid path must fail") + + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(cryptUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSSHCopy(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1/subdir" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2/subdir" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 100, + QuotaSize: 0, + }) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + DeniedPatterns: []string{"*.denied"}, + }, + } + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testDir := "adir" + testDir1 := "adir1" + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileName1 := "test_file1.dat" + testFileSize1 := int64(65537) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir1)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir1)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir1, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir1, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = client.Symlink(path.Join(testDir, testFileName), testFileName) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 4, user.UsedQuotaFiles) + assert.Equal(t, 2*testFileSize+2*testFileSize1, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize+testFileSize1, f.UsedQuotaSize) + assert.Equal(t, 2, f.UsedQuotaFiles) + + _, err = client.Stat(testDir1) + assert.Error(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s", path.Join(vdirPath1, testDir1)), user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand("sftpgo-copy", user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", testFileName, testFileName+".linkcopy"), user, usePubKey) + assert.Error(t, err) + out, err := runSSHCommand(fmt.Sprintf("sftpgo-copy %s %s", path.Join(vdirPath1, testDir1), "."), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + fi, err := client.Stat(testDir1) + if assert.NoError(t, err) { + assert.True(t, fi.IsDir()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, 3*testFileSize+3*testFileSize1, user.UsedQuotaSize) + } + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", "missing\\ dir", "."), user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1), "."), user, usePubKey) + if assert.NoError(t, err) { + // all files are overwritten, quota must remain unchanged + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, 3*testFileSize+3*testFileSize1, user.UsedQuotaSize) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir1, testFileName), testFileName+".copy"), //nolint:goconst + user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + fi, err := client.Stat(testFileName + ".copy") //nolint:goconst + if assert.NoError(t, err) { + assert.True(t, fi.Mode().IsRegular()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 7, user.UsedQuotaFiles) + assert.Equal(t, 4*testFileSize+3*testFileSize1, user.UsedQuotaSize) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1), path.Join(vdirPath2, testDir1+"copy")), //nolint:goconst + user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + fi, err := client.Stat(path.Join(vdirPath2, testDir1+"copy")) + if assert.NoError(t, err) { + assert.True(t, fi.IsDir()) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 7, user.UsedQuotaFiles) + assert.Equal(t, 4*testFileSize+3*testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize*2+testFileSize1*2, f.UsedQuotaSize) + assert.Equal(t, 4, f.UsedQuotaFiles) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1), path.Join(vdirPath1, testDir1+"copy")), + user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + _, err := client.Stat(path.Join(vdirPath2, testDir1+"copy")) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 9, user.UsedQuotaFiles) + assert.Equal(t, 5*testFileSize+4*testFileSize1, user.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + } + // cross folder copy + newDir := "newdir" + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, ".."), newDir), user, usePubKey) + assert.NoError(t, err) + _, err = client.Stat(newDir) + assert.NoError(t, err) + // denied pattern + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(testDir, testFileName), testFileName+".denied"), user, usePubKey) + assert.Error(t, err) + if runtime.GOOS != osWindows { + subPath := filepath.Join(mappedPath1, testDir1, "asubdir", "anothersub", "another") + err = os.MkdirAll(subPath, os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(subPath, 0001) + assert.NoError(t, err) + // listing contents for subdirs with no permissions will fail + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", vdirPath1, "newdir1"), user, usePubKey) + assert.Error(t, err) + err = os.Chmod(subPath, os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir1), 0555) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir1, testFileName), + path.Join(testDir1, "anewdir")), user, usePubKey) + assert.Error(t, err) + err = os.Chmod(filepath.Join(user.GetHomeDir(), testDir1), os.ModePerm) + assert.NoError(t, err) + + err = os.Chmod(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestSSHCopyPermissions(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.Permissions["/dir1"] = []string{dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermListItems} + u.Permissions["/dir2"] = []string{dataprovider.PermCreateDirs, dataprovider.PermUpload, dataprovider.PermDownload, + dataprovider.PermListItems, dataprovider.PermCopy} + u.Permissions["/dir3"] = []string{dataprovider.PermCreateDirs, dataprovider.PermCreateSymlinks, dataprovider.PermDownload, + dataprovider.PermListItems} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testDir := "tDir" + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join("/", testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + // test copy file with no permission + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir, testFileName), path.Join("/dir3", testFileName)), + user, usePubKey) + assert.Error(t, err) + // test copy dir with no create dirs perm + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir1/"), user, usePubKey) + assert.Error(t, err) + // dir2 has the needed permissions + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir2/"), user, usePubKey) + assert.NoError(t, err) + info, err := client.Stat(path.Join("/dir2", testDir)) + if assert.NoError(t, err) { + assert.True(t, info.IsDir()) + } + info, err = client.Stat(path.Join("/dir2", testDir, testFileName)) + if assert.NoError(t, err) { + assert.True(t, info.Mode().IsRegular()) + } + // now create a symlink, dir2 has no create symlink permission, but symlinks will be ignored + err = client.Symlink(path.Join("/", testDir, testFileName), path.Join("/", testDir, testFileName+".link")) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir2/sub"), user, usePubKey) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/newdir"), user, usePubKey) + assert.NoError(t, err) + // now delete the file and copy inside /dir3 + err = client.Remove(path.Join("/", testDir, testFileName)) + assert.NoError(t, err) + // the symlink will be skipped, so no errors + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join("/", testDir), "/dir3"), user, usePubKey) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSSHCopyQuotaLimits(t *testing.T) { + usePubKey := true + testFileSize := int64(131072) + testFileSize1 := int64(65536) + testFileSize2 := int64(32768) + u := getTestUser(usePubKey) + u.QuotaFiles = 3 + u.QuotaSize = testFileSize + testFileSize1 + 1 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 3, + QuotaSize: testFileSize + testFileSize1 + 1, + }) + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + DeniedPatterns: []string{"*.denied"}, + }, + } + err := os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testDir := "testDir" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + testFileName2 := "test_file2.dat" + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = createTestFile(testFilePath2, testFileSize2) + assert.NoError(t, err) + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(testDir, testFileName2), testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(testDir, testFileName2+".dupl"), testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(vdirPath2, testDir, testFileName2), testFileSize2, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath2, path.Join(vdirPath2, testDir, testFileName2+".dupl"), testFileSize2, client) + assert.NoError(t, err) + // user quota: 2 files, size: 32768*2, folder2 quota: 2 files, size: 32768*2 + // try to duplicate testDir, this will result in 4 file (over quota) and 32768*4 bytes (not over quota) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", testDir, testDir+"_copy"), user, usePubKey) //nolint:goconst + assert.Error(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir), + path.Join(vdirPath2, testDir+"_copy")), user, usePubKey) + assert.Error(t, err) + + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) + assert.NoError(t, err) + // remove partially copied dirs + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir+"_copy"), user, usePubKey) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir+"_copy")), user, usePubKey) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, f.UsedQuotaFiles) + assert.Equal(t, int64(0), f.UsedQuotaSize) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, f.UsedQuotaFiles) + assert.Equal(t, int64(0), f.UsedQuotaSize) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + + // vdir1 is included in user quota, file limit will be exceeded + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir), "/"), user, usePubKey) + assert.Error(t, err) + + // vdir2 size limit will be exceeded + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir, testFileName), + vdirPath2+"/"), user, usePubKey) + assert.Error(t, err) + // now decrease the limits + user.QuotaFiles = 1 + user.QuotaSize = testFileSize * 10 + for idx, f := range user.VirtualFolders { + if f.Name == folderName2 { + user.VirtualFolders[idx].QuotaSize = testFileSize + user.VirtualFolders[idx].QuotaFiles = 10 + } + } + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + assert.Equal(t, 1, user.QuotaFiles) + assert.Equal(t, testFileSize*10, user.QuotaSize) + if assert.Len(t, user.VirtualFolders, 2) { + for _, f := range user.VirtualFolders { + if f.Name == folderName2 { + assert.Equal(t, testFileSize, f.QuotaSize) + assert.Equal(t, 10, f.QuotaFiles) + } + } + } + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath1, testDir), + path.Join(vdirPath2, testDir+".copy")), user, usePubKey) + assert.Error(t, err) + + _, err = runSSHCommand(fmt.Sprintf("sftpgo-copy %v %v", path.Join(vdirPath2, testDir), + testDir+".copy"), user, usePubKey) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestSSHRemove(t *testing.T) { + usePubKey := false + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1/sub" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2/sub" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 100, + QuotaSize: 0, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileName1 := "test_file1.dat" + testFileSize1 := int64(65537) + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + testDir := "testdir" + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, testFileName, testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath1, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath, path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = sftpUploadFile(testFilePath1, path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName+".link"), user, usePubKey) + assert.NoError(t, err) + _, err = runSSHCommand("sftpgo-remove /vdir1", user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand("sftpgo-remove /", user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand("sftpgo-remove", user, usePubKey) + assert.Error(t, err) + out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + _, err := client.Stat(testFileName) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 3, user.UsedQuotaFiles) + assert.Equal(t, testFileSize+2*testFileSize1, user.UsedQuotaSize) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, testDir)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + _, err := client.Stat(path.Join(vdirPath1, testFileName)) + assert.Error(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, testFileSize1, user.UsedQuotaSize) + } + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", vdirPath1), user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand("sftpgo-remove /", user, usePubKey) + assert.Error(t, err) + _, err = runSSHCommand("sftpgo-remove missing_file", user, usePubKey) + assert.Error(t, err) + if runtime.GOOS != osWindows { + err = os.Chmod(filepath.Join(mappedPath2, testDir), 0555) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) + assert.Error(t, err) + err = os.Chmod(filepath.Join(mappedPath2, testDir), 0001) + assert.NoError(t, err) + _, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) + assert.Error(t, err) + err = os.Chmod(filepath.Join(mappedPath2, testDir), os.ModePerm) + assert.NoError(t, err) + } + } + + // test remove dir with no delete perm + user.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermDownload, dataprovider.PermListItems} + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + conn, client, err = getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + _, err = runSSHCommand("sftpgo-remove adir", user, usePubKey) + assert.Error(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestSSHRemoveCryptFs(t *testing.T) { + usePubKey := false + u := getTestUserWithCryptFs(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1/sub" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2/sub" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 100, + QuotaSize: 0, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + testDir := "tdir" + err = client.Mkdir(testDir) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath1, testDir)) + assert.NoError(t, err) + err = client.Mkdir(path.Join(vdirPath2, testDir)) + assert.NoError(t, err) + testFileSize := int64(32768) + testFileSize1 := int64(65536) + testFileName1 := "test_file1.dat" + err = writeSFTPFile(testFileName, testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(testFileName1, testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath1, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName1), testFileSize1, client) + assert.NoError(t, err) + _, err = runSSHCommand("sftpgo-remove /vdir2", user, usePubKey) + assert.Error(t, err) + out, err := runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testFileName), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + _, err := client.Stat(testFileName) + assert.Error(t, err) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", testDir), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath1, testDir)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir, testFileName)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + err = writeSFTPFile(path.Join(vdirPath2, testDir, testFileName), testFileSize, client) + assert.NoError(t, err) + out, err = runSSHCommand(fmt.Sprintf("sftpgo-remove %v", path.Join(vdirPath2, testDir)), user, usePubKey) + if assert.NoError(t, err) { + assert.Equal(t, "OK\n", string(out)) + } + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Greater(t, user.UsedQuotaSize, testFileSize1) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestSSHCommandMaxTransfers(t *testing.T) { + if len(gitPath) == 0 || len(sshPath) == 0 || runtime.GOOS == osWindows { + t.Skip("git and/or ssh command not found or OS is windows, unable to execute this test") + } + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + repoName := "testrepo" //nolint:goconst + clonePath := filepath.Join(homeBasePath, repoName) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(homeBasePath, repoName)) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + f1, err := client.Create("file1") + assert.NoError(t, err) + f2, err := client.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + _, err = client.Create("file3") + assert.Error(t, err) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + err = client.Close() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(clonePath) + assert.NoError(t, err) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue +} + +// Start SCP tests +func TestSCPBasicHandling(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaSize = 6553600 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaSize = 6553600 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(131074) + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + // test to download a missing file + err = scpDownload(localPath, remoteDownPath, false, false) + assert.Error(t, err, "downloading a missing file via scp must fail") + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, int64(0), user.FirstDownload) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Equal(t, int64(0), user.FirstDownload) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Greater(t, user.FirstDownload, int64(0)) + fi, err := os.Stat(localPath) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + err = os.Remove(localPath) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSCPUploadFileOverwrite(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 1000 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser(usePubKey) + u.QuotaFiles = 1000 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(32760) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.NoError(t, err) + // test a new upload that must overwrite the existing file + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + err = scpDownload(localPath, remoteDownPath, false, false) + assert.NoError(t, err) + + fi, err := os.Stat(localPath) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + // now create a simlink via SFTP, replace the symlink with a file via SCP and check quota usage + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + err = client.Symlink(testFileName, testFileName+".link") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + } + err = scpUpload(testFilePath, remoteUpPath+".link", true, false) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize*2, user.UsedQuotaSize) + assert.Equal(t, 2, user.UsedQuotaFiles) + + err = os.Remove(localPath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPRecursive(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + localUser, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(getTestSFTPUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + testBaseDirName := "test_dir" + testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) + testBaseDirDownName := "test_dir_down" //nolint:goconst + testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) + testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) + testFileSize := int64(131074) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName)) + // test to download a missing dir + err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) + assert.Error(t, err, "downloading a missing dir via scp must fail") + + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.NoError(t, err) + // overwrite existing dir + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.NoError(t, err) + err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) + assert.NoError(t, err) + // test download without passing -r + err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false) + assert.Error(t, err, "recursive download without -r must fail") + + fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName)) + if assert.NoError(t, err) { + assert.Equal(t, testFileSize, fi.Size()) + } + // upload to a non existent dir + remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir") + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.Error(t, err, "uploading via scp to a non existent dir must fail") + + err = os.RemoveAll(testBaseDirDownPath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + + err = os.RemoveAll(testBaseDirPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPStartDirectory(t *testing.T) { + usePubKey := true + startDir := "/sta rt/dir" + u := getTestUser(usePubKey) + u.Filters.StartDirectory = startDir + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:", user.Username) + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.NoError(t, err) + // check that the file is in the start directory + _, err = os.Stat(filepath.Join(user.HomeDir, startDir, testFileName)) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSCPPatternsFilter(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(131072) + testFilePath := filepath.Join(homeBasePath, testFileName) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + user.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/", + AllowedPatterns: []string{"*.zip"}, + DeniedPatterns: []string{}, + }, + } + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.Error(t, err, "scp download must fail") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err, "scp upload must fail") + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = os.Stat(localPath) + if err == nil { + err = os.Remove(localPath) + assert.NoError(t, err) + } + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSCPTransferQuotaLimits(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(550000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.NoError(t, err) + // error while download is active + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.Error(t, err) + // error before starting the download + err = scpDownload(localDownloadPath, remoteDownPath, false, false) + assert.Error(t, err) + // error while upload is active + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err) + // error before starting the upload + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err) + + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.UsedDownloadDataTransfer, int64(1024*1024)) + if !assert.Greater(t, user.UsedUploadDataTransfer, int64(1024*1024), user.UsedDownloadDataTransfer) { + printLatestLogs(30) + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSCPUploadMaxSize(t *testing.T) { + testFileSize := int64(65535) + usePubKey := true + u := getTestUser(usePubKey) + u.Filters.MaxUploadFileSize = testFileSize + 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testFilePath1, remoteUpPath, false, false) + assert.Error(t, err) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSCPVirtualFolders(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + mappedPath := filepath.Join(os.TempDir(), "vdir") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testBaseDirName := "test_dir" + testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) + testBaseDirDownName := "test_dir_down" + testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) + testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) + testFileSize := int64(131074) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize) + assert.NoError(t, err) + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath) + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.NoError(t, err) + err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(testBaseDirPath) + assert.NoError(t, err) + err = os.RemoveAll(testBaseDirDownPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestSCPNestedFolders(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + baseUser, resp, err := httpdtest.AddUser(getTestUser(false), http.StatusCreated) + assert.NoError(t, err, string(resp)) + usePubKey := true + u := getTestUser(usePubKey) + u.HomeDir += "_folders" + u.Username += "_folders" + mappedPathSFTP := filepath.Join(os.TempDir(), "sftp") + folderNameSFTP := filepath.Base(mappedPathSFTP) + vdirSFTPPath := "/vdir/sftp" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + }, + VirtualPath: vdirSFTPPath, + }) + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + }, + VirtualPath: vdirCryptPath, + }) + + f1 := vfs.BaseVirtualFolder{ + Name: folderNameSFTP, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: baseUser.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + + user, resp, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err, string(resp)) + baseDirDownPath := filepath.Join(os.TempDir(), "basedir-down") + err = os.Mkdir(baseDirDownPath, os.ModePerm) + assert.NoError(t, err) + baseDir := filepath.Join(os.TempDir(), "basedir") + err = os.Mkdir(baseDir, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(baseDir, vdirSFTPPath), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Join(baseDir, vdirCryptPath), os.ModePerm) + assert.NoError(t, err) + err = createTestFile(filepath.Join(baseDir, vdirSFTPPath, testFileName), 32768) + assert.NoError(t, err) + err = createTestFile(filepath.Join(baseDir, vdirCryptPath, testFileName), 65535) + assert.NoError(t, err) + err = createTestFile(filepath.Join(baseDir, "vdir", testFileName), 65536) + assert.NoError(t, err) + + remoteRootPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) + assert.NoError(t, err) + + conn, client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer conn.Close() + defer client.Close() + info, err := client.Stat(path.Join(vdirCryptPath, testFileName)) + assert.NoError(t, err) + assert.Equal(t, int64(65535), info.Size()) + info, err = client.Stat(path.Join(vdirSFTPPath, testFileName)) + assert.NoError(t, err) + assert.Equal(t, int64(32768), info.Size()) + info, err = client.Stat(path.Join("/vdir", testFileName)) + assert.NoError(t, err) + assert.Equal(t, int64(65536), info.Size()) + } + + err = scpDownload(baseDirDownPath, remoteRootPath, true, true) + assert.NoError(t, err) + + assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, "vdir", testFileName)) + assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirCryptPath, testFileName)) + assert.FileExists(t, filepath.Join(baseDirDownPath, user.Username, vdirSFTPPath, testFileName)) + + if runtime.GOOS != osWindows { + err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), 0001) + assert.NoError(t, err) + err = scpDownload(baseDirDownPath, remoteRootPath, true, true) + assert.Error(t, err) + err = os.Chmod(filepath.Join(baseUser.GetHomeDir(), testFileName), os.ModePerm) + assert.NoError(t, err) + } + + // now change the password for the base user, so SFTP folder will not work + baseUser.Password = defaultPassword + "_mod" + _, _, err = httpdtest.UpdateUser(baseUser, http.StatusOK, "1") + assert.NoError(t, err) + + err = scpUpload(filepath.Join(baseDir, "vdir"), remoteRootPath, true, false) + assert.Error(t, err) + + err = scpDownload(baseDirDownPath, remoteRootPath, true, true) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameSFTP}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(baseUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(baseUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathSFTP) + assert.NoError(t, err) + err = os.RemoveAll(baseDir) + assert.NoError(t, err) + err = os.RemoveAll(baseDirDownPath) + assert.NoError(t, err) +} + +func TestSCPVirtualFoldersQuota(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaFiles = 100 + mappedPath1 := filepath.Join(os.TempDir(), "vdir1") + folderName1 := filepath.Base(mappedPath1) + vdirPath1 := "/vdir1" + mappedPath2 := filepath.Join(os.TempDir(), "vdir2") + folderName2 := filepath.Base(mappedPath2) + vdirPath2 := "/vdir2" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName1, + }, + VirtualPath: vdirPath1, + QuotaFiles: -1, + QuotaSize: -1, + }) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName2, + }, + VirtualPath: vdirPath2, + QuotaFiles: 0, + QuotaSize: 0, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderName1, + MappedPath: mappedPath1, + } + _, _, err := httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName2, + MappedPath: mappedPath2, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath1, os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(mappedPath2, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testBaseDirName := "test_dir" + testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) + testBaseDirDownName := "test_dir_down" + testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName) + testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName) + testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName) + testFileSize := int64(131074) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize) + assert.NoError(t, err) + remoteDownPath1 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", vdirPath1)) + remoteUpPath1 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath1) + remoteDownPath2 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", vdirPath2)) + remoteUpPath2 := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, vdirPath2) + // we upload two times to test overwrite + err = scpUpload(testBaseDirPath, remoteUpPath1, true, false) + assert.NoError(t, err) + err = scpDownload(testBaseDirDownPath, remoteDownPath1, true, true) + assert.NoError(t, err) + err = scpUpload(testBaseDirPath, remoteUpPath1, true, false) + assert.NoError(t, err) + err = scpDownload(testBaseDirDownPath, remoteDownPath1, true, true) + assert.NoError(t, err) + err = scpUpload(testBaseDirPath, remoteUpPath2, true, false) + assert.NoError(t, err) + err = scpDownload(testBaseDirDownPath, remoteDownPath2, true, true) + assert.NoError(t, err) + expectedQuotaFiles := 2 + expectedQuotaSize := testFileSize * 2 + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + f, _, err := httpdtest.GetFolderByName(folderName1, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), f.UsedQuotaSize) + assert.Equal(t, 0, f.UsedQuotaFiles) + f, _, err = httpdtest.GetFolderByName(folderName2, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaSize, f.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, f.UsedQuotaFiles) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName1}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName2}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(testBaseDirPath) + assert.NoError(t, err) + err = os.RemoveAll(testBaseDirDownPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath1) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath2) + assert.NoError(t, err) +} + +func TestSCPPermsSubDirs(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.Permissions["/somedir"] = []string{dataprovider.PermListItems, dataprovider.PermUpload} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + subPath := filepath.Join(user.GetHomeDir(), "somedir") + testFileSize := int64(65535) + err = os.MkdirAll(subPath, os.ModePerm) + assert.NoError(t, err) + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/somedir") + err = scpDownload(localPath, remoteDownPath, false, true) + assert.Error(t, err, "download a dir with no permissions must fail") + err = os.Remove(subPath) + assert.NoError(t, err) + err = createTestFile(subPath, testFileSize) + assert.NoError(t, err) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.NoError(t, err) + if runtime.GOOS != osWindows { + err = os.Chmod(subPath, 0001) + assert.NoError(t, err) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.Error(t, err, "download a file with no system permissions must fail") + err = os.Chmod(subPath, os.ModePerm) + assert.NoError(t, err) + } + err = os.Remove(localPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPPermCreateDirs(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(32760) + testBaseDirName := "test_dir" + testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName) + testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = createTestFile(testFilePath1, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp/") + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.Error(t, err, "scp upload must fail, the user cannot create files in a missing dir") + err = scpUpload(testBaseDirPath, remoteUpPath, true, false) + assert.Error(t, err, "scp upload must fail, the user cannot create new dirs") + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(testBaseDirPath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPPermUpload(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65536) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp") + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.Error(t, err, "scp upload must fail, the user cannot upload") + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPPermOverwrite(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65536) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp") + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.NoError(t, err) + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.Error(t, err, "scp upload must fail, the user cannot ovewrite existing files") + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPPermDownload(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + u := getTestUser(usePubKey) + u.Permissions["/"] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65537) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.NoError(t, err) + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + err = scpDownload(localPath, remoteDownPath, false, false) + assert.Error(t, err, "scp download must fail, the user cannot download") + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPQuotaSize(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + testFileSize := int64(65535) + u := getTestUser(usePubKey) + u.QuotaFiles = 1 + u.QuotaSize = testFileSize + 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + testFileSize2 := int64(32768) + testFileName2 := "test_file2.dat" + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath2, testFileSize2) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.NoError(t, err) + err = scpUpload(testFilePath, remoteUpPath+".quota", true, false) + assert.Error(t, err, "user is over quota scp upload must fail") + + // now test quota limits while uploading the current file, we have 1 bytes remaining + user.QuotaSize = testFileSize + 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + err = scpUpload(testFilePath2, remoteUpPath+".quota", true, false) + assert.Error(t, err, "user is over quota scp upload must fail") + // overwriting an existing file will work if the resulting size is lesser or equal than the current one + err = scpUpload(testFilePath1, remoteUpPath, true, false) + assert.Error(t, err) + err = scpUpload(testFilePath2, remoteUpPath, true, false) + assert.NoError(t, err) + err = scpUpload(testFilePath, remoteUpPath, true, false) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPEscapeHomeDir(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + err = os.MkdirAll(user.GetHomeDir(), os.ModePerm) + assert.NoError(t, err) + testDir := "testDir" + linkPath := filepath.Join(homeBasePath, defaultUsername, testDir) + err = os.Symlink(homeBasePath, linkPath) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDir, testDir)) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err, "uploading to a dir with a symlink outside home dir must fail") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir, testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + err = scpDownload(localPath, remoteDownPath, false, false) + assert.Error(t, err, "scp download must fail, the requested file has a symlink outside user home") + remoteDownPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir)) + err = scpDownload(homeBasePath, remoteDownPath, false, true) + assert.Error(t, err, "scp download must fail, the requested dir is a symlink outside user home") + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPUploadPaths(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + testDirName := "testDir" + testDirPath := filepath.Join(user.GetHomeDir(), testDirName) + err = os.MkdirAll(testDirPath, os.ModePerm) + assert.NoError(t, err) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testDirName) + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = scpDownload(localPath, remoteDownPath, false, false) + assert.NoError(t, err) + // upload a file to a missing dir + remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testDirName, testFileName)) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err, "scp upload to a missing dir must fail") + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.Remove(localPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPOverwriteDirWithFile(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + testDirPath := filepath.Join(user.GetHomeDir(), testFileName) + err = os.MkdirAll(testDirPath, os.ModePerm) + assert.NoError(t, err) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.Error(t, err, "copying a file over an existing dir must fail") + + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPRemoteToRemote(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + if runtime.GOOS == osWindows { + t.Skip("scp between remote hosts is not supported on Windows") + } + usePubKey := true + user, _, err := httpdtest.AddUser(getTestUser(usePubKey), http.StatusCreated) + assert.NoError(t, err) + u := getTestUser(usePubKey) + u.Username += "1" + u.HomeDir += "1" + user1, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + remote1UpPath := fmt.Sprintf("%v@127.0.0.1:%v", user1.Username, path.Join("/", testFileName)) + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + err = scpUpload(remoteUpPath, remote1UpPath, false, true) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) +} + +func TestSCPErrors(t *testing.T) { + if scpPath == "" { + t.Skip("scp command not found, unable to execute this test") + } + u := getTestUser(true) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileSize := int64(524288) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/") + remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName)) + localPath := filepath.Join(homeBasePath, "scp_download.dat") + err = scpUpload(testFilePath, remoteUpPath, false, false) + assert.NoError(t, err) + user.UploadBandwidth = 512 + user.DownloadBandwidth = 512 + _, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + cmd := getScpDownloadCommand(localPath, remoteDownPath, false, false) + go func() { + err := cmd.Run() + assert.Error(t, err, "SCP download must fail") + }() + waitForActiveTransfers(t) + // wait some additional arbitrary time to wait for transfer activity to happen + // it is need to reach all the code in CheckIdleConnections + time.Sleep(100 * time.Millisecond) + err = cmd.Process.Kill() + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + cmd = getScpUploadCommand(testFilePath, remoteUpPath, false, false) + go func() { + err := cmd.Run() + assert.Error(t, err, "SCP upload must fail") + }() + waitForActiveTransfers(t) + // wait some additional arbitrary time to wait for transfer activity to happen + // it is need to reach all the code in CheckIdleConnections + time.Sleep(100 * time.Millisecond) + err = cmd.Process.Kill() + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, 2*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + err = os.Remove(testFilePath) + assert.NoError(t, err) + os.Remove(localPath) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +// End SCP tests + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + logger.WarnToConsole("tcp server %v not listening: %v", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + logger.InfoToConsole("tcp server %v now listening", address) + conn.Close() + break + } +} + +func getTestGroup() dataprovider.Group { + return dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "test_group", + Description: "test group description", + }, + } +} + +func getTestUser(usePubKey bool) dataprovider.User { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Status: 1, + ExpirationDate: 0, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = allPerms + if usePubKey { + user.PublicKeys = []string{testPubKey} + user.Password = "" + } + return user +} + +func getTestSFTPUser(usePubKey bool) dataprovider.User { + u := getTestUser(usePubKey) + u.Username = defaultSFTPUsername + u.FsConfig.Provider = sdk.SFTPFilesystemProvider + u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + if usePubKey { + u.FsConfig.SFTPConfig.PrivateKey = kms.NewPlainSecret(testPrivateKey) + u.FsConfig.SFTPConfig.Fingerprints = hostKeyFPs + } + return u +} + +func runSSHCommand(command string, user dataprovider.User, usePubKey bool) ([]byte, error) { + var sshSession *ssh.Session + var output []byte + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if usePubKey { + key, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + return output, err + } + config.Auth = []ssh.AuthMethod{ssh.PublicKeys(key)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return output, err + } + defer conn.Close() + sshSession, err = conn.NewSession() + if err != nil { + return output, err + } + var stdout, stderr bytes.Buffer + sshSession.Stdout = &stdout + sshSession.Stderr = &stderr + err = sshSession.Run(command) + if err != nil { + return nil, fmt.Errorf("failed to run command %v: %v", command, stderr.Bytes()) + } + return stdout.Bytes(), err +} + +func getSignerForUserCert(certBytes []byte) (ssh.Signer, error) { + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + return nil, err + } + cert, _, _, _, err := ssh.ParseAuthorizedKey(certBytes) //nolint:dogsled + if err != nil { + return nil, err + } + return ssh.NewCertSigner(cert.(*ssh.Certificate), signer) +} + +func getSftpClientWithAddr(user dataprovider.User, usePubKey bool, addr string) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if usePubKey { + signer, err := ssh.ParsePrivateKey([]byte(testPrivateKey)) + if err != nil { + return nil, nil, err + } + config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)} + } else { + if user.Password != "" { + if user.Password == "empty" { + config.Auth = []ssh.AuthMethod{ssh.Password("")} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + } + conn, err := ssh.Dial("tcp", addr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func getSftpClient(user dataprovider.User, usePubKey bool) (*ssh.Client, *sftp.Client, error) { + return getSftpClientWithAddr(user, usePubKey, sftpServerAddr) +} + +func getKeyboardInteractiveSftpClient(user dataprovider.User, answers []string) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(_, _ string, _ []string, _ []bool) ([]string, error) { + return answers, nil + }), + }, + Timeout: 5 * time.Second, + } + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return nil, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func getCustomAuthSftpClient(user dataprovider.User, authMethods []ssh.AuthMethod, addr string) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Auth: authMethods, + Timeout: 5 * time.Second, + } + var err error + var conn *ssh.Client + if addr != "" { + conn, err = ssh.Dial("tcp", addr, config) + } else { + conn, err = ssh.Dial("tcp", sftpServerAddr, config) + } + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func createTestFile(path string, size int64) error { + baseDir := filepath.Dir(path) + if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(baseDir, os.ModePerm) + if err != nil { + return err + } + } + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + return os.WriteFile(path, content, os.ModePerm) +} + +func appendToTestFile(path string, size int64) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, os.ModePerm) + if err != nil { + return err + } + defer f.Close() + written, err := io.Copy(f, bytes.NewReader(content)) + if err != nil { + return err + } + if written != size { + return fmt.Errorf("write error, written: %v/%v", written, size) + } + return nil +} + +func checkBasicSFTP(client *sftp.Client) error { + _, err := client.Getwd() + if err != nil { + return err + } + _, err = client.ReadDir(".") + return err +} + +func writeSFTPFile(name string, size int64, client *sftp.Client) error { + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + f, err := client.Create(name) + if err != nil { + return err + } + _, err = io.Copy(f, bytes.NewBuffer(content)) + if err != nil { + f.Close() + return err + } + err = f.Close() + if err != nil { + return err + } + info, err := client.Stat(name) + if err != nil { + return err + } + if info.Size() != size { + return fmt.Errorf("file size mismatch, wanted %v, actual %v", size, info.Size()) + } + return nil +} + +func sftpUploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + destFile, err := client.Create(remoteDestPath) + if err != nil { + return err + } + _, err = io.Copy(destFile, srcFile) + if err != nil { + destFile.Close() + return err + } + // we need to close the file to trigger the server side close method + // we cannot defer closing otherwise Stat will fail for upload atomic mode + destFile.Close() + if expectedSize > 0 { + fi, err := client.Stat(remoteDestPath) + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func sftpUploadResumeFile(localSourcePath string, remoteDestPath string, expectedSize int64, invalidOffset bool, //nolint:unparam + client *sftp.Client) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + fi, err := client.Lstat(remoteDestPath) + if err != nil { + return err + } + if !invalidOffset { + _, err = srcFile.Seek(fi.Size(), 0) + if err != nil { + return err + } + } + destFile, err := client.OpenFile(remoteDestPath, os.O_WRONLY|os.O_APPEND) + if err != nil { + return err + } + if !invalidOffset { + _, err = destFile.Seek(fi.Size(), 0) + if err != nil { + return err + } + } + _, err = io.Copy(destFile, srcFile) + if err != nil { + destFile.Close() + return err + } + // we need to close the file to trigger the server side close method + // we cannot defer closing otherwise Stat will fail for upload atomic mode + destFile.Close() + if expectedSize > 0 { + fi, err := client.Lstat(remoteDestPath) + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func sftpDownloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) error { + downloadDest, err := os.Create(localDestPath) + if err != nil { + return err + } + defer downloadDest.Close() + sftpSrcFile, err := client.Open(remoteSourcePath) + if err != nil { + return err + } + defer sftpSrcFile.Close() + _, err = io.Copy(downloadDest, sftpSrcFile) + if err != nil { + return err + } + err = downloadDest.Sync() + if err != nil { + return err + } + if expectedSize > 0 { + fi, err := downloadDest.Stat() + if err != nil { + return err + } + if fi.Size() != expectedSize { + return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", fi.Size(), expectedSize) + } + } + return err +} + +func sftpUploadNonBlocking(localSourcePath string, remoteDestPath string, expectedSize int64, client *sftp.Client) <-chan error { + c := make(chan error, 1) + go func() { + c <- sftpUploadFile(localSourcePath, remoteDestPath, expectedSize, client) + }() + return c +} + +func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expectedSize int64, client *sftp.Client) <-chan error { + c := make(chan error, 1) + go func() { + c <- sftpDownloadFile(remoteSourcePath, localDestPath, expectedSize, client) + }() + return c +} + +func scpUpload(localPath, remotePath string, preserveTime, remoteToRemote bool) error { + cmd := getScpUploadCommand(localPath, remotePath, preserveTime, remoteToRemote) + return cmd.Run() +} + +func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error { + cmd := getScpDownloadCommand(localPath, remotePath, preserveTime, recursive) + return cmd.Run() +} + +func getScpDownloadCommand(localPath, remotePath string, preserveTime, recursive bool) *exec.Cmd { + var args []string + if preserveTime { + args = append(args, "-p") + } + if recursive { + args = append(args, "-r") + } + if scpForce { + args = append(args, "-O") + } + args = append(args, "-P") + args = append(args, "2022") + args = append(args, "-o") + args = append(args, "StrictHostKeyChecking=no") + args = append(args, "-i") + args = append(args, privateKeyPath) + args = append(args, remotePath) + args = append(args, localPath) + return exec.Command(scpPath, args...) +} + +func getScpUploadCommand(localPath, remotePath string, preserveTime, remoteToRemote bool) *exec.Cmd { + var args []string + if remoteToRemote { + args = append(args, "-3") + } + if preserveTime { + args = append(args, "-p") + } + fi, err := os.Stat(localPath) + if err == nil { + if fi.IsDir() { + args = append(args, "-r") + } + } + if scpForce { + args = append(args, "-O") + } + args = append(args, "-P") + args = append(args, "2022") + args = append(args, "-o") + args = append(args, "StrictHostKeyChecking=no") + args = append(args, "-o") + args = append(args, "HostKeyAlgorithms=+ssh-rsa") + args = append(args, "-i") + args = append(args, privateKeyPath) + args = append(args, localPath) + args = append(args, remotePath) + return exec.Command(scpPath, args...) +} + +func computeHashForFile(hasher hash.Hash, path string) (string, error) { + hash := "" + f, err := os.Open(path) + if err != nil { + return hash, err + } + defer f.Close() + _, err = io.Copy(hasher, f) + if err == nil { + hash = fmt.Sprintf("%x", hasher.Sum(nil)) + } + return hash, err +} + +func waitForActiveTransfers(t *testing.T) { + assert.Eventually(t, func() bool { + for _, stat := range common.Connections.GetStats("") { + if len(stat.Transfers) > 0 { + return true + } + } + return false + }, 1*time.Second, 50*time.Millisecond) +} + +func checkSystemCommands() { + var err error + gitPath, err = exec.LookPath("git") + if err != nil { + logger.Warn(logSender, "", "unable to get git command. GIT tests will be skipped, err: %v", err) + logger.WarnToConsole("unable to get git command. GIT tests will be skipped, err: %v", err) + gitPath = "" + } + + sshPath, err = exec.LookPath("ssh") + if err != nil { + logger.Warn(logSender, "", "unable to get ssh command. GIT tests will be skipped, err: %v", err) + logger.WarnToConsole("unable to get ssh command. GIT tests will be skipped, err: %v", err) + gitPath = "" + } + hookCmdPath, err = exec.LookPath("true") + if err != nil { + logger.Warn(logSender, "", "unable to get hook command: %v", err) + logger.WarnToConsole("unable to get hook command: %v", err) + } + scpPath, err = exec.LookPath("scp") + if err != nil { + logger.Warn(logSender, "", "unable to get scp command. SCP tests will be skipped, err: %v", err) + logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err) + scpPath = "" + } else { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, scpPath, "-O") + out, _ := cmd.CombinedOutput() + scpForce = !strings.Contains(string(out), "option -- O") + } +} + +func getKeyboardInteractiveScriptForBuiltinChecks(addPasscode bool, result int) []byte { + content := []byte("#!/bin/sh\n\n") + echos := []bool{false} + q, _ := json.Marshal([]string{"Password: "}) + e, _ := json.Marshal(echos) + content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v,\"check_password\":1}'\n", string(q), string(e)))...) + content = append(content, []byte("read ANSWER\n\n")...) + content = append(content, []byte("if test \"$ANSWER\" != \"OK\"; then\n")...) + content = append(content, []byte("exit 1\n")...) + content = append(content, []byte("fi\n\n")...) + if addPasscode { + q, _ := json.Marshal([]string{"Passcode: "}) + content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v,\"check_password\":2}'\n", string(q), string(e)))...) + content = append(content, []byte("read ANSWER\n\n")...) + content = append(content, []byte("if test \"$ANSWER\" != \"OK\"; then\n")...) + content = append(content, []byte("exit 1\n")...) + content = append(content, []byte("fi\n\n")...) + } + content = append(content, []byte(fmt.Sprintf("echo '{\"auth_result\":%v}'\n", result))...) + return content +} + +func getKeyboardInteractiveScriptContent(questions []string, sleepTime int, nonJSONResponse bool, result int) []byte { + content := []byte("#!/bin/sh\n\n") + q, _ := json.Marshal(questions) + echos := []bool{} + for index := range questions { + echos = append(echos, index%2 == 0) + } + e, _ := json.Marshal(echos) + if nonJSONResponse { + content = append(content, []byte(fmt.Sprintf("echo 'questions: %v echos: %v\n", string(q), string(e)))...) + } else { + content = append(content, []byte(fmt.Sprintf("echo '{\"questions\":%v,\"echos\":%v}'\n", string(q), string(e)))...) + } + for index := range questions { + content = append(content, []byte(fmt.Sprintf("read ANSWER%v\n", index))...) + } + if sleepTime > 0 { + content = append(content, []byte(fmt.Sprintf("sleep %v\n", sleepTime))...) + } + content = append(content, []byte(fmt.Sprintf("echo '{\"auth_result\":%v}'\n", result))...) + return content +} + +func getExtAuthScriptContent(user dataprovider.User, nonJSONResponse, emptyResponse bool, username string) []byte { + extAuthContent := []byte("#!/bin/sh\n\n") + if emptyResponse { + return extAuthContent + } + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%v\"; then\n", user.Username))...) + if username != "" { + user.Username = username + } + u, _ := json.Marshal(user) + if nonJSONResponse { + extAuthContent = append(extAuthContent, []byte("echo 'text response'\n")...) + } else { + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) + } + extAuthContent = append(extAuthContent, []byte("else\n")...) + if nonJSONResponse { + extAuthContent = append(extAuthContent, []byte("echo 'text response'\n")...) + } else { + extAuthContent = append(extAuthContent, []byte("echo '{\"username\":\"\"}'\n")...) + } + extAuthContent = append(extAuthContent, []byte("fi\n")...) + return extAuthContent +} + +func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { + content := []byte("#!/bin/sh\n\n") + if nonJSONResponse { + content = append(content, []byte("echo 'text response'\n")...) + return content + } + if len(user.Username) > 0 { + u, _ := json.Marshal(user) + content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) + } + return content +} + +func getExitCodeScriptContent(exitCode int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) + return content +} + +func getCheckPwdScriptsContents(status int, toVerify string) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("echo '{\"status\":%v,\"to_verify\":\"%v\"}'\n", status, toVerify))...) + if status > 0 { + content = append(content, []byte("exit 0")...) + } else { + content = append(content, []byte("exit 1")...) + } + return content +} + +func printLatestLogs(maxNumberOfLines int) { + var lines []string + f, err := os.Open(logFilePath) + if err != nil { + return + } + defer f.Close() + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()+"\r\n") + for len(lines) > maxNumberOfLines { + lines = lines[1:] + } + } + if scanner.Err() != nil { + logger.WarnToConsole("Unable to print latest logs: %v", scanner.Err()) + return + } + for _, line := range lines { + logger.DebugToConsole("%s", line) + } +} + +func getHostKeyFingerprint(name string) (string, error) { + privateBytes, err := os.ReadFile(name) + if err != nil { + return "", err + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return "", err + } + return ssh.FingerprintSHA256(private.PublicKey()), nil +} + +func getHostKeysFingerprints(hostKeys []string) { + for _, k := range hostKeys { + fp, err := getHostKeyFingerprint(filepath.Join(configDir, k)) + if err != nil { + logger.ErrorToConsole("unable to get fingerprint for host key %q: %v", k, err) + os.Exit(1) + } + hostKeyFPs = append(hostKeyFPs, fp) + } +} + +func createInitialFiles(scriptArgs string) { + pubKeyPath = filepath.Join(homeBasePath, "ssh_key.pub") + privateKeyPath = filepath.Join(homeBasePath, "ssh_key") + trustedCAUserKey = filepath.Join(homeBasePath, "ca_user_key") + gitWrapPath = filepath.Join(homeBasePath, "gitwrap.sh") + extAuthPath = filepath.Join(homeBasePath, "extauth.sh") + preLoginPath = filepath.Join(homeBasePath, "prelogin.sh") + postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") + checkPwdPath = filepath.Join(homeBasePath, "checkpwd.sh") + preDownloadPath = filepath.Join(homeBasePath, "predownload.sh") + preUploadPath = filepath.Join(homeBasePath, "preupload.sh") + revokeUserCerts = filepath.Join(homeBasePath, "revoked_certs.json") + err := os.WriteFile(pubKeyPath, []byte(testPubKey+"\n"), 0600) + if err != nil { + logger.WarnToConsole("unable to save public key to file: %v", err) + } + err = os.WriteFile(privateKeyPath, []byte(testPrivateKey+"\n"), 0600) + if err != nil { + logger.WarnToConsole("unable to save private key to file: %v", err) + } + err = os.WriteFile(gitWrapPath, []byte(fmt.Sprintf("%v -i %v -oStrictHostKeyChecking=no %v\n", + sshPath, privateKeyPath, scriptArgs)), os.ModePerm) + if err != nil { + logger.WarnToConsole("unable to save gitwrap shell script: %v", err) + } + err = os.WriteFile(trustedCAUserKey, []byte(testCAUserKey), 0600) + if err != nil { + logger.WarnToConsole("unable to save trusted CA user key: %v", err) + } + err = os.WriteFile(revokeUserCerts, []byte(`[]`), 0644) + if err != nil { + logger.WarnToConsole("unable to save revoked user certs: %v", err) + } +} diff --git a/internal/sftpd/ssh_cmd.go b/internal/sftpd/ssh_cmd.go new file mode 100644 index 00000000..c82f3b89 --- /dev/null +++ b/internal/sftpd/ssh_cmd.go @@ -0,0 +1,325 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "fmt" + "hash" + "io" + "runtime/debug" + "slices" + "strings" + "time" + + "github.com/google/shlex" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + scpCmdName = "scp" + sshCommandLogSender = "SSHCommand" +) + +type sshCommand struct { + command string + args []string + connection *Connection + startTime time.Time +} + +func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommands []string) bool { + var msg sshSubsystemExecMsg + if err := ssh.Unmarshal(payload, &msg); err == nil { + name, args, err := parseCommandPayload(msg.Command) + connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v", + name, args, len(args), connection.User.Username, err) + if err == nil && slices.Contains(enabledSSHCommands, name) { + connection.command = msg.Command + if name == scpCmdName && len(args) >= 2 { + connection.SetProtocol(common.ProtocolSCP) + scpCommand := scpCommand{ + sshCommand: sshCommand{ + command: name, + connection: connection, + startTime: time.Now(), + args: args}, + } + go scpCommand.handle() //nolint:errcheck + return true + } + if name != scpCmdName { + connection.SetProtocol(common.ProtocolSSH) + sshCommand := sshCommand{ + command: name, + connection: connection, + startTime: time.Now(), + args: args, + } + go sshCommand.handle() //nolint:errcheck + return true + } + } else { + connection.Log(logger.LevelInfo, "ssh command not enabled/supported: %q", name) + } + } + err := connection.CloseFS() + connection.Log(logger.LevelError, "unable to unmarshal ssh command, close fs, err: %v", err) + return false +} + +func (c *sshCommand) handle() (err error) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in handle ssh command: %q stack trace: %v", r, string(debug.Stack())) + err = common.ErrGenericFailure + } + }() + if err := common.Connections.Add(c.connection); err != nil { + defer c.connection.CloseFS() //nolint:errcheck + logger.Info(logSender, "", "unable to add SSH command connection: %v", err) + return c.sendErrorResponse(err) + } + defer common.Connections.Remove(c.connection.GetID()) + + c.connection.UpdateLastActivity() + if slices.Contains(sshHashCommands, c.command) { + return c.handleHashCommands() + } else if c.command == "cd" { + c.sendExitStatus(nil) + } else if c.command == "pwd" { + // hard coded response to the start directory + c.connection.channel.Write([]byte(util.CleanPath(c.connection.User.Filters.StartDirectory) + "\n")) //nolint:errcheck + c.sendExitStatus(nil) + } else if c.command == "sftpgo-copy" { + return c.handleSFTPGoCopy() + } else if c.command == "sftpgo-remove" { + return c.handleSFTPGoRemove() + } + return +} + +func (c *sshCommand) handleSFTPGoCopy() error { + sshSourcePath := c.getSourcePath() + sshDestPath := c.getDestPath() + if sshSourcePath == "" || sshDestPath == "" || len(c.args) != 2 { + return c.sendErrorResponse(errors.New("usage sftpgo-copy ")) + } + c.connection.Log(logger.LevelDebug, "requested copy %q -> %q", sshSourcePath, sshDestPath) + if err := c.connection.Copy(sshSourcePath, sshDestPath); err != nil { + return c.sendErrorResponse(err) + } + c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck + c.sendExitStatus(nil) + return nil +} + +func (c *sshCommand) handleSFTPGoRemove() error { + sshDestPath, err := c.getRemovePath() + if err != nil { + return c.sendErrorResponse(err) + } + if err := c.connection.RemoveAll(sshDestPath); err != nil { + return c.sendErrorResponse(err) + } + c.connection.channel.Write([]byte("OK\n")) //nolint:errcheck + c.sendExitStatus(nil) + return nil +} + +func (c *sshCommand) handleHashCommands() error { + var h hash.Hash + switch c.command { + case "md5sum": + h = md5.New() + case "sha1sum": + h = sha1.New() + case "sha256sum": + h = sha256.New() + case "sha384sum": + h = sha512.New384() + default: + h = sha512.New() + } + var response string + if len(c.args) == 0 { + // without args we need to read the string to hash from stdin + buf := make([]byte, 4096) + n, err := c.connection.channel.Read(buf) + if err != nil && err != io.EOF { + return c.sendErrorResponse(err) + } + h.Write(buf[:n]) //nolint:errcheck + response = fmt.Sprintf("%x -\n", h.Sum(nil)) + } else { + sshPath := c.getDestPath() + if ok, policy := c.connection.User.IsFileAllowed(sshPath); !ok { + c.connection.Log(logger.LevelInfo, "hash not allowed for file %q", sshPath) + return c.sendErrorResponse(c.connection.GetErrorForDeniedFile(policy)) + } + fs, fsPath, err := c.connection.GetFsAndResolvedPath(sshPath) + if err != nil { + return c.sendErrorResponse(err) + } + if !c.connection.User.HasPerm(dataprovider.PermListItems, sshPath) { + return c.sendErrorResponse(c.connection.GetPermissionDeniedError()) + } + hash, err := c.computeHashForFile(fs, h, fsPath) + if err != nil { + return c.sendErrorResponse(c.connection.GetFsError(fs, err)) + } + response = fmt.Sprintf("%v %v\n", hash, sshPath) + } + c.connection.channel.Write([]byte(response)) //nolint:errcheck + c.sendExitStatus(nil) + return nil +} + +// for the supported commands, the destination path, if any, is the last argument +func (c *sshCommand) getDestPath() string { + if len(c.args) == 0 { + return "" + } + return c.cleanCommandPath(c.args[len(c.args)-1]) +} + +// for the supported commands, the destination path, if any, is the second-last argument +func (c *sshCommand) getSourcePath() string { + if len(c.args) < 2 { + return "" + } + return c.cleanCommandPath(c.args[len(c.args)-2]) +} + +func (c *sshCommand) cleanCommandPath(name string) string { + name = strings.Trim(name, "'") + name = strings.Trim(name, "\"") + result := c.connection.User.GetCleanedPath(name) + if strings.HasSuffix(name, "/") && !strings.HasSuffix(result, "/") { + result += "/" + } + return result +} + +func (c *sshCommand) getRemovePath() (string, error) { + sshDestPath := c.getDestPath() + if sshDestPath == "" || len(c.args) != 1 { + err := errors.New("usage sftpgo-remove ") + return "", err + } + if len(sshDestPath) > 1 { + sshDestPath = strings.TrimSuffix(sshDestPath, "/") + } + return sshDestPath, nil +} + +func (c *sshCommand) sendErrorResponse(err error) error { + errorString := fmt.Sprintf("%v: %v %v\n", c.command, c.getDestPath(), err) + c.connection.channel.Write([]byte(errorString)) //nolint:errcheck + c.sendExitStatus(err) + return err +} + +func (c *sshCommand) sendExitStatus(err error) { + status := uint32(0) + vCmdPath := c.getDestPath() + cmdPath := "" + targetPath := "" + vTargetPath := "" + if c.command == "sftpgo-copy" { + vTargetPath = vCmdPath + vCmdPath = c.getSourcePath() + } + if err != nil { + status = uint32(1) + c.connection.Log(logger.LevelError, "command failed: %q args: %v user: %s err: %v", + c.command, c.args, c.connection.User.Username, err) + } + exitStatus := sshSubsystemExitStatus{ + Status: status, + } + _, errClose := c.connection.channel.(ssh.Channel).SendRequest("exit-status", false, ssh.Marshal(&exitStatus)) + c.connection.Log(logger.LevelDebug, "exit status sent, error: %v", errClose) + c.connection.channel.Close() + // for scp we notify single uploads/downloads + if c.command != scpCmdName { + elapsed := time.Since(c.startTime).Nanoseconds() / 1000000 + metric.SSHCommandCompleted(err) + if vCmdPath != "" { + _, p, errFs := c.connection.GetFsAndResolvedPath(vCmdPath) + if errFs == nil { + cmdPath = p + } + } + if vTargetPath != "" { + _, p, errFs := c.connection.GetFsAndResolvedPath(vTargetPath) + if errFs == nil { + targetPath = p + } + } + common.ExecuteActionNotification(c.connection.BaseConnection, common.OperationSSHCmd, cmdPath, vCmdPath, //nolint:errcheck + targetPath, vTargetPath, c.command, 0, err, elapsed, nil) + if err == nil { + logger.CommandLog(sshCommandLogSender, cmdPath, targetPath, c.connection.User.Username, "", c.connection.ID, + common.ProtocolSSH, -1, -1, "", "", c.connection.command, -1, c.connection.GetLocalAddress(), + c.connection.GetRemoteAddress(), elapsed) + } + } +} + +func (c *sshCommand) computeHashForFile(fs vfs.Fs, hasher hash.Hash, path string) (string, error) { + hash := "" + f, r, _, err := fs.Open(path, 0) + if err != nil { + return hash, err + } + var reader io.ReadCloser + if f != nil { + reader = f + } else { + reader = r + } + defer reader.Close() + _, err = io.Copy(hasher, reader) + if err == nil { + hash = fmt.Sprintf("%x", hasher.Sum(nil)) + } + return hash, err +} + +func parseCommandPayload(command string) (string, []string, error) { + parts, err := shlex.Split(command) + if err == nil && len(parts) == 0 { + err = fmt.Errorf("invalid command: %q", command) + } + if err != nil { + return "", []string{}, err + } + if len(parts) < 2 { + return parts[0], []string{}, nil + } + return parts[0], parts[1:], nil +} diff --git a/internal/sftpd/transfer.go b/internal/sftpd/transfer.go new file mode 100644 index 00000000..465ad8c3 --- /dev/null +++ b/internal/sftpd/transfer.go @@ -0,0 +1,193 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package sftpd + +import ( + "fmt" + "io" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +type writerAtCloser interface { + io.WriterAt + io.Closer +} + +type readerAtCloser interface { + io.ReaderAt + io.Closer +} + +type failingReader struct { + innerReader readerAtCloser + errRead error +} + +func (r *failingReader) ReadAt(_ []byte, _ int64) (n int, err error) { + return 0, r.errRead +} + +func (r *failingReader) Close() error { + if r.innerReader == nil { + return nil + } + return r.innerReader.Close() +} + +// transfer defines the transfer details. +// It implements the io.ReaderAt and io.WriterAt interfaces to handle SFTP downloads and uploads +type transfer struct { + *common.BaseTransfer + writerAt writerAtCloser + readerAt readerAtCloser + isFinished bool +} + +func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader, + errForRead error) *transfer { + var writer writerAtCloser + var reader readerAtCloser + if baseTransfer.File != nil { + writer = baseTransfer.File + if errForRead == nil { + reader = baseTransfer.File + } else { + reader = &failingReader{ + innerReader: baseTransfer.File, + errRead: errForRead, + } + } + } else if pipeWriter != nil { + writer = pipeWriter + } else if pipeReader != nil { + if errForRead == nil { + reader = pipeReader + } else { + reader = &failingReader{ + innerReader: pipeReader, + errRead: errForRead, + } + } + } + if baseTransfer.File == nil && errForRead != nil && pipeReader == nil { + reader = &failingReader{ + innerReader: nil, + errRead: errForRead, + } + } + return &transfer{ + BaseTransfer: baseTransfer, + writerAt: writer, + readerAt: reader, + isFinished: false, + } +} + +// ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent. +// It handles download bandwidth throttling too +func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) { + t.Connection.UpdateLastActivity() + + n, err = t.readerAt.ReadAt(p, off) + t.BytesSent.Add(int64(n)) + + if err == nil { + err = t.CheckRead() + } + if err != nil && err != io.EOF { + if t.GetType() == common.TransferDownload { + t.TransferError(err) + } + err = t.ConvertError(err) + return + } + t.HandleThrottle() + return +} + +// WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received. +// It handles upload bandwidth throttling too +func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) { + t.Connection.UpdateLastActivity() + if off < t.MinWriteOffset { + err := fmt.Errorf("invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset) + t.TransferError(err) + return 0, err + } + + n, err = t.writerAt.WriteAt(p, off) + t.BytesReceived.Add(int64(n)) + + if err == nil { + err = t.CheckWrite() + } + if err != nil { + t.TransferError(err) + err = t.ConvertError(err) + return + } + t.HandleThrottle() + return +} + +// Close it is called when the transfer is completed. +// It closes the underlying file, logs the transfer info, updates the user quota (for uploads) +// and executes any defined action. +// If there is an error no action will be executed and, in atomic mode, we try to delete +// the temporary file +func (t *transfer) Close() error { + if err := t.setFinished(); err != nil { + return err + } + err := t.closeIO() + errBaseClose := t.BaseTransfer.Close() + if errBaseClose != nil { + err = errBaseClose + } + return t.Connection.GetFsError(t.Fs, err) +} + +func (t *transfer) closeIO() error { + var err error + if t.File != nil { + err = t.File.Close() + } else if t.writerAt != nil { + err = t.writerAt.Close() + t.Lock() + // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic + if err != nil && t.ErrTransfer == nil { + t.ErrTransfer = err + } + t.Unlock() + } else if t.readerAt != nil { + err = t.readerAt.Close() + if metadater, ok := t.readerAt.(vfs.Metadater); ok { + t.SetMetadata(metadater.Metadata()) + } + } + return err +} + +func (t *transfer) setFinished() error { + t.Lock() + defer t.Unlock() + if t.isFinished { + return common.ErrTransferClosed + } + t.isFinished = true + return nil +} diff --git a/internal/smtp/oauth2.go b/internal/smtp/oauth2.go new file mode 100644 index 00000000..87997947 --- /dev/null +++ b/internal/smtp/oauth2.go @@ -0,0 +1,165 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package smtp provides supports for sending emails +package smtp + +import ( + "context" + "errors" + "fmt" + "slices" + "sync" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/microsoft" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// Supported OAuth2 providers +const ( + OAuth2ProviderGoogle = iota + OAuth2ProviderMicrosoft +) + +var supportedOAuth2Providers = []int{OAuth2ProviderGoogle, OAuth2ProviderMicrosoft} + +// OAuth2Config defines OAuth2 settings +type OAuth2Config struct { + Provider int `json:"provider" mapstructure:"provider"` + // Tenant for Microsoft provider, if empty "common" is used + Tenant string `json:"tenant" mapstructure:"tenant"` + // ClientID is the application's ID + ClientID string `json:"client_id" mapstructure:"client_id"` + // ClientSecret is the application's secret + ClientSecret string `json:"client_secret" mapstructure:"client_secret"` + // Token to use to get/renew access tokens + RefreshToken string `json:"refresh_token" mapstructure:"refresh_token"` + mu *sync.RWMutex + config *oauth2.Config + accessToken *oauth2.Token +} + +// Validate validates and initializes the configuration +func (c *OAuth2Config) Validate() error { + if !slices.Contains(supportedOAuth2Providers, c.Provider) { + return fmt.Errorf("smtp oauth2: unsupported provider %d", c.Provider) + } + if c.ClientID == "" { + return errors.New("smtp oauth2: client id is required") + } + if c.ClientSecret == "" { + return errors.New("smtp oauth2: client secret is required") + } + if c.RefreshToken == "" { + return errors.New("smtp oauth2: refresh token is required") + } + c.initialize() + return nil +} + +func (c *OAuth2Config) isEqual(other *OAuth2Config) bool { + if c.Provider != other.Provider { + return false + } + if c.Tenant != other.Tenant { + return false + } + if c.ClientID != other.ClientID { + return false + } + if c.ClientSecret != other.ClientSecret { + return false + } + if c.RefreshToken != other.RefreshToken { + return false + } + return true +} + +func (c *OAuth2Config) getAccessToken() (string, error) { + c.mu.RLock() + if c.accessToken.Expiry.After(time.Now().Add(30 * time.Second)) { + accessToken := c.accessToken.AccessToken + c.mu.RUnlock() + + return accessToken, nil + } + logger.Debug(logSender, "", "renew oauth2 token required, current token expires at %s", c.accessToken.Expiry) + token := new(oauth2.Token) + *token = *c.accessToken + c.mu.RUnlock() + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + newToken, err := c.config.TokenSource(ctx, token).Token() + if err != nil { + logger.Error(logSender, "", "unable to get new token: %v", err) + return "", err + } + accessToken := newToken.AccessToken + refreshToken := newToken.RefreshToken + if refreshToken != "" && refreshToken != token.RefreshToken { + c.mu.Lock() + c.RefreshToken = refreshToken + c.accessToken = newToken + c.mu.Unlock() + + logger.Debug(logSender, "", "oauth2 refresh token changed") + go updateRefreshToken(refreshToken) + } + if accessToken != token.AccessToken { + c.mu.Lock() + c.accessToken = newToken + c.mu.Unlock() + + logger.Debug(logSender, "", "new oauth2 token saved, expires at %s", c.accessToken.Expiry) + } + return accessToken, nil +} + +func (c *OAuth2Config) initialize() { + c.mu = new(sync.RWMutex) + c.config = c.GetOAuth2() + c.accessToken = &oauth2.Token{ + TokenType: "Bearer", + RefreshToken: c.RefreshToken, + } +} + +// GetOAuth2 returns the oauth2 configuration for the provided parameters. +func (c *OAuth2Config) GetOAuth2() *oauth2.Config { + var endpoint oauth2.Endpoint + var scopes []string + + switch c.Provider { + case OAuth2ProviderMicrosoft: + endpoint = microsoft.AzureADEndpoint(c.Tenant) + scopes = []string{"offline_access", "https://outlook.office.com/SMTP.Send"} + default: + endpoint = google.Endpoint + scopes = []string{"https://mail.google.com/"} + } + + return &oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + Scopes: scopes, + Endpoint: endpoint, + } +} diff --git a/internal/smtp/smtp.go b/internal/smtp/smtp.go new file mode 100644 index 00000000..94984249 --- /dev/null +++ b/internal/smtp/smtp.go @@ -0,0 +1,450 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package smtp provides supports for sending emails +package smtp + +import ( + "bytes" + "context" + "errors" + "fmt" + "html/template" + "path/filepath" + "sync" + "time" + + "github.com/rs/xid" + "github.com/wneessen/go-mail" + + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + logSender = "smtp" +) + +// EmailContentType defines the support content types for email body +type EmailContentType int + +// Supported email body content type +const ( + EmailContentTypeTextPlain EmailContentType = iota + EmailContentTypeTextHTML +) + +const ( + templateEmailDir = "email" + templatePasswordReset = "reset-password.html" + templatePasswordExpiration = "password-expiration.html" + dialTimeout = 10 * time.Second +) + +var ( + config = &activeConfig{} + initialConfig *Config + emailTemplates = make(map[string]*template.Template) +) + +type activeConfig struct { + sync.RWMutex + config *Config +} + +func (c *activeConfig) isEnabled() bool { + c.RLock() + defer c.RUnlock() + + return c.config != nil && c.config.Host != "" +} + +func (c *activeConfig) Set(cfg *dataprovider.SMTPConfigs) { + var config *Config + if cfg != nil { + config = &Config{ + Host: cfg.Host, + Port: cfg.Port, + From: cfg.From, + User: cfg.User, + Password: cfg.Password.GetPayload(), + AuthType: cfg.AuthType, + Encryption: cfg.Encryption, + Domain: cfg.Domain, + Debug: cfg.Debug, + OAuth2: OAuth2Config{ + Provider: cfg.OAuth2.Provider, + Tenant: cfg.OAuth2.Tenant, + ClientID: cfg.OAuth2.ClientID, + ClientSecret: cfg.OAuth2.ClientSecret.GetPayload(), + RefreshToken: cfg.OAuth2.RefreshToken.GetPayload(), + }, + } + config.OAuth2.initialize() + } + + c.Lock() + defer c.Unlock() + + if config != nil && config.Host != "" { + if c.config != nil && c.config.isEqual(config) { + return + } + c.config = config + logger.Info(logSender, "", "activated new config, server %s:%d", c.config.Host, c.config.Port) + } else { + logger.Debug(logSender, "", "activating initial config") + c.config = initialConfig + if c.config == nil || c.config.Host == "" { + logger.Debug(logSender, "", "configuration disabled, email capabilities will not be available") + } + } +} + +func (c *activeConfig) getSMTPClientAndMsg(to, bcc []string, subject, body string, contentType EmailContentType, + attachments ...*mail.File, +) (*mail.Client, *mail.Msg, error) { + c.RLock() + defer c.RUnlock() + + if c.config == nil || c.config.Host == "" { + return nil, nil, errors.New("smtp: not configured") + } + + return c.config.getSMTPClientAndMsg(to, bcc, subject, body, contentType, attachments...) +} + +func (c *activeConfig) sendEmail(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error { + client, msg, err := c.getSMTPClientAndMsg(to, bcc, subject, body, contentType, attachments...) + if err != nil { + return err + } + + ctx, cancelFn := context.WithTimeout(context.Background(), dialTimeout) + defer cancelFn() + + return client.DialAndSendWithContext(ctx, msg) +} + +// IsEnabled returns true if an SMTP server is configured +func IsEnabled() bool { + return config.isEnabled() +} + +// Activate sets the specified config as active +func Activate(c *dataprovider.SMTPConfigs) { + config.Set(c) +} + +// Config defines the SMTP configuration to use to send emails +type Config struct { + // Location of SMTP email server. Leavy empty to disable email sending capabilities + Host string `json:"host" mapstructure:"host"` + // Port of SMTP email server + Port int `json:"port" mapstructure:"port"` + // From address, for example "SFTPGo ". + // Many SMTP servers reject emails without a `From` header so, if not set, + // SFTPGo will try to use the username as fallback, this may or may not be appropriate + From string `json:"from" mapstructure:"from"` + // SMTP username + User string `json:"user" mapstructure:"user"` + // SMTP password. Leaving both username and password empty the SMTP authentication + // will be disabled + Password string `json:"password" mapstructure:"password"` + // 0 Plain + // 1 Login + // 2 CRAM-MD5 + // 3 OAuth2 + AuthType int `json:"auth_type" mapstructure:"auth_type"` + // 0 no encryption + // 1 TLS + // 2 start TLS + Encryption int `json:"encryption" mapstructure:"encryption"` + // Domain to use for HELO command, if empty localhost will be used + Domain string `json:"domain" mapstructure:"domain"` + // Path to the email templates. This can be an absolute path or a path relative to the config dir. + // Templates are searched within a subdirectory named "email" in the specified path + TemplatesPath string `json:"templates_path" mapstructure:"templates_path"` + // Set to 1 to enable debug logs + Debug int `json:"debug" mapstructure:"debug"` + // OAuth2 related settings + OAuth2 OAuth2Config `json:"oauth2" mapstructure:"oauth2"` +} + +func (c *Config) isEqual(other *Config) bool { + if c.Host != other.Host { + return false + } + if c.Port != other.Port { + return false + } + if c.From != other.From { + return false + } + if c.User != other.User { + return false + } + if c.Password != other.Password { + return false + } + if c.AuthType != other.AuthType { + return false + } + if c.Encryption != other.Encryption { + return false + } + if c.Domain != other.Domain { + return false + } + if c.Debug != other.Debug { + return false + } + return c.OAuth2.isEqual(&other.OAuth2) +} + +func (c *Config) validate() error { + if c.Port <= 0 || c.Port > 65535 { + return fmt.Errorf("smtp: invalid port %d", c.Port) + } + if c.AuthType < 0 || c.AuthType > 3 { + return fmt.Errorf("smtp: invalid auth type %d", c.AuthType) + } + if c.Encryption < 0 || c.Encryption > 2 { + return fmt.Errorf("smtp: invalid encryption %d", c.Encryption) + } + if c.From == "" && c.User == "" { + return errors.New(`smtp: from address and user cannot both be empty`) + } + if c.AuthType == 3 { + return c.OAuth2.Validate() + } + return nil +} + +func (c *Config) loadTemplates(configDir string) error { + if c.TemplatesPath == "" { + logger.Debug(logSender, "", "templates path empty, using default") + c.TemplatesPath = "templates" + } + templatesPath := util.FindSharedDataPath(c.TemplatesPath, configDir) + if templatesPath == "" { + return fmt.Errorf("smtp: invalid templates path %q", templatesPath) + } + loadTemplates(filepath.Join(templatesPath, templateEmailDir)) + return nil +} + +// Initialize initializes and validates the SMTP configuration +func (c *Config) Initialize(configDir string, isService bool) error { + if !isService && c.Host == "" { + if err := loadConfigFromProvider(); err != nil { + return err + } + if !config.isEnabled() { + return nil + } + // If not running as a service, templates will only be loaded if required. + return c.loadTemplates(configDir) + } + // In service mode SMTP can be enabled from the WebAdmin at runtime so we + // always load templates. + if err := c.loadTemplates(configDir); err != nil { + return err + } + if c.Host == "" { + return loadConfigFromProvider() + } + if err := c.validate(); err != nil { + return err + } + initialConfig = c + config.Set(nil) + logger.Debug(logSender, "", "configuration successfully initialized, host: %q, port: %d, username: %q, auth: %d, encryption: %d, helo: %q", + c.Host, c.Port, c.User, c.AuthType, c.Encryption, c.Domain) + return loadConfigFromProvider() +} + +func (c *Config) getMailClientOptions() []mail.Option { + options := []mail.Option{mail.WithPort(c.Port), mail.WithoutNoop()} + + switch c.Encryption { + case 1: + options = append(options, mail.WithSSL()) + case 2: + options = append(options, mail.WithTLSPolicy(mail.TLSMandatory)) + default: + options = append(options, mail.WithTLSPolicy(mail.NoTLS)) + } + if c.User != "" { + options = append(options, mail.WithUsername(c.User)) + } + if c.Password != "" { + options = append(options, mail.WithPassword(c.Password)) + } + if c.User != "" || c.Password != "" { + switch c.AuthType { + case 1: + options = append(options, mail.WithSMTPAuth(mail.SMTPAuthLogin)) + case 2: + options = append(options, mail.WithSMTPAuth(mail.SMTPAuthCramMD5)) + case 3: + options = append(options, mail.WithSMTPAuth(mail.SMTPAuthXOAUTH2)) + default: + options = append(options, mail.WithSMTPAuth(mail.SMTPAuthPlain)) + } + } + if c.Domain != "" { + options = append(options, mail.WithHELO(c.Domain)) + } + if c.Debug > 0 { + options = append(options, + mail.WithLogger(&logger.MailAdapter{ + ConnectionID: xid.New().String(), + }), + mail.WithDebugLog()) + } + return options +} + +func (c *Config) getSMTPClientAndMsg(to, bcc []string, subject, body string, contentType EmailContentType, + attachments ...*mail.File) (*mail.Client, *mail.Msg, error) { + msg := mail.NewMsg() + msg.SetUserAgent(version.GetServerVersion(" ", false)) + + var from string + if c.From != "" { + from = c.From + } else { + from = c.User + } + if err := msg.From(from); err != nil { + return nil, nil, fmt.Errorf("invalid from address: %w", err) + } + if err := msg.To(to...); err != nil { + return nil, nil, err + } + if len(bcc) > 0 { + if err := msg.Bcc(bcc...); err != nil { + return nil, nil, err + } + } + msg.Subject(subject) + msg.SetDate() + msg.SetMessageID() + msg.SetAttachments(attachments) + + switch contentType { + case EmailContentTypeTextPlain: + msg.SetBodyString(mail.TypeTextPlain, body) + case EmailContentTypeTextHTML: + msg.SetBodyString(mail.TypeTextHTML, body) + default: + return nil, nil, fmt.Errorf("smtp: unsupported body content type %v", contentType) + } + + client, err := mail.NewClient(c.Host, c.getMailClientOptions()...) + if err != nil { + return nil, nil, fmt.Errorf("unable to create mail client: %w", err) + } + if c.AuthType == 3 { + token, err := c.OAuth2.getAccessToken() + if err != nil { + return nil, nil, fmt.Errorf("unable to get oauth2 access token: %w", err) + } + client.SetPassword(token) + } + return client, msg, nil +} + +// SendEmail tries to send an email using the specified parameters +func (c *Config) SendEmail(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error { + client, msg, err := c.getSMTPClientAndMsg(to, bcc, subject, body, contentType, attachments...) + if err != nil { + return err + } + ctx, cancelFn := context.WithTimeout(context.Background(), dialTimeout) + defer cancelFn() + + return client.DialAndSendWithContext(ctx, msg) +} + +func loadTemplates(templatesPath string) { + logger.Debug(logSender, "", "loading templates from %q", templatesPath) + + passwordResetPath := filepath.Join(templatesPath, templatePasswordReset) + pwdResetTmpl := util.LoadTemplate(nil, passwordResetPath) + passwordExpirationPath := filepath.Join(templatesPath, templatePasswordExpiration) + pwdExpirationTmpl := util.LoadTemplate(nil, passwordExpirationPath) + + emailTemplates[templatePasswordReset] = pwdResetTmpl + emailTemplates[templatePasswordExpiration] = pwdExpirationTmpl +} + +// RenderPasswordResetTemplate executes the password reset template +func RenderPasswordResetTemplate(buf *bytes.Buffer, data any) error { + if !IsEnabled() { + return errors.New("smtp: not configured") + } + return emailTemplates[templatePasswordReset].Execute(buf, data) +} + +// RenderPasswordExpirationTemplate executes the password expiration template +func RenderPasswordExpirationTemplate(buf *bytes.Buffer, data any) error { + if !IsEnabled() { + return errors.New("smtp: not configured") + } + return emailTemplates[templatePasswordExpiration].Execute(buf, data) +} + +// SendEmail tries to send an email using the specified parameters. +func SendEmail(to, bcc []string, subject, body string, contentType EmailContentType, attachments ...*mail.File) error { + return config.sendEmail(to, bcc, subject, body, contentType, attachments...) +} + +func loadConfigFromProvider() error { + configs, err := dataprovider.GetConfigs() + if err != nil { + logger.Error(logSender, "", "unable to load config from provider: %v", err) + return fmt.Errorf("smtp: unable to load config from provider: %w", err) + } + configs.SetNilsToEmpty() + if err := configs.SMTP.TryDecrypt(); err != nil { + logger.Error(logSender, "", "unable to decrypt smtp config: %v", err) + return fmt.Errorf("smtp: unable to decrypt smtp config: %w", err) + } + config.Set(configs.SMTP) + return nil +} + +func updateRefreshToken(token string) { + configs, err := dataprovider.GetConfigs() + if err != nil { + logger.Error(logSender, "", "unable to load config from provider, updating refresh token not possible: %v", err) + return + } + configs.SetNilsToEmpty() + if configs.SMTP.IsEmpty() { + logger.Warn(logSender, "", "unable to update refresh token, smtp not configured in the data provider") + return + } + configs.SMTP.OAuth2.RefreshToken = kms.NewPlainSecret(token) + if err := dataprovider.UpdateConfigs(&configs, dataprovider.ActionExecutorSystem, "", ""); err != nil { + logger.Error(logSender, "", "unable to save new refresh token: %v", err) + return + } + logger.Info(logSender, "", "refresh token updated") +} diff --git a/internal/telemetry/router.go b/internal/telemetry/router.go new file mode 100644 index 00000000..6149f293 --- /dev/null +++ b/internal/telemetry/router.go @@ -0,0 +1,74 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package telemetry + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/render" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" +) + +func initializeRouter(enableProfiler bool) { + router = chi.NewRouter() + + router.Use(middleware.GetHead) + router.Use(logger.NewStructuredLogger(logger.GetLogger())) + router.Use(middleware.Recoverer) + + router.Group(func(r chi.Router) { + r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) { + render.PlainText(w, r, "ok") + }) + }) + + router.Group(func(router chi.Router) { + router.Use(checkAuth) + metric.AddMetricsEndpoint(metricsPath, router) + + if enableProfiler { + logger.InfoToConsole("enabling the built-in profiler") + logger.Info(logSender, "", "enabling the built-in profiler") + router.Mount(pprofBasePath, middleware.Profiler()) + } + }) +} + +func checkAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !validateCredentials(r) { + w.Header().Set(common.HTTPAuthenticationHeader, "Basic realm=\"SFTPGo telemetry\"") + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func validateCredentials(r *http.Request) bool { + if !httpAuth.IsEnabled() { + return true + } + username, password, ok := r.BasicAuth() + if !ok { + return false + } + return httpAuth.ValidateCredentials(username, password) +} diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go new file mode 100644 index 00000000..233636aa --- /dev/null +++ b/internal/telemetry/telemetry.go @@ -0,0 +1,159 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package telemetry provides telemetry information for SFTPGo, such as: +// - health information (for health checks) +// - metrics +// - profiling information +package telemetry + +import ( + "crypto/tls" + "log" + "net/http" + "path/filepath" + "runtime" + "time" + + "github.com/go-chi/chi/v5" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + logSender = "telemetry" + metricsPath = "/metrics" + pprofBasePath = "/debug" +) + +var ( + router *chi.Mux + httpAuth common.HTTPAuthProvider + certMgr *common.CertManager +) + +// Conf telemetry server configuration. +type Conf struct { + // The port used for serving HTTP requests. 0 disable the HTTP server. Default: 0 + BindPort int `json:"bind_port" mapstructure:"bind_port"` + // The address to listen on. A blank value means listen on all available network interfaces. Default: "127.0.0.1" + BindAddress string `json:"bind_address" mapstructure:"bind_address"` + // Enable the built-in profiler. + // The profiler will be accessible via HTTP/HTTPS using the base URL "/debug/pprof/" + EnableProfiler bool `json:"enable_profiler" mapstructure:"enable_profiler"` + // Path to a file used to store usernames and password for basic authentication. + // This can be an absolute path or a path relative to the config dir. + // We support HTTP basic authentication and the file format must conform to the one generated using the Apache + // htpasswd tool. The supported password formats are bcrypt ($2y$ prefix) and md5 crypt ($apr1$ prefix). + // If empty HTTP authentication is disabled + AuthUserFile string `json:"auth_user_file" mapstructure:"auth_user_file"` + // If files containing a certificate and matching private key for the server are provided the server will expect + // HTTPS connections. + // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a + // "paramchange" request to the running service on Windows. + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. + // If CipherSuites is nil/empty, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. + // Note that TLS 1.3 ciphersuites are not configurable. + // The supported ciphersuites names are defined here: + // + // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 + // + // any invalid name will be silently ignored. + // The order matters, the ciphers listed first will be the preferred ones. + TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` + // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 + MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` + // HTTP protocols to enable in preference order. Supported values: http/1.1, h2 + Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"` +} + +// ShouldBind returns true if there service must be started +func (c Conf) ShouldBind() bool { + if c.BindPort > 0 { + return true + } + if filepath.IsAbs(c.BindAddress) && runtime.GOOS != "windows" { + return true + } + return false +} + +// Initialize configures and starts the telemetry server. +func (c Conf) Initialize(configDir string) error { + var err error + logger.Info(logSender, "", "initializing telemetry server with config %+v", c) + authUserFile := getConfigPath(c.AuthUserFile, configDir) + httpAuth, err = common.NewBasicAuthProvider(authUserFile) + if err != nil { + return err + } + certificateFile := getConfigPath(c.CertificateFile, configDir) + certificateKeyFile := getConfigPath(c.CertificateKeyFile, configDir) + initializeRouter(c.EnableProfiler) + httpServer := &http.Server{ + Handler: router, + ReadHeaderTimeout: 30 * time.Second, + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 1 << 14, // 16KB + ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), + } + if certificateFile != "" && certificateKeyFile != "" { + keyPairs := []common.TLSKeyPair{ + { + Cert: certificateFile, + Key: certificateKeyFile, + ID: common.DefaultTLSKeyPaidID, + }, + } + certMgr, err = common.NewCertManager(keyPairs, configDir, logSender) + if err != nil { + return err + } + config := &tls.Config{ + GetCertificate: certMgr.GetCertificateFunc(common.DefaultTLSKeyPaidID), + MinVersion: util.GetTLSVersion(c.MinTLSVersion), + NextProtos: util.GetALPNProtocols(c.Protocols), + CipherSuites: util.GetTLSCiphersFromNames(c.TLSCipherSuites), + } + logger.Debug(logSender, "", "configured TLS cipher suites: %v", config.CipherSuites) + httpServer.TLSConfig = config + return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, true, nil, logSender) + } + return util.HTTPListenAndServe(httpServer, c.BindAddress, c.BindPort, false, nil, logSender) +} + +// ReloadCertificateMgr reloads the certificate manager +func ReloadCertificateMgr() error { + if certMgr != nil { + return certMgr.Reload() + } + return nil +} + +func getConfigPath(name, configDir string) string { + if !util.IsFileInputValid(name) { + return "" + } + if name != "" && !filepath.IsAbs(name) { + return filepath.Join(configDir, name) + } + return name +} diff --git a/internal/telemetry/telemetry_test.go b/internal/telemetry/telemetry_test.go new file mode 100644 index 00000000..0e302283 --- /dev/null +++ b/internal/telemetry/telemetry_test.go @@ -0,0 +1,181 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package telemetry + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" +) + +const ( + httpsCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + httpsKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` +) + +func TestInitialization(t *testing.T) { + configDir := filepath.Join(".", "..", "..") + providerConf := dataprovider.Config{ + Driver: dataprovider.MemoryDataProviderName, + BackupsPath: "backups", + } + err := dataprovider.Initialize(providerConf, configDir, false) + require.NoError(t, err) + commonConfig := common.Configuration{} + err = common.Initialize(commonConfig, 0) + require.NoError(t, err) + c := Conf{ + BindPort: 10000, + BindAddress: "invalid address", + EnableProfiler: false, + } + err = c.Initialize(configDir) + require.Error(t, err) + + c.AuthUserFile = "missing" + err = c.Initialize(".") + require.Error(t, err) + + err = ReloadCertificateMgr() + require.NoError(t, err) + + c.AuthUserFile = "" + c.CertificateFile = "crt" + c.CertificateKeyFile = "key" + + err = c.Initialize(".") + require.Error(t, err) + + certPath := filepath.Join(os.TempDir(), "test.crt") + keyPath := filepath.Join(os.TempDir(), "test.key") + err = os.WriteFile(certPath, []byte(httpsCert), os.ModePerm) + require.NoError(t, err) + err = os.WriteFile(keyPath, []byte(httpsKey), os.ModePerm) + require.NoError(t, err) + + c.CertificateFile = certPath + c.CertificateKeyFile = keyPath + + err = c.Initialize(".") + require.Error(t, err) + + err = ReloadCertificateMgr() + require.NoError(t, err) + + err = os.Remove(certPath) + require.NoError(t, err) + err = os.Remove(keyPath) + require.NoError(t, err) +} + +func TestShouldBind(t *testing.T) { + c := Conf{ + BindPort: 10000, + EnableProfiler: false, + } + require.True(t, c.ShouldBind()) + + c.BindPort = 0 + require.False(t, c.ShouldBind()) + + if runtime.GOOS != "windows" { + c.BindAddress = "/absolute/path" + require.True(t, c.ShouldBind()) + } +} + +func TestRouter(t *testing.T) { + authUserFile := filepath.Join(os.TempDir(), "http_users.txt") + authUserData := []byte("test1:$2y$05$bcHSED7aO1cfLto6ZdDBOOKzlwftslVhtpIkRhAtSa4GuLmk5mola\n") + err := os.WriteFile(authUserFile, authUserData, os.ModePerm) + require.NoError(t, err) + + httpAuth, err = common.NewBasicAuthProvider(authUserFile) + require.NoError(t, err) + + initializeRouter(true) + testServer := httptest.NewServer(router) + defer testServer.Close() + + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) + rr := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "ok", rr.Body.String()) + + req, err = http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusUnauthorized, rr.Code) + + req.SetBasicAuth("test1", "password1") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + req, err = http.NewRequest(http.MethodGet, pprofBasePath+"/pprof/", nil) + require.NoError(t, err) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusUnauthorized, rr.Code) + + req.SetBasicAuth("test1", "password1") + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + httpAuth, err = common.NewBasicAuthProvider("") + require.NoError(t, err) + + req, err = http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + rr = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(rr, req) + require.Equal(t, http.StatusOK, rr.Code) + + err = os.Remove(authUserFile) + require.NoError(t, err) +} diff --git a/internal/util/errors.go b/internal/util/errors.go new file mode 100644 index 00000000..79e54265 --- /dev/null +++ b/internal/util/errors.go @@ -0,0 +1,135 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package util + +import ( + "errors" + "fmt" +) + +const ( + templateLoadErrorHints = "Try setting the absolute templates path in your configuration file " + + "or specifying the config directory adding the `-c` flag to the serve options. For example: " + + "sftpgo serve -c \"\"" +) + +// MaxRecursion defines the maximum number of allowed recursions +const MaxRecursion = 1000 + +// errors definitions +var ( + ErrValidation = NewValidationError("") + ErrNotFound = NewRecordNotFoundError("") + ErrMethodDisabled = NewMethodDisabledError("") + ErrGeneric = NewGenericError("") + ErrRecursionTooDeep = errors.New("recursion too deep") +) + +// ValidationError raised if input data is not valid +type ValidationError struct { + err string +} + +// Validation error details +func (e *ValidationError) Error() string { + return fmt.Sprintf("Validation error: %s", e.err) +} + +// GetErrorString returns the unmodified error string +func (e *ValidationError) GetErrorString() string { + return e.err +} + +// Is reports if target matches +func (e *ValidationError) Is(target error) bool { + _, ok := target.(*ValidationError) + return ok +} + +// NewValidationError returns a validation errors +func NewValidationError(errorString string) *ValidationError { + return &ValidationError{ + err: errorString, + } +} + +// RecordNotFoundError raised if a requested object is not found +type RecordNotFoundError struct { + err string +} + +func (e *RecordNotFoundError) Error() string { + return fmt.Sprintf("not found: %s", e.err) +} + +// Is reports if target matches +func (e *RecordNotFoundError) Is(target error) bool { + _, ok := target.(*RecordNotFoundError) + return ok +} + +// NewRecordNotFoundError returns a not found error +func NewRecordNotFoundError(errorString string) *RecordNotFoundError { + return &RecordNotFoundError{ + err: errorString, + } +} + +// MethodDisabledError raised if a method is disabled in config file. +// For example, if user management is disabled, this error is raised +// every time a user operation is done using the REST API +type MethodDisabledError struct { + err string +} + +// Method disabled error details +func (e *MethodDisabledError) Error() string { + return fmt.Sprintf("Method disabled error: %s", e.err) +} + +// Is reports if target matches +func (e *MethodDisabledError) Is(target error) bool { + _, ok := target.(*MethodDisabledError) + return ok +} + +// NewMethodDisabledError returns a method disabled error +func NewMethodDisabledError(errorString string) *MethodDisabledError { + return &MethodDisabledError{ + err: errorString, + } +} + +// GenericError raised for not well categorized error +type GenericError struct { + err string +} + +func (e *GenericError) Error() string { + return e.err +} + +// Is reports if target matches +func (e *GenericError) Is(target error) bool { + _, ok := target.(*GenericError) + return ok +} + +// NewGenericError returns a generic error +func NewGenericError(errorString string) *GenericError { + return &GenericError{ + err: errorString, + } +} diff --git a/internal/util/i18n.go b/internal/util/i18n.go new file mode 100644 index 00000000..2b318310 --- /dev/null +++ b/internal/util/i18n.go @@ -0,0 +1,380 @@ +// Copyright (C) 2023 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package util + +import ( + "encoding/json" + "errors" +) + +// localization id for the Web frontend +const ( + I18nSetupTitle = "title.setup" + I18nLoginTitle = "title.login" + I18nShareLoginTitle = "title.share_login" + I18nFilesTitle = "title.files" + I18nSharesTitle = "title.shares" + I18nShareAddTitle = "title.add_share" + I18nShareUpdateTitle = "title.update_share" + I18nProfileTitle = "title.profile" + I18nUsersTitle = "title.users" + I18nGroupsTitle = "title.groups" + I18nFoldersTitle = "title.folders" + I18nChangePwdTitle = "title.change_password" + I18n2FATitle = "title.two_factor_auth" + I18nEditFileTitle = "title.edit_file" + I18nViewFileTitle = "title.view_file" + I18nForgotPwdTitle = "title.recovery_password" + I18nResetPwdTitle = "title.reset_password" + I18nSharedFilesTitle = "title.shared_files" + I18nShareUploadTitle = "title.upload_to_share" + I18nShareDownloadTitle = "title.download_shared_file" + I18nShareAccessErrorTitle = "title.share_access_error" + I18nInvalidAuthReqTitle = "title.invalid_auth_request" + I18nError403Title = "title.error403" + I18nError400Title = "title.error400" + I18nError404Title = "title.error404" + I18nError416Title = "title.error416" + I18nError429Title = "title.error429" + I18nError500Title = "title.error500" + I18nErrorPDFTitle = "title.errorPDF" + I18nErrorEditorTitle = "title.error_editor" + I18nAddUserTitle = "title.add_user" + I18nUpdateUserTitle = "title.update_user" + I18nAddAdminTitle = "title.add_admin" + I18nUpdateAdminTitle = "title.update_admin" + I18nTemplateUserTitle = "title.template_user" + I18nMaintenanceTitle = "title.maintenance" + I18nConfigsTitle = "title.configs" + I18nOAuth2Title = "title.oauth2_success" + I18nOAuth2ErrorTitle = "title.oauth2_error" + I18nSessionsTitle = "title.connections" + I18nRolesTitle = "title.roles" + I18nAdminsTitle = "title.admins" + I18nIPListsTitle = "title.ip_lists" + I18nAddIPListTitle = "title.add_ip_list" + I18nUpdateIPListTitle = "title.update_ip_list" + I18nDefenderTitle = "title.defender" + I18nEventsTitle = "title.logs" + I18nActionsTitle = "title.event_actions" + I18nRulesTitle = "title.event_rules" + I18nAddActionTitle = "title.add_action" + I18nUpdateActionTitle = "title.update_action" + I18nAddRuleTitle = "title.add_rule" + I18nUpdateRuleTitle = "title.update_rule" + I18nStatusTitle = "status.desc" + I18nErrorSetupInstallCode = "setup.install_code_mismatch" + I18nInvalidAuth = "general.invalid_auth_request" + I18nError429Message = "general.error429" + I18nError400Message = "general.error400" + I18nError403Message = "general.error403" + I18nError404Message = "general.error404" + I18nError416Message = "general.error416" + I18nError500Message = "general.error500" + I18nErrorPDFMessage = "general.errorPDF" + I18nErrorInvalidToken = "general.invalid_token" + I18nErrorInvalidForm = "general.invalid_form" + I18nErrorInvalidCredentials = "general.invalid_credentials" + I18nErrorInvalidCSRF = "general.invalid_csrf" + I18nErrorFsGeneric = "fs.err_generic" + I18nErrorDirListGeneric = "fs.dir_list.err_generic" + I18nErrorDirList403 = "fs.dir_list.err_403" + I18nErrorDirList429 = "fs.dir_list.err_429" + I18nErrorDirListUser = "fs.dir_list.err_user" + I18nErrorFsValidation = "fs.err_validation" + I18nErrorChangePwdRequiredFields = "change_pwd.required_fields" + I18nErrorChangePwdNoMatch = "change_pwd.no_match" + I18nErrorChangePwdGeneric = "change_pwd.generic" + I18nErrorChangePwdNoDifferent = "change_pwd.no_different" + I18nErrorChangePwdCurrentNoMatch = "change_pwd.current_no_match" + I18nErrorChangePwdRequired = "change_pwd.required" + I18nErrorUsernameRequired = "general.username_required" + I18nErrorPasswordRequired = "general.password_required" + I18nErrorPermissionsRequired = "general.permissions_required" + I18nErrorGetUser = "general.err_user" + I18nErrorPwdResetForbidded = "login.reset_pwd_forbidden" + I18nErrorPwdResetNoEmail = "login.reset_pwd_no_email" + I18nErrorPwdResetSendEmail = "login.reset_pwd_send_email_err" + I18nErrorPwdResetGeneric = "login.reset_pwd_err_generic" + I18nErrorProtocolForbidden = "general.err_protocol_forbidden" + I18nErrorPwdLoginForbidden = "general.pwd_login_forbidden" + I18nErrorIPForbidden = "general.ip_forbidden" + I18nErrorConnectionForbidden = "general.connection_forbidden" + I18nErrorReservedUsername = "user.username_reserved" + I18nErrorInvalidEmail = "general.email_invalid" + I18nErrorInvalidInput = "general.invalid_input" + I18nErrorInvalidUser = "user.username_invalid" + I18nErrorInvalidName = "general.name_invalid" + I18nErrorHomeRequired = "user.home_required" + I18nErrorHomeInvalid = "user.home_invalid" + I18nErrorPubKeyInvalid = "user.pub_key_invalid" + I18nErrorPrivKeyInvalid = "user.priv_key_invalid" + I18nErrorKeySizeInvalid = "user.key_invalid_size" + I18nErrorKeyInsecure = "user.key_insecure" + I18nErrorPrimaryGroup = "user.err_primary_group" + I18nErrorDuplicateGroup = "user.err_duplicate_group" + I18nErrorNoPermission = "user.no_permissions" + I18nErrorNoRootPermission = "user.no_root_permissions" + I18nErrorGenericPermission = "user.err_permissions_generic" + I18nError2FAInvalid = "user.2fa_invalid" + I18nErrorRecoveryCodesInvalid = "user.recovery_codes_invalid" + I18nErrorFolderNameRequired = "general.foldername_required" + I18nErrorFolderMountPathRequired = "user.folder_path_required" + I18nErrorDuplicatedFolders = "user.folder_duplicated" + I18nErrorOverlappedFolders = "user.folder_overlapped" + I18nErrorFolderQuotaSizeInvalid = "user.folder_quota_size_invalid" + I18nErrorFolderQuotaFileInvalid = "user.folder_quota_file_invalid" + I18nErrorFolderQuotaInvalid = "user.folder_quota_invalid" + I18nErrorPasswordComplexity = "general.err_password_complexity" + I18nErrorIPFiltersInvalid = "user.ip_filters_invalid" + I18nErrorSourceBWLimitInvalid = "user.src_bw_limits_invalid" + I18nErrorShareExpirationInvalid = "user.share_expiration_invalid" + I18nErrorFilePatternPathInvalid = "user.file_pattern_path_invalid" + I18nErrorFilePatternDuplicated = "user.file_pattern_duplicated" + I18nErrorFilePatternInvalid = "user.file_pattern_invalid" + I18nErrorDisableActive2FA = "user.disable_active_2fa" + I18nErrorPwdChangeConflict = "user.pwd_change_conflict" + I18nError2FAConflict = "user.two_factor_conflict" + I18nErrorLoginAfterReset = "login.reset_ok_login_error" + I18nErrorShareScope = "share.scope_invalid" + I18nErrorShareMaxTokens = "share.max_tokens_invalid" + I18nErrorShareExpiration = "share.expiration_invalid" + I18nErrorShareNoPwd = "share.err_no_password" + I18nErrorShareExpirationOutOfRange = "share.expiration_out_of_range" + I18nErrorShareGeneric = "share.generic" + I18nErrorNameRequired = "general.name_required" + I18nErrorSharePathRequired = "share.path_required" + I18nErrorShareWriteScope = "share.path_write_scope" + I18nErrorShareNestedPaths = "share.nested_paths" + I18nErrorShareExpirationPast = "share.expiration_past" + I18nErrorInvalidIPMask = "general.allowed_ip_mask_invalid" + I18nErrorShareUsage = "share.usage_exceed" + I18nErrorShareExpired = "share.expired" + I18nErrorLoginFromIPDenied = "login.ip_not_allowed" + I18nError2FARequired = "login.two_factor_required" + I18nError2FARequiredGeneric = "login.two_factor_required_generic" + I18nErrorNoOIDCFeature = "general.no_oidc_feature" + I18nErrorNoPermissions = "general.no_permissions" + I18nErrorShareBrowsePaths = "share.browsable_multiple_paths" + I18nErrorShareBrowseNoDir = "share.browsable_non_dir" + I18nErrorShareInvalidPath = "share.invalid_path" + I18nErrorPathInvalid = "general.path_invalid" + I18nErrorQuotaRead = "general.err_quota_read" + I18nErrorEditDir = "general.error_edit_dir" + I18nErrorEditSize = "general.error_edit_size" + I18nProfileUpdated = "general.profile_updated" + I18nShareLoginOK = "general.share_ok" + I18n2FADisabled = "2fa.disabled" + I18nOIDCTokenExpired = "oidc.token_expired" + I18nOIDCTokenInvalidAdmin = "oidc.token_invalid_webadmin" + I18nOIDCTokenInvalidUser = "oidc.token_invalid_webclient" + I18nOIDCErrTokenExchange = "oidc.token_exchange_err" + I18nOIDCTokenInvalid = "oidc.token_invalid" + I18nOIDCTokenInvalidRoleAdmin = "oidc.role_admin_err" + I18nOIDCTokenInvalidRoleUser = "oidc.role_user_err" + I18nOIDCErrGetUser = "oidc.get_user_err" + I18nErrorInvalidQuotaSize = "user.invalid_quota_size" + I18nErrorTimeOfDayInvalid = "user.time_of_day_invalid" + I18nErrorTimeOfDayConflict = "user.time_of_day_conflict" + I18nErrorInvalidMaxFilesize = "filters.max_upload_size_invalid" + I18nErrorInvalidHomeDir = "storage.home_dir_invalid" + I18nErrorBucketRequired = "storage.bucket_required" + I18nErrorRegionRequired = "storage.region_required" + I18nErrorKeyPrefixInvalid = "storage.key_prefix_invalid" + I18nErrorULPartSizeInvalid = "storage.ul_part_size_invalid" + I18nErrorDLPartSizeInvalid = "storage.dl_part_size_invalid" + I18nErrorULConcurrencyInvalid = "storage.ul_concurrency_invalid" + I18nErrorDLConcurrencyInvalid = "storage.dl_concurrency_invalid" + I18nErrorAccessKeyRequired = "storage.access_key_required" + I18nErrorAccessSecretRequired = "storage.access_secret_required" + I18nErrorFsCredentialsRequired = "storage.credentials_required" + I18nErrorContainerRequired = "storage.container_required" + I18nErrorAccountNameRequired = "storage.account_name_required" + I18nErrorSASURLInvalid = "storage.sas_url_invalid" + I18nErrorPassphraseRequired = "storage.passphrase_required" + I18nErrorEndpointInvalid = "storage.endpoint_invalid" + I18nErrorEndpointRequired = "storage.endpoint_required" + I18nErrorFsUsernameRequired = "storage.username_required" + I18nAddGroupTitle = "title.add_group" + I18nUpdateGroupTitle = "title.update_group" + I18nRoleAddTitle = "title.add_role" + I18nRoleUpdateTitle = "title.update_role" + I18nErrorInvalidTLSCert = "user.tls_cert_invalid" + I18nAddFolderTitle = "title.add_folder" + I18nUpdateFolderTitle = "title.update_folder" + I18nTemplateFolderTitle = "title.template_folder" + I18nErrorDuplicatedUsername = "general.duplicated_username" + I18nErrorDuplicatedName = "general.duplicated_name" + I18nErrorDuplicatedIPNet = "ip_list.duplicated" + I18nErrorRoleAdminPerms = "admin.role_permissions" + I18nBackupOK = "maintenance.backup_ok" + I18nErrorFolderTemplate = "virtual_folders.template_no_folder" + I18nErrorUserTemplate = "user.template_no_user" + I18nConfigsOK = "general.configs_saved" + I18nOAuth2ErrorVerifyState = "oauth2.auth_verify_error" + I18nOAuth2ErrorValidateState = "oauth2.auth_validation_error" + I18nOAuth2InvalidState = "oauth2.auth_invalid" + I18nOAuth2ErrTokenExchange = "oauth2.token_exchange_err" + I18nOAuth2ErrNoRefreshToken = "oauth2.no_refresh_token" + I18nOAuth2OK = "oauth2.success" + I18nErrorAdminSelfPerms = "admin.self_permissions" + I18nErrorAdminSelfDisable = "admin.self_disable" + I18nErrorAdminSelfRole = "admin.self_role" + I18nErrorIPInvalid = "ip_list.ip_invalid" + I18nErrorNetInvalid = "ip_list.net_invalid" + I18nFTPTLSDisabled = "status.tls_disabled" + I18nFTPTLSExplicit = "status.tls_explicit" + I18nFTPTLSImplicit = "status.tls_implicit" + I18nFTPTLSMixed = "status.tls_mixed" + I18nErrorBackupFile = "maintenance.backup_invalid_file" + I18nErrorRestore = "maintenance.restore_error" + I18nErrorACMEGeneric = "acme.generic_error" + I18nErrorSMTPRequiredFields = "smtp.err_required_fields" + I18nErrorClientIDRequired = "oauth2.client_id_required" + I18nErrorClientSecretRequired = "oauth2.client_secret_required" + I18nErrorRefreshTokenRequired = "oauth2.refresh_token_required" + I18nErrorURLRequired = "actions.http_url_required" + I18nErrorURLInvalid = "actions.http_url_invalid" + I18nErrorHTTPPartNameRequired = "actions.http_part_name_required" + I18nErrorHTTPPartBodyRequired = "actions.http_part_body_required" + I18nErrorMultipartBody = "actions.http_multipart_body_error" + I18nErrorMultipartCType = "actions.http_multipart_ctype_error" + I18nErrorPathDuplicated = "actions.path_duplicated" + I18nErrorCommandRequired = "actions.command_required" + I18nErrorCommandInvalid = "actions.command_invalid" + I18nErrorEmailRecipientRequired = "actions.email_recipient_required" + I18nErrorEmailSubjectRequired = "actions.email_subject_required" + I18nErrorEmailBodyRequired = "actions.email_body_required" + I18nErrorRetentionDirRequired = "actions.retention_directory_required" + I18nErrorPathRequired = "actions.path_required" + I18nErrorSourceDestMatch = "actions.source_dest_different" + I18nErrorRootNotAllowed = "actions.root_not_allowed" + I18nErrorArchiveNameRequired = "actions.archive_name_required" + I18nErrorIDPTemplateRequired = "actions.idp_template_required" + I18nActionTypeHTTP = "actions.types.http" + I18nActionTypeEmail = "actions.types.email" + I18nActionTypeBackup = "actions.types.backup" + I18nActionTypeUserQuotaReset = "actions.types.user_quota_reset" + I18nActionTypeFolderQuotaReset = "actions.types.folder_quota_reset" + I18nActionTypeTransferQuotaReset = "actions.types.transfer_quota_reset" + I18nActionTypeDataRetentionCheck = "actions.types.data_retention_check" + I18nActionTypeFilesystem = "actions.types.filesystem" + I18nActionTypePwdExpirationCheck = "actions.types.password_expiration_check" + I18nActionTypeUserExpirationCheck = "actions.types.user_expiration_check" + I18nActionTypeUserInactivityCheck = "actions.types.user_inactivity_check" + I18nActionTypeIDPCheck = "actions.types.idp_check" + I18nActionTypeCommand = "actions.types.command" + I18nActionTypeRotateLogs = "actions.types.rotate_logs" + I18nActionFsTypeRename = "actions.fs_types.rename" + I18nActionFsTypeDelete = "actions.fs_types.delete" + I18nActionFsTypePathExists = "actions.fs_types.path_exists" + I18nActionFsTypeCompress = "actions.fs_types.compress" + I18nActionFsTypeCopy = "actions.fs_types.copy" + I18nActionFsTypeCreateDirs = "actions.fs_types.create_dirs" + I18nActionThresholdRequired = "actions.inactivity_threshold_required" + I18nActionThresholdsInvalid = "actions.inactivity_thresholds_invalid" + I18nTriggerFsEvent = "rules.triggers.fs_event" + I18nTriggerProviderEvent = "rules.triggers.provider_event" + I18nTriggerIPBlockedEvent = "rules.triggers.ip_blocked" + I18nTriggerCertificateRenewEvent = "rules.triggers.certificate_renewal" + I18nTriggerOnDemandEvent = "rules.triggers.on_demand" + I18nTriggerIDPLoginEvent = "rules.triggers.idp_login" + I18nTriggerScheduleEvent = "rules.triggers.schedule" + I18nErrorInvalidMinSize = "rules.invalid_fs_min_size" + I18nErrorInvalidMaxSize = "rules.invalid_fs_max_size" + I18nErrorRuleActionRequired = "rules.action_required" + I18nErrorRuleFsEventRequired = "rules.fs_event_required" + I18nErrorRuleProviderEventRequired = "rules.provider_event_required" + I18nErrorRuleScheduleRequired = "rules.schedule_required" + I18nErrorRuleScheduleInvalid = "rules.schedule_invalid" + I18nErrorRuleDuplicateActions = "rules.duplicate_actions" + I18nErrorEvSyncFailureActions = "rules.sync_failure_actions" + I18nErrorEvSyncUnsupported = "rules.sync_unsupported" + I18nErrorEvSyncUnsupportedFs = "rules.sync_unsupported_fs_event" + I18nErrorRuleFailureActionsOnly = "rules.only_failure_actions" + I18nErrorRuleSyncActionRequired = "rules.sync_action_required" + I18nErrorInvalidPNG = "branding.invalid_png" + I18nErrorInvalidPNGSize = "branding.invalid_png_size" + I18nErrorInvalidDisclaimerURL = "branding.invalid_disclaimer_url" +) + +// NewI18nError returns a I18nError wrappring the provided error +func NewI18nError(err error, message string, options ...I18nErrorOption) *I18nError { + var errI18n *I18nError + if errors.As(err, &errI18n) { + return errI18n + } + errI18n = &I18nError{ + err: err, + Message: message, + args: nil, + } + for _, opt := range options { + opt(errI18n) + } + return errI18n +} + +// I18nErrorOption defines a functional option type that allows to configure the I18nError. +type I18nErrorOption func(*I18nError) + +// I18nErrorArgs is a functional option to set I18nError arguments. +func I18nErrorArgs(args map[string]any) I18nErrorOption { + return func(e *I18nError) { + e.args = args + } +} + +// I18nError is an error wrapper that add a message to use for localization. +type I18nError struct { + err error + Message string + args map[string]any +} + +// Error returns the wrapped error string. +func (e *I18nError) Error() string { + return e.err.Error() +} + +// Unwrap returns the underlying error +func (e *I18nError) Unwrap() error { + return e.err +} + +// Is reports if target matches +func (e *I18nError) Is(target error) bool { + if errors.Is(e.err, target) { + return true + } + _, ok := target.(*I18nError) + return ok +} + +// HasArgs returns true if the error has i18n args. +func (e *I18nError) HasArgs() bool { + return len(e.args) > 0 +} + +// Args returns the provided args in JSON format +func (e *I18nError) Args() string { + if len(e.args) > 0 { + data, err := json.Marshal(e.args) + if err == nil { + return BytesToString(data) + } + } + return "{}" +} diff --git a/internal/util/resources.go b/internal/util/resources.go new file mode 100644 index 00000000..8cddd946 --- /dev/null +++ b/internal/util/resources.go @@ -0,0 +1,85 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !bundle + +package util + +import ( + "html/template" + "os" + "path/filepath" + "runtime" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// FindSharedDataPath searches for the specified directory name in searchDir +// and in system-wide shared data directories. +// If name is an absolute path it is returned unmodified. +func FindSharedDataPath(name, searchDir string) string { + if !IsFileInputValid(name) { + return "" + } + if name != "" && !filepath.IsAbs(name) { + searchList := []string{searchDir} + if additionalSharedDataSearchPath != "" { + searchList = append(searchList, additionalSharedDataSearchPath) + } + if runtime.GOOS != osWindows { + searchList = append(searchList, "/usr/share/sftpgo") + searchList = append(searchList, "/usr/local/share/sftpgo") + } + searchList = RemoveDuplicates(searchList, false) + for _, basePath := range searchList { + res := filepath.Join(basePath, name) + _, err := os.Stat(res) + if err == nil { + logger.Debug(logSender, "", "found share data path for name %q: %q", name, res) + return res + } + } + return filepath.Join(searchDir, name) + } + return name +} + +// LoadTemplate parses the given template paths. +// It behaves like template.Must but it writes a log before exiting. +func LoadTemplate(base *template.Template, paths ...string) *template.Template { + if base != nil { + baseTmpl, err := base.Clone() + if err != nil { + showTemplateLoadingError(err) + } + t, err := baseTmpl.ParseFiles(paths...) + if err != nil { + showTemplateLoadingError(err) + } + return t + } + + t, err := template.ParseFiles(paths...) + if err != nil { + showTemplateLoadingError(err) + } + return t +} + +func showTemplateLoadingError(err error) { + logger.ErrorToConsole("error loading required template: %v", err) + logger.ErrorToConsole(templateLoadErrorHints) + logger.Error(logSender, "", "error loading required template: %v", err) + os.Exit(1) +} diff --git a/internal/util/resources_embedded.go b/internal/util/resources_embedded.go new file mode 100644 index 00000000..685eaf72 --- /dev/null +++ b/internal/util/resources_embedded.go @@ -0,0 +1,58 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build bundle + +package util + +import ( + "html/template" + "os" + + "github.com/drakkan/sftpgo/v2/internal/bundle" + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// FindSharedDataPath searches for the specified directory name in searchDir +// and in system-wide shared data directories. +// If name is an absolute path it is returned unmodified. +func FindSharedDataPath(name, _ string) string { + return name +} + +// LoadTemplate parses the given template paths. +// It behaves like template.Must but it writes a log before exiting. +// You can optionally provide a base template (e.g. to define some custom functions) +func LoadTemplate(base *template.Template, paths ...string) *template.Template { + var t *template.Template + var err error + + templateFs := bundle.GetTemplatesFs() + if base != nil { + base, err = base.Clone() + if err == nil { + t, err = base.ParseFS(templateFs, paths...) + } + } else { + t, err = template.ParseFS(templateFs, paths...) + } + + if err != nil { + logger.ErrorToConsole("error loading required template: %v", err) + logger.ErrorToConsole(templateLoadErrorHints) + logger.Error(logSender, "", "error loading required template: %v", err) + os.Exit(1) + } + return t +} diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 00000000..82cbbba2 --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,1013 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package util provides some common utility methods +package util + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/subtle" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "hash" + "io" + "io/fs" + "math" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "runtime" + "slices" + "strconv" + "strings" + "time" + "unicode" + "unsafe" + + "github.com/google/uuid" + "github.com/lithammer/shortuuid/v4" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +const ( + logSender = "util" + osWindows = "windows" + pubKeySuffix = ".pub" +) + +var ( + emailRegex = regexp.MustCompile("^(?:(?:(?:(?:[a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+(?:\\.([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+)*)|(?:(?:\\x22)(?:(?:(?:(?:\\x20|\\x09)*(?:\\x0d\\x0a))?(?:\\x20|\\x09)+)?(?:(?:[\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x7f]|\\x21|[\\x23-\\x5b]|[\\x5d-\\x7e]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[\\x01-\\x09\\x0b\\x0c\\x0d-\\x7f]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}]))))*(?:(?:(?:\\x20|\\x09)*(?:\\x0d\\x0a))?(\\x20|\\x09)+)?(?:\\x22))))@(?:(?:(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])(?:[a-zA-Z]|\\d|-|\\.|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*(?:[a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.)+(?:(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(?:(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])(?:[a-zA-Z]|\\d|-|\\.|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*(?:[a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.?$") + // this can be set at build time + additionalSharedDataSearchPath = "" + // CertsBasePath defines base path for certificates obtained using the built-in ACME protocol. + // It is empty is ACME support is disabled + CertsBasePath string + // Defines the TLS ciphers used by default for TLS 1.0-1.2 if no preference is specified. + defaultTLSCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + } +) + +// IEC Sizes. +// kibis of bits +const ( + oneByte = 1 << (iota * 10) + kiByte + miByte + giByte + tiByte + piByte + eiByte +) + +// SI Sizes. +const ( + iByte = 1 + kbByte = iByte * 1000 + mByte = kbByte * 1000 + gByte = mByte * 1000 + tByte = gByte * 1000 + pByte = tByte * 1000 + eByte = pByte * 1000 +) + +var bytesSizeTable = map[string]uint64{ + "b": oneByte, + "kib": kiByte, + "kb": kbByte, + "mib": miByte, + "mb": mByte, + "gib": giByte, + "gb": gByte, + "tib": tiByte, + "tb": tByte, + "pib": piByte, + "pb": pByte, + "eib": eiByte, + "eb": eByte, + // Without suffix + "": oneByte, + "ki": kiByte, + "k": kbByte, + "mi": miByte, + "m": mByte, + "gi": giByte, + "g": gByte, + "ti": tiByte, + "t": tByte, + "pi": piByte, + "p": pByte, + "ei": eiByte, + "e": eByte, +} + +// IsStringPrefixInSlice searches a string prefix in a slice and returns true +// if a matching prefix is found +func IsStringPrefixInSlice(obj string, list []string) bool { + for i := 0; i < len(list); i++ { + if strings.HasPrefix(obj, list[i]) { + return true + } + } + return false +} + +// RemoveDuplicates returns a new slice removing any duplicate element from the initial one +func RemoveDuplicates(obj []string, trim bool) []string { + if len(obj) == 0 { + return obj + } + seen := make(map[string]bool) + validIdx := 0 + for _, item := range obj { + if trim { + item = strings.TrimSpace(item) + } + if !seen[item] { + seen[item] = true + obj[validIdx] = item + validIdx++ + } + } + return obj[:validIdx] +} + +// IsNameValid validates that a name/username contains only safe characters. +func IsNameValid(name string) bool { + if name == "" { + return false + } + if len(name) > 255 { + return false + } + for _, r := range name { + if unicode.IsControl(r) { + return false + } + + switch r { + case '/', '\\': + return false + case ':', '*', '?', '"', '<', '>', '|': + return false + } + } + + if name == "." || name == ".." { + return false + } + + upperName := strings.ToUpper(name) + baseName := strings.Split(upperName, ".")[0] + + switch baseName { + case "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", + "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9": + return false + } + + if strings.HasSuffix(name, " ") || strings.HasSuffix(name, ".") { + return false + } + + return true +} + +// GetTimeAsMsSinceEpoch returns unix timestamp as milliseconds from a time struct +func GetTimeAsMsSinceEpoch(t time.Time) int64 { + return t.UnixMilli() +} + +// GetTimeFromMsecSinceEpoch return a time struct from a unix timestamp with millisecond precision +func GetTimeFromMsecSinceEpoch(msec int64) time.Time { + return time.Unix(0, msec*1000000) +} + +// GetDurationAsString returns a string representation for a time.Duration +func GetDurationAsString(d time.Duration) string { + d = d.Round(time.Second) + h := d / time.Hour + d -= h * time.Hour + m := d / time.Minute + d -= m * time.Minute + s := d / time.Second + if h > 0 { + return fmt.Sprintf("%02d:%02d:%02d", h, m, s) + } + return fmt.Sprintf("%02d:%02d", m, s) +} + +// ByteCountSI returns humanized size in SI (decimal) format +func ByteCountSI(b int64) string { + return byteCount(b, 1000, true) +} + +// ByteCountIEC returns humanized size in IEC (binary) format +func ByteCountIEC(b int64) string { + return byteCount(b, 1024, false) +} + +func byteCount(b int64, unit int64, maxPrecision bool) string { + if b <= 0 && maxPrecision { + return strconv.FormatInt(b, 10) + } + if b < unit { + return fmt.Sprintf("%d B", b) + } + div, exp := unit, 0 + for n := b / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + var val string + if maxPrecision { + val = strconv.FormatFloat(float64(b)/float64(div), 'f', -1, 64) + } else { + val = fmt.Sprintf("%.1f", float64(b)/float64(div)) + } + if unit == 1000 { + return fmt.Sprintf("%s %cB", val, "KMGTPE"[exp]) + } + return fmt.Sprintf("%s %ciB", val, "KMGTPE"[exp]) +} + +// ParseBytes parses a string representation of bytes into the number +// of bytes it represents. +// +// ParseBytes("42 MB") -> 42000000, nil +// ParseBytes("42 mib") -> 44040192, nil +// +// copied from here: +// +// https://github.com/dustin/go-humanize/blob/master/bytes.go +// +// with minor modifications +func ParseBytes(s string) (int64, error) { + s = strings.TrimSpace(s) + lastDigit := 0 + hasComma := false + for _, r := range s { + if !unicode.IsDigit(r) && r != '.' && r != ',' { + break + } + if r == ',' { + hasComma = true + } + lastDigit++ + } + + num := s[:lastDigit] + if hasComma { + num = strings.ReplaceAll(num, ",", "") + } + + f, err := strconv.ParseFloat(num, 64) + if err != nil { + return 0, err + } + + extra := strings.ToLower(strings.TrimSpace(s[lastDigit:])) + if m, ok := bytesSizeTable[extra]; ok { + f *= float64(m) + if f >= math.MaxInt64 { + return 0, fmt.Errorf("value too large: %v", s) + } + if f < 0 { + return 0, fmt.Errorf("negative value not allowed: %v", s) + } + return int64(f), nil + } + + return 0, fmt.Errorf("unhandled size name: %v", extra) +} + +// BytesToString converts []byte to string without allocations. +// https://github.com/kubernetes/kubernetes/blob/e4b74dd12fa8cb63c174091d5536a10b8ec19d34/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go#L278 +// Use only if strictly required, this method uses unsafe. +func BytesToString(b []byte) string { + // unsafe.SliceData relies on cap whereas we want to rely on len + if len(b) == 0 { + return "" + } + // https://github.com/golang/go/blob/4ed358b57efdad9ed710be7f4fc51495a7620ce2/src/strings/builder.go#L41 + return unsafe.String(unsafe.SliceData(b), len(b)) +} + +// StringToBytes convert string to []byte without allocations. +// https://github.com/kubernetes/kubernetes/blob/e4b74dd12fa8cb63c174091d5536a10b8ec19d34/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go#L289 +// Use only if strictly required, this method uses unsafe. +func StringToBytes(s string) []byte { + // unsafe.StringData is unspecified for the empty string, so we provide a strict interpretation + if s == "" { + return nil + } + // https://github.com/golang/go/blob/4ed358b57efdad9ed710be7f4fc51495a7620ce2/src/os/file.go#L300 + return unsafe.Slice(unsafe.StringData(s), len(s)) +} + +// GetIPFromRemoteAddress returns the IP from the remote address. +// If the given remote address cannot be parsed it will be returned unchanged +func GetIPFromRemoteAddress(remoteAddress string) string { + ip, _, err := net.SplitHostPort(remoteAddress) + if err == nil { + return ip + } + return remoteAddress +} + +// GetIPFromNetAddr returns the IP from the network address +func GetIPFromNetAddr(upstream net.Addr) (net.IP, error) { + if upstream == nil { + return nil, errors.New("invalid address") + } + upstreamString, _, err := net.SplitHostPort(upstream.String()) + if err != nil { + return nil, err + } + + upstreamIP := net.ParseIP(upstreamString) + if upstreamIP == nil { + return nil, fmt.Errorf("invalid IP address: %q", upstreamString) + } + + return upstreamIP, nil +} + +// NilIfEmpty returns nil if the input string is empty +func NilIfEmpty(s string) *string { + if s == "" { + return nil + } + return &s +} + +// GetStringFromPointer returns the string value or empty if nil +func GetStringFromPointer(val *string) string { + if val == nil { + return "" + } + return *val +} + +// GetIntFromPointer returns the int value or zero +func GetIntFromPointer(val *int64) int64 { + if val == nil { + return 0 + } + return *val +} + +// GetTimeFromPointer returns the time value or now +func GetTimeFromPointer(val *time.Time) time.Time { + if val == nil { + return time.Unix(0, 0) + } + return *val +} + +// GenerateRSAKeys generate rsa private and public keys and write the +// private key to specified file and the public key to the specified +// file adding the .pub suffix +func GenerateRSAKeys(file string) error { + if err := createDirPathIfMissing(file, 0700); err != nil { + return err + } + key, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + return err + } + + o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer o.Close() + + priv := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + } + + if err := pem.Encode(o, priv); err != nil { + return err + } + + pub, err := ssh.NewPublicKey(&key.PublicKey) + if err != nil { + return err + } + return os.WriteFile(file+pubKeySuffix, ssh.MarshalAuthorizedKey(pub), 0600) +} + +// GenerateECDSAKeys generate ecdsa private and public keys and write the +// private key to specified file and the public key to the specified +// file adding the .pub suffix +func GenerateECDSAKeys(file string) error { + if err := createDirPathIfMissing(file, 0700); err != nil { + return err + } + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + keyBytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + priv := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + } + + o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer o.Close() + + if err := pem.Encode(o, priv); err != nil { + return err + } + + pub, err := ssh.NewPublicKey(&key.PublicKey) + if err != nil { + return err + } + return os.WriteFile(file+pubKeySuffix, ssh.MarshalAuthorizedKey(pub), 0600) +} + +// GenerateEd25519Keys generate ed25519 private and public keys and write the +// private key to specified file and the public key to the specified +// file adding the .pub suffix +func GenerateEd25519Keys(file string) error { + pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return err + } + keyBytes, err := x509.MarshalPKCS8PrivateKey(privKey) + if err != nil { + return err + } + priv := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: keyBytes, + } + o, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer o.Close() + + if err := pem.Encode(o, priv); err != nil { + return err + } + pub, err := ssh.NewPublicKey(pubKey) + if err != nil { + return err + } + return os.WriteFile(file+pubKeySuffix, ssh.MarshalAuthorizedKey(pub), 0600) +} + +// IsDirOverlapped returns true if dir1 and dir2 overlap +func IsDirOverlapped(dir1, dir2 string, fullCheck bool, separator string) bool { + if dir1 == dir2 { + return true + } + if fullCheck { + if len(dir1) > len(dir2) { + if strings.HasPrefix(dir1, dir2+separator) { + return true + } + } + if len(dir2) > len(dir1) { + if strings.HasPrefix(dir2, dir1+separator) { + return true + } + } + } + return false +} + +// GetDirsForVirtualPath returns all the directory for the given path in reverse order +// for example if the path is: /1/2/3/4 it returns: +// [ "/1/2/3/4", "/1/2/3", "/1/2", "/1", "/" ] +func GetDirsForVirtualPath(virtualPath string) []string { + if virtualPath == "" || virtualPath == "." { + virtualPath = "/" + } else { + if !path.IsAbs(virtualPath) { + virtualPath = CleanPath(virtualPath) + } + } + dirsForPath := []string{virtualPath} + for virtualPath != "/" { + virtualPath = path.Dir(virtualPath) + dirsForPath = append(dirsForPath, virtualPath) + } + return dirsForPath +} + +// CleanPath returns a clean POSIX (/) absolute path to work with +func CleanPath(p string) string { + return CleanPathWithBase("/", p) +} + +// CleanPathWithBase returns a clean POSIX (/) absolute path to work with. +// The specified base will be used if the provided path is not absolute +func CleanPathWithBase(base, p string) string { + p = strings.ReplaceAll(p, "\\", "/") + if !path.IsAbs(p) { + p = path.Join(base, p) + } + return path.Clean(p) +} + +// IsFileInputValid returns true this is a valid file name. +// This method must be used before joining a file name, generally provided as +// user input, with a directory +func IsFileInputValid(fileInput string) bool { + cleanInput := filepath.Clean(fileInput) + if cleanInput == "." || cleanInput == ".." { + return false + } + return true +} + +// CleanDirInput sanitizes user input for directories. +// On Windows it removes any trailing `"`. +// We try to help windows users that set an invalid path such as "C:\ProgramData\SFTPGO\". +// This will only help if the invalid path is the last argument, for example in this command: +// sftpgo.exe serve -c "C:\ProgramData\SFTPGO\" -l "sftpgo.log" +// the -l flag will be ignored and the -c flag will get the value `C:\ProgramData\SFTPGO" -l sftpgo.log` +// since the backslash after SFTPGO escape the double quote. This is definitely a bad user input +func CleanDirInput(dirInput string) string { + if runtime.GOOS == osWindows { + for strings.HasSuffix(dirInput, "\"") { + dirInput = strings.TrimSuffix(dirInput, "\"") + } + } + return filepath.Clean(dirInput) +} + +func createDirPathIfMissing(file string, perm os.FileMode) error { + dirPath := filepath.Dir(file) + if _, err := os.Stat(dirPath); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(dirPath, perm) + if err != nil { + return err + } + } + return nil +} + +// GenerateRandomBytes generates random bytes with the specified length +func GenerateRandomBytes(length int) []byte { + b := make([]byte, length) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + PanicOnError(fmt.Errorf("failed to read random data (see https://go.dev/issue/66821): %w", err)) + } + return b +} + +// GenerateOpaqueString generates a cryptographically secure opaque string +func GenerateOpaqueString() string { + randomBytes := sha256.Sum256(GenerateRandomBytes(32)) + return hex.EncodeToString(randomBytes[:]) +} + +// GenerateUniqueID returns an unique ID +func GenerateUniqueID() string { + u, err := uuid.NewRandom() + if err != nil { + PanicOnError(fmt.Errorf("failed to read random data (see https://go.dev/issue/66821): %w", err)) + } + return shortuuid.DefaultEncoder.Encode(u) +} + +// HTTPListenAndServe is a wrapper for ListenAndServe that support both tcp +// and Unix-domain sockets +func HTTPListenAndServe(srv *http.Server, address string, port int, isTLS bool, + listenerWrapper func(net.Listener) (net.Listener, error), + logSender string, +) error { + var listener net.Listener + var err error + + if filepath.IsAbs(address) && runtime.GOOS != osWindows { + if !IsFileInputValid(address) { + return fmt.Errorf("invalid socket address %q", address) + } + err = createDirPathIfMissing(address, 0770) + if err != nil { + logger.ErrorToConsole("error creating Unix-domain socket parent dir: %v", err) + logger.Error(logSender, "", "error creating Unix-domain socket parent dir: %v", err) + } + os.Remove(address) + listener, err = net.Listen("unix", address) + if err == nil { + // should a chmod err be fatal? + if errChmod := os.Chmod(address, 0770); errChmod != nil { + logger.Warn(logSender, "", "unable to set the Unix-domain socket group writable: %v", errChmod) + } + } + } else { + CheckTCP4Port(port) + listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", address, port)) + } + if err != nil { + return err + } + if listenerWrapper != nil { + listener, err = listenerWrapper(listener) + if err != nil { + return err + } + } + logger.Info(logSender, "", "server listener registered, address: %s TLS enabled: %t", listener.Addr().String(), isTLS) + + defer listener.Close() + + if isTLS { + return srv.ServeTLS(listener, "", "") + } + return srv.Serve(listener) +} + +// GetTLSCiphersFromNames returns the TLS ciphers from the specified names +func GetTLSCiphersFromNames(cipherNames []string) []uint16 { + var ciphers []uint16 + + for _, name := range RemoveDuplicates(cipherNames, false) { + for _, c := range tls.CipherSuites() { + if c.Name == strings.TrimSpace(name) { + ciphers = append(ciphers, c.ID) + } + } + for _, c := range tls.InsecureCipherSuites() { + if c.Name == strings.TrimSpace(name) { + ciphers = append(ciphers, c.ID) + } + } + } + + if len(ciphers) == 0 { + // return a secure default + return defaultTLSCiphers + } + + return ciphers +} + +// GetALPNProtocols returns the ALPN protocols, any invalid protocol will be +// silently ignored. If no protocol or no valid protocol is provided the default +// is http/1.1, h2 +func GetALPNProtocols(protocols []string) []string { + var result []string + for _, p := range protocols { + switch p { + case "http/1.1", "h2": + result = append(result, p) + } + } + if len(result) == 0 { + return []string{"http/1.1", "h2"} + } + return result +} + +// EncodeTLSCertToPem returns the specified certificate PEM encoded. +// This can be verified using openssl x509 -in cert.crt -text -noout +func EncodeTLSCertToPem(tlsCert *x509.Certificate) (string, error) { + if len(tlsCert.Raw) == 0 { + return "", errors.New("invalid x509 certificate, no der contents") + } + publicKeyBlock := pem.Block{ + Type: "CERTIFICATE", + Bytes: tlsCert.Raw, + } + return BytesToString(pem.EncodeToMemory(&publicKeyBlock)), nil +} + +// CheckTCP4Port quits the app if bind on the given IPv4 port fails. +// This is a ugly hack to avoid to bind on an already used port. +// It is required on Windows only. Upstream does not consider this +// behaviour a bug: +// https://github.com/golang/go/issues/45150 +func CheckTCP4Port(port int) { + if runtime.GOOS != osWindows { + return + } + listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) + if err != nil { + logger.ErrorToConsole("unable to bind on tcp4 address: %v", err) + logger.Error(logSender, "", "unable to bind on tcp4 address: %v", err) + os.Exit(1) + } + listener.Close() +} + +// IsByteArrayEmpty return true if the byte array is empty or a new line +func IsByteArrayEmpty(b []byte) bool { + if len(b) == 0 { + return true + } + if bytes.Equal(b, []byte("\n")) { + return true + } + if bytes.Equal(b, []byte("\r\n")) { + return true + } + return false +} + +// GetSSHPublicKeyAsString returns an SSH public key serialized as string +func GetSSHPublicKeyAsString(pubKey []byte) (string, error) { + if len(pubKey) == 0 { + return "", nil + } + k, err := ssh.ParsePublicKey(pubKey) + if err != nil { + return "", err + } + return BytesToString(ssh.MarshalAuthorizedKey(k)), nil +} + +// GetRealIP returns the ip address as result of parsing the specified +// header and using the specified depth +func GetRealIP(r *http.Request, header string, depth int) string { + if header == "" { + return "" + } + var ipAddresses []string + + for _, h := range r.Header.Values(header) { + for ipStr := range strings.SplitSeq(h, ",") { + ipStr = strings.TrimSpace(ipStr) + ipAddresses = append(ipAddresses, ipStr) + } + } + + idx := len(ipAddresses) - 1 - depth + if idx >= 0 { + ip := strings.TrimSpace(ipAddresses[idx]) + if ip == "" || net.ParseIP(ip) == nil { + return "" + } + return ip + } + + return "" +} + +// GetHTTPLocalAddress returns the local address for an http.Request +// or empty if it cannot be determined +func GetHTTPLocalAddress(r *http.Request) string { + if r == nil { + return "" + } + localAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr) + if ok { + return localAddr.String() + } + return "" +} + +// ParseAllowedIPAndRanges returns a list of functions that allow to find if an +// IP is equal or is contained within the allowed list +func ParseAllowedIPAndRanges(allowed []string) ([]func(net.IP) bool, error) { + res := make([]func(net.IP) bool, len(allowed)) + for i, allowFrom := range allowed { + if strings.LastIndex(allowFrom, "/") > 0 { + _, ipRange, err := net.ParseCIDR(allowFrom) + if err != nil { + return nil, fmt.Errorf("given string %q is not a valid IP range: %v", allowFrom, err) + } + + res[i] = ipRange.Contains + } else { + allowed := net.ParseIP(allowFrom) + if allowed == nil { + return nil, fmt.Errorf("given string %q is not a valid IP address", allowFrom) + } + + res[i] = allowed.Equal + } + } + + return res, nil +} + +// GetRedactedURL returns the url redacting the password if any +func GetRedactedURL(rawurl string) string { + if !strings.HasPrefix(rawurl, "http") { + return rawurl + } + u, err := url.Parse(rawurl) + if err != nil { + return rawurl + } + return u.Redacted() +} + +// GetTLSVersion returns the TLS version from an integer value: +// - 10 means TLS 1.0 +// - 11 means TLS 1.1 +// - 12 means TLS 1.2 +// - 13 means TLS 1.3 +// default is TLS 1.2 +func GetTLSVersion(val int) uint16 { + switch val { + case 13: + return tls.VersionTLS13 + case 11: + return tls.VersionTLS11 + case 10: + return tls.VersionTLS10 + default: + return tls.VersionTLS12 + } +} + +// IsEmailValid returns true if the specified email address is valid +func IsEmailValid(email string) bool { + return emailRegex.MatchString(email) +} + +// SanitizeDomain return the specified domain name in a form suitable to save as file +func SanitizeDomain(domain string) string { + return strings.NewReplacer(":", "_", "*", "_", ",", "_", " ", "_").Replace(domain) +} + +// PanicOnError calls panic if err is not nil +func PanicOnError(err error) { + if err != nil { + panic(fmt.Errorf("unexpected error: %w", err)) + } +} + +// GetAbsolutePath returns an absolute path using the current dir as base +// if name defines a relative path +func GetAbsolutePath(name string) (string, error) { + if name == "" { + return name, errors.New("input path cannot be empty") + } + if filepath.IsAbs(name) { + return name, nil + } + curDir, err := os.Getwd() + if err != nil { + return name, err + } + return filepath.Join(curDir, name), nil +} + +// GetACMECertificateKeyPair returns the path to the ACME TLS crt and key for the specified domain +func GetACMECertificateKeyPair(domain string) (string, string) { + if CertsBasePath == "" { + return "", "" + } + domain = SanitizeDomain(domain) + return filepath.Join(CertsBasePath, domain+".crt"), filepath.Join(CertsBasePath, domain+".key") +} + +// GetLastIPForPrefix returns the last IP for the given prefix +// https://github.com/go4org/netipx/blob/8449b0a6169f5140fb0340cb4fc0de4c9b281ef6/netipx.go#L173 +func GetLastIPForPrefix(p netip.Prefix) netip.Addr { + if !p.IsValid() { + return netip.Addr{} + } + a16 := p.Addr().As16() + var off uint8 + var bits uint8 = 128 + if p.Addr().Is4() { + off = 12 + bits = 32 + } + for b := uint8(p.Bits()); b < bits; b++ { + byteNum, bitInByte := b/8, 7-(b%8) + a16[off+byteNum] |= 1 << uint(bitInByte) + } + if p.Addr().Is4() { + return netip.AddrFrom16(a16).Unmap() + } + return netip.AddrFrom16(a16) // doesn't unmap +} + +// JSONEscape returns the JSON escaped format for the input string +func JSONEscape(val string) string { + if val == "" { + return val + } + b, err := json.Marshal(val) + if err != nil { + return "" + } + return BytesToString(b[1 : len(b)-1]) +} + +// ReadConfigFromFile reads a configuration parameter from the specified file +func ReadConfigFromFile(name, configDir string) (string, error) { + if !IsFileInputValid(name) { + return "", fmt.Errorf("invalid file input: %q", name) + } + if configDir == "" { + if !filepath.IsAbs(name) { + return "", fmt.Errorf("%q must be an absolute file path", name) + } + } else { + if name != "" && !filepath.IsAbs(name) { + name = filepath.Join(configDir, name) + } + } + val, err := os.ReadFile(name) + if err != nil { + return "", err + } + return strings.TrimSpace(BytesToString(val)), nil +} + +// SlicesEqual checks if the provided slices contain the same elements, +// also in different order. +func SlicesEqual(s1, s2 []string) bool { + if len(s1) != len(s2) { + return false + } + for _, v := range s1 { + if !slices.Contains(s2, v) { + return false + } + } + + return true +} + +// VerifyFileChecksum computes the hash of the given file using the provided +// hash algorithm and compares it against the expected checksum (in hex format). +// It returns an error if the checksum does not match or if the operation fails. +func VerifyFileChecksum(filePath string, h hash.Hash, expectedHex string, maxSize int64) error { + expected, err := hex.DecodeString(expectedHex) + if err != nil { + return fmt.Errorf("invalid checksum %q: %w", expectedHex, err) + } + + f, err := os.Open(filePath) + if err != nil { + return err + } + defer f.Close() + + if maxSize > 0 { + fi, err := f.Stat() + if err != nil { + return err + } + if fi.Size() > maxSize { + return fmt.Errorf("file too large: %s", ByteCountIEC(fi.Size())) + } + } + + if _, err := io.Copy(h, f); err != nil { + return err + } + + actual := h.Sum(nil) + if subtle.ConstantTimeCompare(actual, expected) != 1 { + return errors.New("checksum mismatch") + } + + return nil +} diff --git a/internal/util/util_fallback.go b/internal/util/util_fallback.go new file mode 100644 index 00000000..d673f88c --- /dev/null +++ b/internal/util/util_fallback.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !unix + +package util + +import ( + "runtime" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// SetUmask sets the specified umask +func SetUmask(val string) { + if val == "" { + return + } + logger.Debug(logSender, "", "umask not supported on OS %q", runtime.GOOS) +} diff --git a/internal/util/util_unix.go b/internal/util/util_unix.go new file mode 100644 index 00000000..071c690d --- /dev/null +++ b/internal/util/util_unix.go @@ -0,0 +1,38 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build unix + +package util + +import ( + "strconv" + "syscall" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +// SetUmask sets the specified umask +func SetUmask(val string) { + if val == "" { + return + } + umask, err := strconv.ParseUint(val, 8, 31) + if err != nil { + logger.Error(logSender, "", "invalid umask %q: %v", val, err) + return + } + logger.Debug(logSender, "", "set umask to: %d, configured value: %q", umask, val) + syscall.Umask(int(umask)) +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 00000000..c9e066ea --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,108 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package version defines SFTPGo version details +package version + +import "strings" + +const ( + version = "2.7.99-dev" + appName = "SFTPGo" +) + +var ( + commit = "" + date = "" + info Info +) + +var ( + config string +) + +// Info defines version details +type Info struct { + Version string `json:"version"` + BuildDate string `json:"build_date"` + CommitHash string `json:"commit_hash"` + Features []string `json:"features"` +} + +// GetAsString returns the string representation of the version +func GetAsString() string { + var sb strings.Builder + sb.WriteString(info.Version) + if info.CommitHash != "" { + sb.WriteString("-") + sb.WriteString(info.CommitHash) + } + if info.BuildDate != "" { + sb.WriteString("-") + sb.WriteString(info.BuildDate) + } + if len(info.Features) > 0 { + sb.WriteString(" ") + sb.WriteString(strings.Join(info.Features, " ")) + } + return sb.String() +} + +func init() { + info = Info{ + Version: version, + CommitHash: commit, + BuildDate: date, + } +} + +// AddFeature adds a feature description +func AddFeature(feature string) { + info.Features = append(info.Features, feature) +} + +// Get returns the Info struct +func Get() Info { + return info +} + +// SetConfig sets the version configuration +func SetConfig(val string) { + config = val +} + +// GetServerVersion returns the server version according to the configuration +// and the provided parameters. +func GetServerVersion(separator string, addHash bool) string { + var sb strings.Builder + sb.WriteString(appName) + if config != "short" { + sb.WriteString(separator) + sb.WriteString(info.Version) + } + if addHash { + sb.WriteString(separator) + sb.WriteString(info.CommitHash) + } + return sb.String() +} + +// GetVersionHash returns the server identification string with the commit hash. +func GetVersionHash() string { + var sb strings.Builder + sb.WriteString(appName) + sb.WriteString("-") + sb.WriteString(info.CommitHash) + return sb.String() +} diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go new file mode 100644 index 00000000..46b7ae78 --- /dev/null +++ b/internal/vfs/azblobfs.go @@ -0,0 +1,1261 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !noazblob + +package vfs + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" + "github.com/google/uuid" + "github.com/pkg/sftp" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + azureDefaultEndpoint = "blob.core.windows.net" + azFolderKey = "hdi_isfolder" +) + +var ( + azureBlobDefaultPageSize = int32(5000) +) + +// AzureBlobFs is a Fs implementation for Azure Blob storage. +type AzureBlobFs struct { + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + config *AzBlobFsConfig + containerClient *container.Client + ctxTimeout time.Duration + ctxLongTimeout time.Duration +} + +func init() { + version.AddFeature("+azblob") +} + +// NewAzBlobFs returns an AzBlobFs object that allows to interact with Azure Blob storage +func NewAzBlobFs(connectionID, localTempDir, mountPath string, config AzBlobFsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = getLocalTempDir() + } + fs := &AzureBlobFs{ + connectionID: connectionID, + localTempDir: localTempDir, + mountPath: getMountPath(mountPath), + config: &config, + ctxTimeout: 30 * time.Second, + ctxLongTimeout: 90 * time.Second, + } + if err := fs.config.validate(); err != nil { + return fs, err + } + + if err := fs.config.tryDecrypt(); err != nil { + return fs, err + } + + fs.setConfigDefaults() + + if fs.config.SASURL.GetPayload() != "" { + return fs.initFromSASURL() + } + + var endpoint string + if fs.config.UseEmulator { + endpoint = fmt.Sprintf("%s/%s", fs.config.Endpoint, fs.config.AccountName) + } else { + endpoint = fmt.Sprintf("https://%s.%s/", fs.config.AccountName, fs.config.Endpoint) + } + containerURL := runtime.JoinPaths(endpoint, fs.config.Container) + if fs.config.AccountKey.GetPayload() != "" { + credential, err := blob.NewSharedKeyCredential(fs.config.AccountName, fs.config.AccountKey.GetPayload()) + if err != nil { + return fs, fmt.Errorf("invalid credentials: %v", err) + } + svc, err := container.NewClientWithSharedKeyCredential(containerURL, credential, getAzContainerClientOptions()) + if err != nil { + return fs, fmt.Errorf("unable to create the storage client using shared key credentials: %v", err) + } + fs.containerClient = svc + return fs, err + } + credential, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return fs, fmt.Errorf("invalid default azure credentials: %v", err) + } + svc, err := container.NewClient(containerURL, credential, getAzContainerClientOptions()) + if err != nil { + return fs, fmt.Errorf("unable to create the storage client using azure credentials: %v", err) + } + fs.containerClient = svc + return fs, err +} + +func (fs *AzureBlobFs) initFromSASURL() (Fs, error) { + parts, err := blob.ParseURL(fs.config.SASURL.GetPayload()) + if err != nil { + return fs, fmt.Errorf("invalid SAS URL: %w", err) + } + if parts.BlobName != "" { + return fs, fmt.Errorf("SAS URL with blob name not supported") + } + if parts.ContainerName != "" { + if fs.config.Container != "" && fs.config.Container != parts.ContainerName { + return fs, fmt.Errorf("container name in SAS URL %q and container provided %q do not match", + parts.ContainerName, fs.config.Container) + } + svc, err := container.NewClientWithNoCredential(fs.config.SASURL.GetPayload(), getAzContainerClientOptions()) + if err != nil { + return fs, fmt.Errorf("invalid credentials: %v", err) + } + fs.config.Container = parts.ContainerName + fs.containerClient = svc + return fs, nil + } + if fs.config.Container == "" { + return fs, errors.New("container is required with this SAS URL") + } + sasURL := runtime.JoinPaths(fs.config.SASURL.GetPayload(), fs.config.Container) + svc, err := container.NewClientWithNoCredential(sasURL, getAzContainerClientOptions()) + if err != nil { + return fs, fmt.Errorf("invalid credentials: %v", err) + } + fs.containerClient = svc + return fs, nil +} + +// Name returns the name for the Fs implementation +func (fs *AzureBlobFs) Name() string { + if !fs.config.SASURL.IsEmpty() { + return fmt.Sprintf("%s with SAS URL, container %q", azBlobFsName, fs.config.Container) + } + return fmt.Sprintf("%s container %q", azBlobFsName, fs.config.Container) +} + +// ConnectionID returns the connection ID associated to this Fs implementation +func (fs *AzureBlobFs) ConnectionID() string { + return fs.connectionID +} + +// Stat returns a FileInfo describing the named file +func (fs *AzureBlobFs) Stat(name string) (os.FileInfo, error) { + if name == "" || name == "/" || name == "." { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + if fs.config.KeyPrefix == name+"/" { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + + attrs, err := fs.headObject(name) + if err == nil { + contentType := util.GetStringFromPointer(attrs.ContentType) + isDir := checkDirectoryMarkers(contentType, attrs.Metadata) + lastModified := util.GetTimeFromPointer(attrs.LastModified) + if val := getAzureLastModified(attrs.Metadata); val > 0 { + lastModified = util.GetTimeFromMsecSinceEpoch(val) + } + info := NewFileInfo(name, isDir, util.GetIntFromPointer(attrs.ContentLength), lastModified, false) + if !isDir { + info.setMetadataFromPointerVal(attrs.Metadata) + } + return info, nil + } + if !fs.IsNotExist(err) { + return nil, err + } + // now check if this is a prefix (virtual directory) + hasContents, err := fs.hasContents(name) + if err != nil { + return nil, err + } + if hasContents { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + return nil, os.ErrNotExist +} + +// Lstat returns a FileInfo describing the named file +func (fs *AzureBlobFs) Lstat(name string) (os.FileInfo, error) { + return fs.Stat(name) +} + +// Open opens the named file for reading +func (fs *AzureBlobFs) Open(name string, offset int64) (File, PipeReader, func(), error) { + r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1) + if err != nil { + return nil, nil, nil, err + } + p := NewPipeReader(r) + ctx, cancelFn := context.WithCancel(context.Background()) + + go func() { + defer cancelFn() + + blockBlob := fs.containerClient.NewBlockBlobClient(name) + err := fs.handleMultipartDownload(ctx, blockBlob, offset, w, p) + w.CloseWithError(err) //nolint:errcheck + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, w.GetWrittenBytes(), err) + metric.AZTransferCompleted(w.GetWrittenBytes(), 1, err) + }() + + return nil, p, cancelFn, nil +} + +// Create creates or opens the named file for writing +func (fs *AzureBlobFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { + if checks&CheckParentDir != 0 { + _, err := fs.Stat(path.Dir(name)) + if err != nil { + return nil, nil, nil, err + } + } + r, w, err := createPipeFn(fs.localTempDir, fs.config.UploadPartSize+1024*1024) + if err != nil { + return nil, nil, nil, err + } + ctx, cancelFn := context.WithCancel(context.Background()) + + var p PipeWriter + if checks&CheckResume != 0 { + p = newPipeWriterAtOffset(w, 0) + } else { + p = NewPipeWriter(w) + } + headers := blob.HTTPHeaders{} + var contentType string + var metadata map[string]*string + if flag == -1 { + contentType = dirMimeType + metadata = map[string]*string{ + azFolderKey: util.NilIfEmpty("true"), + } + } else { + contentType = mime.TypeByExtension(path.Ext(name)) + } + if contentType != "" { + headers.BlobContentType = &contentType + } + + go func() { + defer cancelFn() + + blockBlob := fs.containerClient.NewBlockBlobClient(name) + err := fs.handleMultipartUpload(ctx, r, blockBlob, &headers, metadata) + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %+v", name, r.GetReadedBytes(), err) + metric.AZTransferCompleted(r.GetReadedBytes(), 0, err) + }() + + if checks&CheckResume != 0 { + readCh := make(chan error, 1) + + go func() { + n, err := fs.downloadToWriter(name, p) + pw := p.(*pipeWriterAtOffset) + pw.offset = 0 + pw.writeOffset = n + readCh <- err + }() + + err = <-readCh + if err != nil { + cancelFn() + p.Close() + fsLog(fs, logger.LevelDebug, "download before resume failed, writer closed and read cancelled") + return nil, nil, nil, err + } + } + + if uploadMode&16 != 0 { + return nil, p, nil, nil + } + return nil, p, cancelFn, nil +} + +// Rename renames (moves) source to target. +func (fs *AzureBlobFs) Rename(source, target string, checks int) (int, int64, error) { + if source == target { + return -1, -1, nil + } + if checks&CheckParentDir != 0 { + _, err := fs.Stat(path.Dir(target)) + if err != nil { + return -1, -1, err + } + } + fi, err := fs.Stat(source) + if err != nil { + return -1, -1, err + } + return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0) +} + +// Remove removes the named file or (empty) directory. +func (fs *AzureBlobFs) Remove(name string, isDir bool) error { + if isDir { + hasContents, err := fs.hasContents(name) + if err != nil { + return err + } + if hasContents { + return fmt.Errorf("cannot remove non empty directory: %q", name) + } + } + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + blobBlock := fs.containerClient.NewBlockBlobClient(name) + var deletSnapshots blob.DeleteSnapshotsOptionType + if !isDir { + deletSnapshots = blob.DeleteSnapshotsOptionTypeInclude + } + _, err := blobBlock.Delete(ctx, &blob.DeleteOptions{ + DeleteSnapshots: &deletSnapshots, + }) + if err != nil && isDir { + if fs.isBadRequestError(err) { + deletSnapshots = blob.DeleteSnapshotsOptionTypeInclude + _, err = blobBlock.Delete(ctx, &blob.DeleteOptions{ + DeleteSnapshots: &deletSnapshots, + }) + } + } + metric.AZDeleteObjectCompleted(err) + return err +} + +// Mkdir creates a new directory with the specified name and default permissions +func (fs *AzureBlobFs) Mkdir(name string) error { + _, err := fs.Stat(name) + if !fs.IsNotExist(err) { + return err + } + return fs.mkdirInternal(name) +} + +// Symlink creates source as a symbolic link to target. +func (*AzureBlobFs) Symlink(_, _ string) error { + return ErrVfsUnsupported +} + +// Readlink returns the destination of the named symbolic link +func (*AzureBlobFs) Readlink(_ string) (string, error) { + return "", ErrVfsUnsupported +} + +// Chown changes the numeric uid and gid of the named file. +func (*AzureBlobFs) Chown(_ string, _ int, _ int) error { + return ErrVfsUnsupported +} + +// Chmod changes the mode of the named file to mode. +func (*AzureBlobFs) Chmod(_ string, _ os.FileMode) error { + return ErrVfsUnsupported +} + +// Chtimes changes the access and modification times of the named file. +func (fs *AzureBlobFs) Chtimes(name string, _, mtime time.Time, isUploading bool) error { + if isUploading { + return nil + } + props, err := fs.headObject(name) + if err != nil { + return err + } + metadata := props.Metadata + if metadata == nil { + metadata = make(map[string]*string) + } + found := false + for k := range metadata { + if strings.EqualFold(k, lastModifiedField) { + metadata[k] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) + found = true + break + } + } + if !found { + metadata[lastModifiedField] = to.Ptr(strconv.FormatInt(mtime.UnixMilli(), 10)) + } + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + _, err = fs.containerClient.NewBlockBlobClient(name).SetMetadata(ctx, metadata, &blob.SetMetadataOptions{}) + return err +} + +// Truncate changes the size of the named file. +// Truncate by path is not supported, while truncating an opened +// file is handled inside base transfer +func (*AzureBlobFs) Truncate(_ string, _ int64) error { + return ErrVfsUnsupported +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (fs *AzureBlobFs) ReadDir(dirname string) (DirLister, error) { + // dirname must be already cleaned + prefix := fs.getPrefix(dirname) + pager := fs.containerClient.NewListBlobsHierarchyPager("/", &container.ListBlobsHierarchyOptions{ + Include: container.ListBlobsInclude{ + Metadata: true, + }, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, + }) + + return &azureBlobDirLister{ + paginator: pager, + timeout: fs.ctxTimeout, + prefix: prefix, + prefixes: make(map[string]bool), + }, nil +} + +// IsUploadResumeSupported returns true if resuming uploads is supported. +// Resuming uploads is not supported on Azure Blob +func (*AzureBlobFs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (*AzureBlobFs) IsConditionalUploadResumeSupported(size int64) bool { + return size <= resumeMaxSize +} + +// IsAtomicUploadSupported returns true if atomic upload is supported. +// Azure Blob uploads are already atomic, we don't need to upload to a temporary +// file +func (*AzureBlobFs) IsAtomicUploadSupported() bool { + return false +} + +// IsNotExist returns a boolean indicating whether the error is known to +// report that a file or directory does not exist +func (*AzureBlobFs) IsNotExist(err error) bool { + if err == nil { + return false + } + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + return respErr.StatusCode == http.StatusNotFound + } + // os.ErrNotExist can be returned internally by fs.Stat + return errors.Is(err, os.ErrNotExist) +} + +// IsPermission returns a boolean indicating whether the error is known to +// report that permission is denied. +func (*AzureBlobFs) IsPermission(err error) bool { + if err == nil { + return false + } + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + return respErr.StatusCode == http.StatusForbidden || respErr.StatusCode == http.StatusUnauthorized + } + return false +} + +// IsNotSupported returns true if the error indicate an unsupported operation +func (*AzureBlobFs) IsNotSupported(err error) bool { + if err == nil { + return false + } + return errors.Is(err, ErrVfsUnsupported) +} + +func (*AzureBlobFs) isBadRequestError(err error) bool { + if err == nil { + return false + } + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + return respErr.StatusCode == http.StatusBadRequest + } + return false +} + +// CheckRootPath creates the specified local root directory if it does not exists +func (fs *AzureBlobFs) CheckRootPath(username string, uid int, gid int) bool { + // we need a local directory for temporary files + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) + return osFs.CheckRootPath(username, uid, gid) +} + +// ScanRootDirContents returns the number of files contained in the bucket, +// and their size +func (fs *AzureBlobFs) ScanRootDirContents() (int, int64, error) { + return fs.GetDirSize(fs.config.KeyPrefix) +} + +// GetDirSize returns the number of files and the size for a folder +// including any subfolders +func (fs *AzureBlobFs) GetDirSize(dirname string) (int, int64, error) { + numFiles := 0 + size := int64(0) + prefix := fs.getPrefix(dirname) + + pager := fs.containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + Include: container.ListBlobsInclude{ + Metadata: true, + }, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, + }) + + for pager.More() { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := pager.NextPage(ctx) + if err != nil { + metric.AZListObjectsCompleted(err) + return numFiles, size, err + } + for _, blobItem := range resp.Segment.BlobItems { + if blobItem.Properties != nil { + contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) + isDir := checkDirectoryMarkers(contentType, blobItem.Metadata) + blobSize := util.GetIntFromPointer(blobItem.Properties.ContentLength) + if isDir && blobSize == 0 { + continue + } + numFiles++ + size += blobSize + } + } + fsLog(fs, logger.LevelDebug, "scan in progress for %q, files: %d, size: %d", dirname, numFiles, size) + } + metric.AZListObjectsCompleted(nil) + + return numFiles, size, nil +} + +// GetAtomicUploadPath returns the path to use for an atomic upload. +// Azure Blob Storage uploads are already atomic, we never call this method +func (*AzureBlobFs) GetAtomicUploadPath(_ string) string { + return "" +} + +// GetRelativePath returns the path for a file relative to the user's home dir. +// This is the path as seen by SFTPGo users +func (fs *AzureBlobFs) GetRelativePath(name string) string { + rel := path.Clean(name) + if rel == "." { + rel = "" + } + if !path.IsAbs(rel) { + rel = "/" + rel + } + if fs.config.KeyPrefix != "" { + if !strings.HasPrefix(rel, "/"+fs.config.KeyPrefix) { + rel = "/" + } + rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) + } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } + return rel +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root +func (fs *AzureBlobFs) Walk(root string, walkFn filepath.WalkFunc) error { + prefix := fs.getPrefix(root) + pager := fs.containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + Include: container.ListBlobsInclude{ + Metadata: true, + }, + Prefix: &prefix, + MaxResults: &azureBlobDefaultPageSize, + }) + + for pager.More() { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := pager.NextPage(ctx) + if err != nil { + metric.AZListObjectsCompleted(err) + return err + } + for _, blobItem := range resp.Segment.BlobItems { + name := util.GetStringFromPointer(blobItem.Name) + if fs.isEqual(name, prefix) { + continue + } + blobSize := int64(0) + lastModified := time.Unix(0, 0) + isDir := false + if blobItem.Properties != nil { + contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) + isDir = checkDirectoryMarkers(contentType, blobItem.Metadata) + blobSize = util.GetIntFromPointer(blobItem.Properties.ContentLength) + lastModified = util.GetTimeFromPointer(blobItem.Properties.LastModified) + if val := getAzureLastModified(blobItem.Metadata); val > 0 { + lastModified = util.GetTimeFromMsecSinceEpoch(val) + } + } + err := walkFn(name, NewFileInfo(name, isDir, blobSize, lastModified, false), nil) + if err != nil { + return err + } + } + } + + metric.AZListObjectsCompleted(nil) + return walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), nil) +} + +// Join joins any number of path elements into a single path +func (*AzureBlobFs) Join(elem ...string) string { + return strings.TrimPrefix(path.Join(elem...), "/") +} + +// HasVirtualFolders returns true if folders are emulated +func (*AzureBlobFs) HasVirtualFolders() bool { + return true +} + +// ResolvePath returns the matching filesystem path for the specified sftp path +func (fs *AzureBlobFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { + virtualPath = after + } + } + virtualPath = path.Clean("/" + virtualPath) + return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil +} + +// CopyFile implements the FsFileCopier interface +func (fs *AzureBlobFs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) { + numFiles := 1 + sizeDiff := srcInfo.Size() + attrs, err := fs.headObject(target) + if err == nil { + sizeDiff -= util.GetIntFromPointer(attrs.ContentLength) + numFiles = 0 + } else { + if !fs.IsNotExist(err) { + return 0, 0, err + } + } + if err := fs.copyFileInternal(source, target, srcInfo, true); err != nil { + return 0, 0, err + } + return numFiles, sizeDiff, nil +} + +func (fs *AzureBlobFs) headObject(name string) (blob.GetPropertiesResponse, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.containerClient.NewBlockBlobClient(name).GetProperties(ctx, &blob.GetPropertiesOptions{}) + + metric.AZHeadObjectCompleted(err) + return resp, err +} + +// GetMimeType returns the content type +func (fs *AzureBlobFs) GetMimeType(name string) (string, error) { + response, err := fs.headObject(name) + if err != nil { + return "", err + } + return util.GetStringFromPointer(response.ContentType), nil +} + +// Close closes the fs +func (*AzureBlobFs) Close() error { + return nil +} + +// GetAvailableDiskSize returns the available size for the specified path +func (*AzureBlobFs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { + return nil, ErrStorageSizeUnavailable +} + +func (*AzureBlobFs) getPrefix(name string) string { + prefix := "" + if name != "" && name != "." { + prefix = strings.TrimPrefix(name, "/") + if !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + } + return prefix +} + +func (fs *AzureBlobFs) isEqual(key string, virtualName string) bool { + if key == virtualName { + return true + } + if key == virtualName+"/" { + return true + } + if key+"/" == virtualName { + return true + } + return false +} + +func (fs *AzureBlobFs) setConfigDefaults() { + if fs.config.Endpoint == "" { + fs.config.Endpoint = azureDefaultEndpoint + } + if fs.config.UploadPartSize == 0 { + fs.config.UploadPartSize = 5 + } + if fs.config.UploadPartSize < 1024*1024 { + fs.config.UploadPartSize *= 1024 * 1024 + } + if fs.config.UploadConcurrency == 0 { + fs.config.UploadConcurrency = 5 + } + if fs.config.DownloadPartSize == 0 { + fs.config.DownloadPartSize = 5 + } + if fs.config.DownloadPartSize < 1024*1024 { + fs.config.DownloadPartSize *= 1024 * 1024 + } + if fs.config.DownloadConcurrency == 0 { + fs.config.DownloadConcurrency = 5 + } +} + +func (fs *AzureBlobFs) copyFileInternal(source, target string, srcInfo os.FileInfo, updateModTime bool) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout)) + defer cancelFn() + + srcBlob := fs.containerClient.NewBlockBlobClient(source) + dstBlob := fs.containerClient.NewBlockBlobClient(target) + resp, err := dstBlob.StartCopyFromURL(ctx, srcBlob.URL(), fs.getCopyOptions(srcInfo, updateModTime)) + if err != nil { + metric.AZCopyObjectCompleted(err) + return err + } + copyStatus := blob.CopyStatusType(util.GetStringFromPointer((*string)(resp.CopyStatus))) + nErrors := 0 + for copyStatus == blob.CopyStatusTypePending { + // Poll until the copy is complete. + time.Sleep(500 * time.Millisecond) + resp, err := dstBlob.GetProperties(ctx, &blob.GetPropertiesOptions{}) + if err != nil { + // A GetProperties failure may be transient, so allow a couple + // of them before giving up. + nErrors++ + if ctx.Err() != nil || nErrors == 3 { + metric.AZCopyObjectCompleted(err) + return err + } + } else { + copyStatus = blob.CopyStatusType(util.GetStringFromPointer((*string)(resp.CopyStatus))) + } + } + if copyStatus != blob.CopyStatusTypeSuccess { + err := fmt.Errorf("copy failed with status: %s", copyStatus) + metric.AZCopyObjectCompleted(err) + return err + } + + metric.AZCopyObjectCompleted(nil) + return nil +} + +func (fs *AzureBlobFs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int, + updateModTime bool, +) (int, int64, error) { + var numFiles int + var filesSize int64 + + if srcInfo.IsDir() { + if renameMode == 0 { + hasContents, err := fs.hasContents(source) + if err != nil { + return numFiles, filesSize, err + } + if hasContents { + return numFiles, filesSize, fmt.Errorf("%w: cannot rename non empty directory: %q", ErrVfsUnsupported, source) + } + } + if err := fs.mkdirInternal(target); err != nil { + return numFiles, filesSize, err + } + if renameMode == 1 { + files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime) + numFiles += files + filesSize += size + if err != nil { + return numFiles, filesSize, err + } + } + } else { + if err := fs.copyFileInternal(source, target, srcInfo, updateModTime); err != nil { + return numFiles, filesSize, err + } + numFiles++ + filesSize += srcInfo.Size() + } + err := fs.skipNotExistErr(fs.Remove(source, srcInfo.IsDir())) + return numFiles, filesSize, err +} + +func (fs *AzureBlobFs) skipNotExistErr(err error) error { + if fs.IsNotExist(err) { + return nil + } + return err +} + +func (fs *AzureBlobFs) mkdirInternal(name string) error { + _, w, _, err := fs.Create(name, -1, 0) + if err != nil { + return err + } + return w.Close() +} + +func (fs *AzureBlobFs) hasContents(name string) (bool, error) { + result := false + prefix := fs.getPrefix(name) + + maxResults := int32(1) + pager := fs.containerClient.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{ + MaxResults: &maxResults, + Prefix: &prefix, + }) + + if pager.More() { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := pager.NextPage(ctx) + if err != nil { + metric.AZListObjectsCompleted(err) + return result, err + } + + result = len(resp.Segment.BlobItems) > 0 + } + + metric.AZListObjectsCompleted(nil) + return result, nil +} + +func (fs *AzureBlobFs) downloadPart(ctx context.Context, blockBlob *blockblob.Client, buf []byte, + w io.WriterAt, offset, count, writeOffset int64, +) error { + if count == 0 { + return nil + } + + resp, err := blockBlob.DownloadStream(ctx, &blob.DownloadStreamOptions{ + Range: blob.HTTPRange{ + Offset: offset, + Count: count, + }, + }) + if err != nil { + return err + } + defer resp.Body.Close() + + _, err = io.ReadAtLeast(resp.Body, buf, int(count)) + if err != nil { + return err + } + + return writeAtFull(w, buf, writeOffset, int(count)) +} + +func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *blockblob.Client, + offset int64, writer io.WriterAt, pipeReader PipeReader, +) error { + props, err := blockBlob.GetProperties(ctx, &blob.GetPropertiesOptions{}) + metric.AZHeadObjectCompleted(err) + if err != nil { + fsLog(fs, logger.LevelError, "unable to get blob properties, download aborted: %+v", err) + return err + } + if readMetadata > 0 && pipeReader != nil { + pipeReader.setMetadataFromPointerVal(props.Metadata) + } + contentLength := util.GetIntFromPointer(props.ContentLength) + sizeToDownload := contentLength - offset + if sizeToDownload < 0 { + fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %v, offset: %v, size to download: %v", + contentLength, offset, sizeToDownload) + return errors.New("the requested offset exceeds the file size") + } + if sizeToDownload == 0 { + fsLog(fs, logger.LevelDebug, "nothing to download, offset %v, content length %v", offset, contentLength) + return nil + } + partSize := fs.config.DownloadPartSize + guard := make(chan struct{}, fs.config.DownloadConcurrency) + blockCtxTimeout := time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute + pool := newBufferAllocator(int(partSize)) + defer pool.free() + + finished := false + var wg sync.WaitGroup + var errOnce sync.Once + var hasError atomic.Bool + var poolError error + + poolCtx, poolCancel := context.WithCancel(ctx) + defer poolCancel() + + for part := 0; !finished; part++ { + start := offset + end := offset + partSize + if end >= contentLength { + end = contentLength + finished = true + } + writeOffset := int64(part) * partSize + offset = end + + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelDebug, "pool error, download for part %v not started", part) + break + } + + buf := pool.getBuffer() + wg.Add(1) + go func(start, end, writeOffset int64, buf []byte) { + defer func() { + pool.releaseBuffer(buf) + <-guard + wg.Done() + }() + + innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) + defer cancelFn() + + count := end - start + + err := fs.downloadPart(innerCtx, blockBlob, buf, writer, start, count, writeOffset) + if err != nil { + errOnce.Do(func() { + fsLog(fs, logger.LevelError, "multipart download error: %+v", err) + hasError.Store(true) + poolError = fmt.Errorf("multipart download error: %w", err) + poolCancel() + }) + } + }(start, end, writeOffset, buf) + } + + wg.Wait() + close(guard) + + return poolError +} + +func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Reader, + blockBlob *blockblob.Client, httpHeaders *blob.HTTPHeaders, metadata map[string]*string, +) error { + partSize := fs.config.UploadPartSize + guard := make(chan struct{}, fs.config.UploadConcurrency) + blockCtxTimeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute + + // sync.Pool seems to use a lot of memory so prefer our own, very simple, allocator + // we only need to recycle few byte slices + pool := newBufferAllocator(int(partSize)) + defer pool.free() + + finished := false + var blocks []string + var wg sync.WaitGroup + var errOnce sync.Once + var hasError atomic.Bool + var poolError error + + poolCtx, poolCancel := context.WithCancel(ctx) + defer poolCancel() + + finalizeFailedUpload := func(err error) { + fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err) + hasError.Store(true) + poolError = fmt.Errorf("multipart upload error: %w", err) + poolCancel() + } + + for part := 0; !finished; part++ { + buf := pool.getBuffer() + + n, err := readFill(reader, buf) + if err == io.EOF { + // read finished, if n > 0 we need to process the last data chunck + if n == 0 { + pool.releaseBuffer(buf) + break + } + finished = true + } else if err != nil { + pool.releaseBuffer(buf) + errOnce.Do(func() { + finalizeFailedUpload(err) + }) + break + } + + // Block IDs are unique values to avoid issue if 2+ clients are uploading blocks + // at the same time causing CommitBlockList to get a mix of blocks from all the clients. + generatedUUID, err := uuid.NewRandom() + if err != nil { + pool.releaseBuffer(buf) + errOnce.Do(func() { + finalizeFailedUpload(err) + }) + break + } + blockID := base64.StdEncoding.EncodeToString([]byte(generatedUUID.String())) + blocks = append(blocks, blockID) + + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelError, "pool error, upload for part %d not started", part) + pool.releaseBuffer(buf) + break + } + + wg.Add(1) + go func(blockID string, buf []byte, bufSize int) { + defer func() { + pool.releaseBuffer(buf) + <-guard + wg.Done() + }() + + bufferReader := &bytesReaderWrapper{ + Reader: bytes.NewReader(buf[:bufSize]), + } + innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) + defer cancelFn() + + _, err := blockBlob.StageBlock(innerCtx, blockID, bufferReader, &blockblob.StageBlockOptions{}) + if err != nil { + errOnce.Do(func() { + fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err) + finalizeFailedUpload(err) + }) + } + }(blockID, buf, n) + } + + wg.Wait() + close(guard) + + if poolError != nil { + return poolError + } + + commitOptions := blockblob.CommitBlockListOptions{ + HTTPHeaders: httpHeaders, + Metadata: metadata, + } + if fs.config.AccessTier != "" { + commitOptions.Tier = (*blob.AccessTier)(&fs.config.AccessTier) + } + + _, err := blockBlob.CommitBlockList(ctx, blocks, &commitOptions) + return err +} + +func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions { + copyOptions := &blob.StartCopyFromURLOptions{} + if fs.config.AccessTier != "" { + copyOptions.Tier = (*blob.AccessTier)(&fs.config.AccessTier) + } + if updateModTime { + metadata := make(map[string]*string) + for k, v := range getMetadata(srcInfo) { + if v != "" { + if strings.EqualFold(k, lastModifiedField) { + metadata[k] = to.Ptr("0") + } else { + metadata[k] = to.Ptr(v) + } + } + } + if len(metadata) > 0 { + copyOptions.Metadata = metadata + } + } + + return copyOptions +} + +func (fs *AzureBlobFs) downloadToWriter(name string, w PipeWriter) (int64, error) { + fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name) + ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout) + defer cancelFn() + + blockBlob := fs.containerClient.NewBlockBlobClient(name) + err := fs.handleMultipartDownload(ctx, blockBlob, 0, w, nil) + n := w.GetWrittenBytes() + fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v", + name, n, err) + metric.AZTransferCompleted(n, 1, err) + return n, err +} + +func checkDirectoryMarkers(contentType string, metadata map[string]*string) bool { + if contentType == dirMimeType { + return true + } + for k, v := range metadata { + if strings.EqualFold(k, azFolderKey) { + return strings.EqualFold(util.GetStringFromPointer(v), "true") + } + } + return false +} + +func getAzContainerClientOptions() *container.ClientOptions { + return &container.ClientOptions{ + ClientOptions: azcore.ClientOptions{ + Telemetry: policy.TelemetryOptions{ + ApplicationID: version.GetVersionHash(), + }, + }, + } +} + +type azureBlobDirLister struct { + baseDirLister + paginator *runtime.Pager[container.ListBlobsHierarchyResponse] + timeout time.Duration + prefix string + prefixes map[string]bool + metricUpdated bool +} + +func (l *azureBlobDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + if !l.paginator.More() { + if !l.metricUpdated { + l.metricUpdated = true + metric.AZListObjectsCompleted(nil) + } + return l.returnFromCache(limit), io.EOF + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) + defer cancelFn() + + page, err := l.paginator.NextPage(ctx) + if err != nil { + metric.AZListObjectsCompleted(err) + return l.cache, err + } + + for _, blobPrefix := range page.Segment.BlobPrefixes { + name := util.GetStringFromPointer(blobPrefix.Name) + // we don't support prefixes == "/" this will be sent if a key starts with "/" + if name == "" || name == "/" { + continue + } + // sometime we have duplicate prefixes, maybe an Azurite bug + name = strings.TrimPrefix(name, l.prefix) + if _, ok := l.prefixes[strings.TrimSuffix(name, "/")]; ok { + continue + } + l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) + l.prefixes[strings.TrimSuffix(name, "/")] = true + } + + for _, blobItem := range page.Segment.BlobItems { + name := util.GetStringFromPointer(blobItem.Name) + name = strings.TrimPrefix(name, l.prefix) + size := int64(0) + isDir := false + var metadata map[string]*string + modTime := time.Unix(0, 0) + if blobItem.Properties != nil { + size = util.GetIntFromPointer(blobItem.Properties.ContentLength) + modTime = util.GetTimeFromPointer(blobItem.Properties.LastModified) + contentType := util.GetStringFromPointer(blobItem.Properties.ContentType) + isDir = checkDirectoryMarkers(contentType, blobItem.Metadata) + if isDir { + // check if the dir is already included, it will be sent as blob prefix if it contains at least one item + if _, ok := l.prefixes[name]; ok { + continue + } + l.prefixes[name] = true + } else { + metadata = blobItem.Metadata + } + if val := getAzureLastModified(blobItem.Metadata); val > 0 { + modTime = util.GetTimeFromMsecSinceEpoch(val) + } + } + info := NewFileInfo(name, isDir, size, modTime, false) + info.setMetadataFromPointerVal(metadata) + l.cache = append(l.cache, info) + } + + return l.returnFromCache(limit), nil +} + +func (l *azureBlobDirLister) Close() error { + clear(l.prefixes) + return l.baseDirLister.Close() +} diff --git a/internal/vfs/azblobfs_disabled.go b/internal/vfs/azblobfs_disabled.go new file mode 100644 index 00000000..35b83e2b --- /dev/null +++ b/internal/vfs/azblobfs_disabled.go @@ -0,0 +1,32 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build noazblob + +package vfs + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-azblob") +} + +// NewAzBlobFs returns an error, Azure Blob storage is disabled +func NewAzBlobFs(_, _, _ string, _ AzBlobFsConfig) (Fs, error) { + return nil, errors.New("Azure Blob Storage disabled at build time") +} diff --git a/internal/vfs/cryptfs.go b/internal/vfs/cryptfs.go new file mode 100644 index 00000000..8e76da30 --- /dev/null +++ b/internal/vfs/cryptfs.go @@ -0,0 +1,419 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "net/http" + "os" + + "github.com/minio/sio" + "golang.org/x/crypto/hkdf" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +const ( + // cryptFsName is the name for the local Fs implementation with encryption support + cryptFsName = "cryptfs" + version10 byte = 0x10 + nonceV10Size int = 32 + headerV10Size int64 = 33 // 1 (version byte) + 32 (nonce size) +) + +// CryptFs is a Fs implementation that allows to encrypts/decrypts local files +type CryptFs struct { + *OsFs + localTempDir string + masterKey []byte +} + +// NewCryptFs returns a CryptFs object +func NewCryptFs(connectionID, rootDir, mountPath string, config CryptFsConfig) (Fs, error) { + if err := config.validate(); err != nil { + return nil, err + } + if err := config.Passphrase.TryDecrypt(); err != nil { + return nil, err + } + fs := &CryptFs{ + OsFs: &OsFs{ + name: cryptFsName, + connectionID: connectionID, + rootDir: rootDir, + mountPath: getMountPath(mountPath), + readBufferSize: config.ReadBufferSize * 1024 * 1024, + writeBufferSize: config.WriteBufferSize * 1024 * 1024, + }, + masterKey: []byte(config.Passphrase.GetPayload()), + } + if tempPath == "" { + fs.localTempDir = rootDir + } else { + fs.localTempDir = tempPath + } + return fs, nil +} + +// Name returns the name for the Fs implementation +func (fs *CryptFs) Name() string { + return fs.name +} + +// Open opens the named file for reading +func (fs *CryptFs) Open(name string, offset int64) (File, PipeReader, func(), error) { + f, key, err := fs.getFileAndEncryptionKey(name) + if err != nil { + return nil, nil, nil, err + } + isZeroDownload, err := isZeroBytesDownload(f, offset) + if err != nil { + f.Close() + return nil, nil, nil, err + } + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + f.Close() + return nil, nil, nil, err + } + p := NewPipeReader(r) + + go func() { + if isZeroDownload { + w.CloseWithError(err) //nolint:errcheck + f.Close() + fsLog(fs, logger.LevelDebug, "zero bytes download completed, path: %q", name) + return + } + var n int64 + var err error + + if offset == 0 { + n, err = fs.decryptWrapper(w, f, fs.getSIOConfig(key)) + } else { + var readerAt io.ReaderAt + var readed, written int + buf := make([]byte, 65568) + wrapper := &cryptedFileWrapper{ + File: f, + } + readerAt, err = sio.DecryptReaderAt(wrapper, fs.getSIOConfig(key)) + if err == nil { + finished := false + for !finished { + readed, err = readerAt.ReadAt(buf, offset) + offset += int64(readed) + if err != nil && err != io.EOF { + break + } + if err == io.EOF { + finished = true + err = nil + } + if readed > 0 { + written, err = w.Write(buf[:readed]) + n += int64(written) + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + if readed != written { + err = io.ErrShortWrite + break + } + } + } + } + } + w.CloseWithError(err) //nolint:errcheck + f.Close() + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %v", name, n, err) + }() + + return nil, p, nil, nil +} + +// Create creates or opens the named file for writing +func (fs *CryptFs) Create(name string, _, _ int) (File, PipeWriter, func(), error) { + f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return nil, nil, nil, err + } + header := encryptedFileHeader{ + version: version10, + nonce: make([]byte, 32), + } + _, err = io.ReadFull(rand.Reader, header.nonce) + if err != nil { + f.Close() + return nil, nil, nil, err + } + var key [32]byte + kdf := hkdf.New(sha256.New, fs.masterKey, header.nonce, nil) + _, err = io.ReadFull(kdf, key[:]) + if err != nil { + f.Close() + return nil, nil, nil, err + } + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + f.Close() + return nil, nil, nil, err + } + err = header.Store(f) + if err != nil { + r.Close() + w.Close() + f.Close() + return nil, nil, nil, err + } + p := NewPipeWriter(w) + + go func() { + var n int64 + var err error + if fs.writeBufferSize <= 0 { + n, err = sio.Encrypt(f, r, fs.getSIOConfig(key)) + } else { + bw := bufio.NewWriterSize(f, fs.writeBufferSize) + n, err = fs.encryptWrapper(bw, r, fs.getSIOConfig(key)) + errFlush := bw.Flush() + if err == nil && errFlush != nil { + err = errFlush + } + } + errClose := f.Close() + if err == nil && errClose != nil { + err = errClose + } + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %v", name, n, err) + }() + + return nil, p, nil, nil +} + +// Truncate changes the size of the named file +func (*CryptFs) Truncate(_ string, _ int64) error { + return ErrVfsUnsupported +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (fs *CryptFs) ReadDir(dirname string) (DirLister, error) { + f, err := os.Open(dirname) + if err != nil { + if isInvalidNameError(err) { + err = os.ErrNotExist + } + return nil, err + } + + return &cryptFsDirLister{f}, nil +} + +// IsUploadResumeSupported returns false sio does not support random access writes +func (*CryptFs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (*CryptFs) IsConditionalUploadResumeSupported(_ int64) bool { + return false +} + +// GetMimeType returns the content type +func (fs *CryptFs) GetMimeType(name string) (string, error) { + f, key, err := fs.getFileAndEncryptionKey(name) + if err != nil { + return "", err + } + defer f.Close() + + readSize, err := sio.DecryptedSize(512) + if err != nil { + return "", err + } + buf := make([]byte, readSize) + n, err := io.ReadFull(f, buf) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return "", err + } + + decrypted := bytes.NewBuffer(nil) + _, err = sio.Decrypt(decrypted, bytes.NewBuffer(buf[:n]), fs.getSIOConfig(key)) + if err != nil { + return "", err + } + + ctype := http.DetectContentType(decrypted.Bytes()) + // Rewind file. + _, err = f.Seek(0, io.SeekStart) + return ctype, err +} + +func (fs *CryptFs) getSIOConfig(key [32]byte) sio.Config { + return sio.Config{ + MinVersion: sio.Version20, + MaxVersion: sio.Version20, + Key: key[:], + } +} + +// ConvertFileInfo returns a FileInfo with the decrypted size +func (fs *CryptFs) ConvertFileInfo(info os.FileInfo) os.FileInfo { + return convertCryptFsInfo(info) +} + +func (fs *CryptFs) getFileAndEncryptionKey(name string) (*os.File, [32]byte, error) { + var key [32]byte + f, err := os.Open(name) + if err != nil { + return nil, key, err + } + header := encryptedFileHeader{} + err = header.Load(f) + if err != nil { + f.Close() + return nil, key, err + } + kdf := hkdf.New(sha256.New, fs.masterKey, header.nonce, nil) + _, err = io.ReadFull(kdf, key[:]) + if err != nil { + f.Close() + return nil, key, err + } + return f, key, err +} + +func (*CryptFs) encryptWrapper(dst io.Writer, src io.Reader, config sio.Config) (int64, error) { + encReader, err := sio.EncryptReader(src, config) + if err != nil { + return 0, err + } + return doCopy(dst, encReader, make([]byte, 65568)) +} + +func (fs *CryptFs) decryptWrapper(dst io.Writer, src io.Reader, config sio.Config) (int64, error) { + if fs.readBufferSize <= 0 { + return sio.Decrypt(dst, src, config) + } + br := bufio.NewReaderSize(src, fs.readBufferSize) + decReader, err := sio.DecryptReader(br, config) + if err != nil { + return 0, err + } + return doCopy(dst, decReader, make([]byte, 65568)) +} + +func isZeroBytesDownload(f *os.File, offset int64) (bool, error) { + info, err := f.Stat() + if err != nil { + return false, err + } + if info.Size() == headerV10Size { + return true, nil + } + if info.Size() > headerV10Size { + decSize, err := sio.DecryptedSize(uint64(info.Size() - headerV10Size)) + if err != nil { + return false, err + } + if int64(decSize) == offset { + return true, nil + } + } + return false, nil +} + +func convertCryptFsInfo(info os.FileInfo) os.FileInfo { + if !info.Mode().IsRegular() { + return info + } + size := info.Size() + if size >= headerV10Size { + size -= headerV10Size + decryptedSize, err := sio.DecryptedSize(uint64(size)) + if err == nil { + size = int64(decryptedSize) + } + } else { + size = 0 + } + return NewFileInfo(info.Name(), info.IsDir(), size, info.ModTime(), false) +} + +type encryptedFileHeader struct { + version byte + nonce []byte +} + +func (h *encryptedFileHeader) Store(f *os.File) error { + buf := make([]byte, 0, headerV10Size) + buf = append(buf, version10) + buf = append(buf, h.nonce...) + _, err := f.Write(buf) + return err +} + +func (h *encryptedFileHeader) Load(f *os.File) error { + header := make([]byte, 1+nonceV10Size) + _, err := io.ReadFull(f, header) + if err != nil { + return err + } + h.version = header[0] + if h.version == version10 { + h.nonce = header[1:] + return nil + } + return fmt.Errorf("unsupported encryption version: %v", h.version) +} + +type cryptedFileWrapper struct { + *os.File +} + +func (w *cryptedFileWrapper) ReadAt(p []byte, offset int64) (n int, err error) { + return w.File.ReadAt(p, offset+headerV10Size) +} + +type cryptFsDirLister struct { + f *os.File +} + +func (l *cryptFsDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + files, err := l.f.Readdir(limit) + for idx := range files { + files[idx] = convertCryptFsInfo(files[idx]) + } + return files, err +} + +func (l *cryptFsDirLister) Close() error { + return l.f.Close() +} diff --git a/internal/vfs/fileinfo.go b/internal/vfs/fileinfo.go new file mode 100644 index 00000000..c27fbf79 --- /dev/null +++ b/internal/vfs/fileinfo.go @@ -0,0 +1,117 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "os" + "path" + "time" + + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// FileInfo implements os.FileInfo for a Cloud Storage file. +type FileInfo struct { + name string + sizeInBytes int64 + modTime time.Time + mode os.FileMode + metadata map[string]string +} + +// NewFileInfo creates file info. +func NewFileInfo(name string, isDirectory bool, sizeInBytes int64, modTime time.Time, fullName bool) *FileInfo { + mode := os.FileMode(0644) + if isDirectory { + mode = os.FileMode(0755) | os.ModeDir + } + if !fullName { + // we have always Unix style paths here + name = path.Base(name) + } + + return &FileInfo{ + name: name, + sizeInBytes: sizeInBytes, + modTime: modTime, + mode: mode, + } +} + +// Name provides the base name of the file. +func (fi *FileInfo) Name() string { + return fi.name +} + +// Size provides the length in bytes for a file. +func (fi *FileInfo) Size() int64 { + return fi.sizeInBytes +} + +// Mode provides the file mode bits +func (fi *FileInfo) Mode() os.FileMode { + return fi.mode +} + +// ModTime provides the last modification time. +func (fi *FileInfo) ModTime() time.Time { + return fi.modTime +} + +// IsDir provides the abbreviation for Mode().IsDir() +func (fi *FileInfo) IsDir() bool { + return fi.mode&os.ModeDir != 0 +} + +// SetMode sets the file mode +func (fi *FileInfo) SetMode(mode os.FileMode) { + fi.mode = mode +} + +// Sys provides the underlying data source (can return nil) +func (fi *FileInfo) Sys() any { + return fi.metadata +} + +func (fi *FileInfo) setMetadata(value map[string]string) { + fi.metadata = value +} + +func (fi *FileInfo) setMetadataFromPointerVal(value map[string]*string) { + if len(value) == 0 { + fi.metadata = nil + return + } + + fi.metadata = map[string]string{} + for k, v := range value { + val := util.GetStringFromPointer(v) + if val != "" { + fi.metadata[k] = val + } + } +} + +func getMetadata(fi os.FileInfo) map[string]string { + if fi.Sys() == nil { + return nil + } + if val, ok := fi.Sys().(map[string]string); ok { + if len(val) > 0 { + return val + } + } + return nil +} diff --git a/internal/vfs/filesystem.go b/internal/vfs/filesystem.go new file mode 100644 index 00000000..59ce47dd --- /dev/null +++ b/internal/vfs/filesystem.go @@ -0,0 +1,408 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "os" + + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +// Filesystem defines filesystem details +type Filesystem struct { + RedactedSecret string `json:"-"` + Provider sdk.FilesystemProvider `json:"provider"` + OSConfig sdk.OSFsConfig `json:"osconfig,omitempty"` + S3Config S3FsConfig `json:"s3config,omitempty"` + GCSConfig GCSFsConfig `json:"gcsconfig,omitempty"` + AzBlobConfig AzBlobFsConfig `json:"azblobconfig,omitempty"` + CryptConfig CryptFsConfig `json:"cryptconfig,omitempty"` + SFTPConfig SFTPFsConfig `json:"sftpconfig,omitempty"` + HTTPConfig HTTPFsConfig `json:"httpconfig,omitempty"` +} + +// SetEmptySecrets sets the secrets to empty +func (f *Filesystem) SetEmptySecrets() { + f.S3Config.AccessSecret = kms.NewEmptySecret() + f.S3Config.SSECustomerKey = kms.NewEmptySecret() + f.GCSConfig.Credentials = kms.NewEmptySecret() + f.AzBlobConfig.AccountKey = kms.NewEmptySecret() + f.AzBlobConfig.SASURL = kms.NewEmptySecret() + f.CryptConfig.Passphrase = kms.NewEmptySecret() + f.SFTPConfig.Password = kms.NewEmptySecret() + f.SFTPConfig.PrivateKey = kms.NewEmptySecret() + f.SFTPConfig.KeyPassphrase = kms.NewEmptySecret() + f.HTTPConfig.Password = kms.NewEmptySecret() + f.HTTPConfig.APIKey = kms.NewEmptySecret() +} + +// SetEmptySecretsIfNil sets the secrets to empty if nil +func (f *Filesystem) SetEmptySecretsIfNil() { + if f.S3Config.AccessSecret == nil { + f.S3Config.AccessSecret = kms.NewEmptySecret() + } + if f.S3Config.SSECustomerKey == nil { + f.S3Config.SSECustomerKey = kms.NewEmptySecret() + } + if f.GCSConfig.Credentials == nil { + f.GCSConfig.Credentials = kms.NewEmptySecret() + } + if f.AzBlobConfig.AccountKey == nil { + f.AzBlobConfig.AccountKey = kms.NewEmptySecret() + } + if f.AzBlobConfig.SASURL == nil { + f.AzBlobConfig.SASURL = kms.NewEmptySecret() + } + if f.CryptConfig.Passphrase == nil { + f.CryptConfig.Passphrase = kms.NewEmptySecret() + } + if f.SFTPConfig.Password == nil { + f.SFTPConfig.Password = kms.NewEmptySecret() + } + if f.SFTPConfig.PrivateKey == nil { + f.SFTPConfig.PrivateKey = kms.NewEmptySecret() + } + if f.SFTPConfig.KeyPassphrase == nil { + f.SFTPConfig.KeyPassphrase = kms.NewEmptySecret() + } + if f.HTTPConfig.Password == nil { + f.HTTPConfig.Password = kms.NewEmptySecret() + } + if f.HTTPConfig.APIKey == nil { + f.HTTPConfig.APIKey = kms.NewEmptySecret() + } +} + +// SetNilSecretsIfEmpty set the secrets to nil if empty. +// This is useful before rendering as JSON so the empty fields +// will not be serialized. +func (f *Filesystem) SetNilSecretsIfEmpty() { + if f.S3Config.AccessSecret != nil && f.S3Config.AccessSecret.IsEmpty() { + f.S3Config.AccessSecret = nil + } + if f.S3Config.SSECustomerKey != nil && f.S3Config.SSECustomerKey.IsEmpty() { + f.S3Config.SSECustomerKey = nil + } + if f.GCSConfig.Credentials != nil && f.GCSConfig.Credentials.IsEmpty() { + f.GCSConfig.Credentials = nil + } + if f.AzBlobConfig.AccountKey != nil && f.AzBlobConfig.AccountKey.IsEmpty() { + f.AzBlobConfig.AccountKey = nil + } + if f.AzBlobConfig.SASURL != nil && f.AzBlobConfig.SASURL.IsEmpty() { + f.AzBlobConfig.SASURL = nil + } + if f.CryptConfig.Passphrase != nil && f.CryptConfig.Passphrase.IsEmpty() { + f.CryptConfig.Passphrase = nil + } + f.SFTPConfig.setNilSecretsIfEmpty() + f.HTTPConfig.setNilSecretsIfEmpty() +} + +// IsEqual returns true if the fs is equal to other +func (f *Filesystem) IsEqual(other Filesystem) bool { + if f.Provider != other.Provider { + return false + } + switch f.Provider { + case sdk.S3FilesystemProvider: + return f.S3Config.isEqual(other.S3Config) + case sdk.GCSFilesystemProvider: + return f.GCSConfig.isEqual(other.GCSConfig) + case sdk.AzureBlobFilesystemProvider: + return f.AzBlobConfig.isEqual(other.AzBlobConfig) + case sdk.CryptedFilesystemProvider: + return f.CryptConfig.isEqual(other.CryptConfig) + case sdk.SFTPFilesystemProvider: + return f.SFTPConfig.isEqual(other.SFTPConfig) + case sdk.HTTPFilesystemProvider: + return f.HTTPConfig.isEqual(other.HTTPConfig) + default: + return true + } +} + +// IsSameResource returns true if fs point to the same resource as other +func (f *Filesystem) IsSameResource(other Filesystem) bool { + if f.Provider != other.Provider { + return false + } + switch f.Provider { + case sdk.S3FilesystemProvider: + return f.S3Config.isSameResource(other.S3Config) + case sdk.GCSFilesystemProvider: + return f.GCSConfig.isSameResource(other.GCSConfig) + case sdk.AzureBlobFilesystemProvider: + return f.AzBlobConfig.isSameResource(other.AzBlobConfig) + case sdk.CryptedFilesystemProvider: + return f.CryptConfig.isSameResource(other.CryptConfig) + case sdk.SFTPFilesystemProvider: + return f.SFTPConfig.isSameResource(other.SFTPConfig) + case sdk.HTTPFilesystemProvider: + return f.HTTPConfig.isSameResource(other.HTTPConfig) + default: + return true + } +} + +// GetPathSeparator returns the path separator +func (f *Filesystem) GetPathSeparator() string { + switch f.Provider { + case sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider: + return string(os.PathSeparator) + default: + return "/" + } +} + +// Validate verifies the FsConfig matching the configured provider and sets all other +// Filesystem.*Config to their zero value if successful +func (f *Filesystem) Validate(additionalData string) error { + switch f.Provider { + case sdk.S3FilesystemProvider: + if err := f.S3Config.ValidateAndEncryptCredentials(additionalData); err != nil { + return err + } + f.OSConfig = sdk.OSFsConfig{} + f.GCSConfig = GCSFsConfig{} + f.AzBlobConfig = AzBlobFsConfig{} + f.CryptConfig = CryptFsConfig{} + f.SFTPConfig = SFTPFsConfig{} + f.HTTPConfig = HTTPFsConfig{} + return nil + case sdk.GCSFilesystemProvider: + if err := f.GCSConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err + } + f.OSConfig = sdk.OSFsConfig{} + f.S3Config = S3FsConfig{} + f.AzBlobConfig = AzBlobFsConfig{} + f.CryptConfig = CryptFsConfig{} + f.SFTPConfig = SFTPFsConfig{} + f.HTTPConfig = HTTPFsConfig{} + return nil + case sdk.AzureBlobFilesystemProvider: + if err := f.AzBlobConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err + } + f.OSConfig = sdk.OSFsConfig{} + f.S3Config = S3FsConfig{} + f.GCSConfig = GCSFsConfig{} + f.CryptConfig = CryptFsConfig{} + f.SFTPConfig = SFTPFsConfig{} + f.HTTPConfig = HTTPFsConfig{} + return nil + case sdk.CryptedFilesystemProvider: + if err := f.CryptConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err + } + f.OSConfig = sdk.OSFsConfig{} + f.S3Config = S3FsConfig{} + f.GCSConfig = GCSFsConfig{} + f.AzBlobConfig = AzBlobFsConfig{} + f.SFTPConfig = SFTPFsConfig{} + f.HTTPConfig = HTTPFsConfig{} + return validateOSFsConfig(&f.CryptConfig.OSFsConfig) + case sdk.SFTPFilesystemProvider: + if err := f.SFTPConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err + } + f.OSConfig = sdk.OSFsConfig{} + f.S3Config = S3FsConfig{} + f.GCSConfig = GCSFsConfig{} + f.AzBlobConfig = AzBlobFsConfig{} + f.CryptConfig = CryptFsConfig{} + f.HTTPConfig = HTTPFsConfig{} + return nil + case sdk.HTTPFilesystemProvider: + if err := f.HTTPConfig.ValidateAndEncryptCredentials(additionalData); err != nil { + return err + } + f.OSConfig = sdk.OSFsConfig{} + f.S3Config = S3FsConfig{} + f.GCSConfig = GCSFsConfig{} + f.AzBlobConfig = AzBlobFsConfig{} + f.CryptConfig = CryptFsConfig{} + f.SFTPConfig = SFTPFsConfig{} + return nil + case sdk.LocalFilesystemProvider: + f.S3Config = S3FsConfig{} + f.GCSConfig = GCSFsConfig{} + f.AzBlobConfig = AzBlobFsConfig{} + f.CryptConfig = CryptFsConfig{} + f.SFTPConfig = SFTPFsConfig{} + f.HTTPConfig = HTTPFsConfig{} + return validateOSFsConfig(&f.OSConfig) + default: + return util.NewI18nError( + util.NewValidationError("invalid filesystem provider"), + util.I18nErrorFsValidation, + ) + } +} + +// HasRedactedSecret returns true if configured the filesystem configuration has a redacted secret +func (f *Filesystem) HasRedactedSecret() bool { + // TODO move vfs specific code into each *FsConfig struct + switch f.Provider { + case sdk.S3FilesystemProvider: + if f.S3Config.SSECustomerKey.IsRedacted() { + return true + } + return f.S3Config.AccessSecret.IsRedacted() + case sdk.GCSFilesystemProvider: + return f.GCSConfig.Credentials.IsRedacted() + case sdk.AzureBlobFilesystemProvider: + if f.AzBlobConfig.AccountKey.IsRedacted() { + return true + } + return f.AzBlobConfig.SASURL.IsRedacted() + case sdk.CryptedFilesystemProvider: + return f.CryptConfig.Passphrase.IsRedacted() + case sdk.SFTPFilesystemProvider: + if f.SFTPConfig.Password.IsRedacted() { + return true + } + if f.SFTPConfig.PrivateKey.IsRedacted() { + return true + } + return f.SFTPConfig.KeyPassphrase.IsRedacted() + case sdk.HTTPFilesystemProvider: + if f.HTTPConfig.Password.IsRedacted() { + return true + } + return f.HTTPConfig.APIKey.IsRedacted() + } + + return false +} + +// HideConfidentialData hides filesystem confidential data +func (f *Filesystem) HideConfidentialData() { + switch f.Provider { + case sdk.S3FilesystemProvider: + f.S3Config.HideConfidentialData() + case sdk.GCSFilesystemProvider: + f.GCSConfig.HideConfidentialData() + case sdk.AzureBlobFilesystemProvider: + f.AzBlobConfig.HideConfidentialData() + case sdk.CryptedFilesystemProvider: + f.CryptConfig.HideConfidentialData() + case sdk.SFTPFilesystemProvider: + f.SFTPConfig.HideConfidentialData() + case sdk.HTTPFilesystemProvider: + f.HTTPConfig.HideConfidentialData() + } +} + +// GetACopy returns a filesystem copy +func (f *Filesystem) GetACopy() Filesystem { + f.SetEmptySecretsIfNil() + fs := Filesystem{ + Provider: f.Provider, + OSConfig: sdk.OSFsConfig{ + ReadBufferSize: f.OSConfig.ReadBufferSize, + WriteBufferSize: f.OSConfig.WriteBufferSize, + }, + S3Config: S3FsConfig{ + BaseS3FsConfig: sdk.BaseS3FsConfig{ + Bucket: f.S3Config.Bucket, + Region: f.S3Config.Region, + AccessKey: f.S3Config.AccessKey, + RoleARN: f.S3Config.RoleARN, + Endpoint: f.S3Config.Endpoint, + StorageClass: f.S3Config.StorageClass, + ACL: f.S3Config.ACL, + KeyPrefix: f.S3Config.KeyPrefix, + UploadPartSize: f.S3Config.UploadPartSize, + UploadConcurrency: f.S3Config.UploadConcurrency, + DownloadPartSize: f.S3Config.DownloadPartSize, + DownloadConcurrency: f.S3Config.DownloadConcurrency, + DownloadPartMaxTime: f.S3Config.DownloadPartMaxTime, + UploadPartMaxTime: f.S3Config.UploadPartMaxTime, + ForcePathStyle: f.S3Config.ForcePathStyle, + SkipTLSVerify: f.S3Config.SkipTLSVerify, + }, + AccessSecret: f.S3Config.AccessSecret.Clone(), + SSECustomerKey: f.S3Config.SSECustomerKey.Clone(), + }, + GCSConfig: GCSFsConfig{ + BaseGCSFsConfig: sdk.BaseGCSFsConfig{ + Bucket: f.GCSConfig.Bucket, + AutomaticCredentials: f.GCSConfig.AutomaticCredentials, + StorageClass: f.GCSConfig.StorageClass, + ACL: f.GCSConfig.ACL, + KeyPrefix: f.GCSConfig.KeyPrefix, + UploadPartSize: f.GCSConfig.UploadPartSize, + UploadPartMaxTime: f.GCSConfig.UploadPartMaxTime, + }, + Credentials: f.GCSConfig.Credentials.Clone(), + }, + AzBlobConfig: AzBlobFsConfig{ + BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{ + Container: f.AzBlobConfig.Container, + AccountName: f.AzBlobConfig.AccountName, + Endpoint: f.AzBlobConfig.Endpoint, + KeyPrefix: f.AzBlobConfig.KeyPrefix, + UploadPartSize: f.AzBlobConfig.UploadPartSize, + UploadConcurrency: f.AzBlobConfig.UploadConcurrency, + DownloadPartSize: f.AzBlobConfig.DownloadPartSize, + DownloadConcurrency: f.AzBlobConfig.DownloadConcurrency, + UseEmulator: f.AzBlobConfig.UseEmulator, + AccessTier: f.AzBlobConfig.AccessTier, + }, + AccountKey: f.AzBlobConfig.AccountKey.Clone(), + SASURL: f.AzBlobConfig.SASURL.Clone(), + }, + CryptConfig: CryptFsConfig{ + OSFsConfig: sdk.OSFsConfig{ + ReadBufferSize: f.CryptConfig.ReadBufferSize, + WriteBufferSize: f.CryptConfig.WriteBufferSize, + }, + Passphrase: f.CryptConfig.Passphrase.Clone(), + }, + SFTPConfig: SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: f.SFTPConfig.Endpoint, + Username: f.SFTPConfig.Username, + Prefix: f.SFTPConfig.Prefix, + DisableCouncurrentReads: f.SFTPConfig.DisableCouncurrentReads, + BufferSize: f.SFTPConfig.BufferSize, + EqualityCheckMode: f.SFTPConfig.EqualityCheckMode, + }, + Password: f.SFTPConfig.Password.Clone(), + PrivateKey: f.SFTPConfig.PrivateKey.Clone(), + KeyPassphrase: f.SFTPConfig.KeyPassphrase.Clone(), + }, + HTTPConfig: HTTPFsConfig{ + BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{ + Endpoint: f.HTTPConfig.Endpoint, + Username: f.HTTPConfig.Username, + SkipTLSVerify: f.HTTPConfig.SkipTLSVerify, + EqualityCheckMode: f.HTTPConfig.EqualityCheckMode, + }, + Password: f.HTTPConfig.Password.Clone(), + APIKey: f.HTTPConfig.APIKey.Clone(), + }, + } + if len(f.SFTPConfig.Fingerprints) > 0 { + fs.SFTPConfig.Fingerprints = make([]string, len(f.SFTPConfig.Fingerprints)) + copy(fs.SFTPConfig.Fingerprints, f.SFTPConfig.Fingerprints) + } + return fs +} diff --git a/internal/vfs/folder.go b/internal/vfs/folder.go new file mode 100644 index 00000000..5fcee524 --- /dev/null +++ b/internal/vfs/folder.go @@ -0,0 +1,202 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "errors" + "fmt" + "strings" + + "github.com/rs/xid" + "github.com/sftpgo/sdk" +) + +// BaseVirtualFolder defines the path for the virtual folder and the used quota limits. +// The same folder can be shared among multiple users and each user can have different +// quota limits or a different virtual path. +type BaseVirtualFolder struct { + ID int64 `json:"id"` + Name string `json:"name"` + MappedPath string `json:"mapped_path,omitempty"` + Description string `json:"description,omitempty"` + UsedQuotaSize int64 `json:"used_quota_size"` + // Used quota as number of files + UsedQuotaFiles int `json:"used_quota_files"` + // Last quota update as unix timestamp in milliseconds + LastQuotaUpdate int64 `json:"last_quota_update"` + // list of usernames associated with this virtual folder + Users []string `json:"users,omitempty"` + // list of group names associated with this virtual folder + Groups []string `json:"groups,omitempty"` + // Filesystem configuration details + FsConfig Filesystem `json:"filesystem"` +} + +// GetEncryptionAdditionalData returns the additional data to use for AEAD +func (v *BaseVirtualFolder) GetEncryptionAdditionalData() string { + return fmt.Sprintf("folder_%v", v.Name) +} + +// GetACopy returns a copy +func (v *BaseVirtualFolder) GetACopy() BaseVirtualFolder { + users := make([]string, len(v.Users)) + copy(users, v.Users) + groups := make([]string, len(v.Groups)) + copy(groups, v.Groups) + return BaseVirtualFolder{ + ID: v.ID, + Name: v.Name, + Description: v.Description, + MappedPath: v.MappedPath, + UsedQuotaSize: v.UsedQuotaSize, + UsedQuotaFiles: v.UsedQuotaFiles, + LastQuotaUpdate: v.LastQuotaUpdate, + Users: users, + Groups: v.Groups, + FsConfig: v.FsConfig.GetACopy(), + } +} + +// IsLocalOrLocalCrypted returns true if the folder provider is local or local encrypted +func (v *BaseVirtualFolder) IsLocalOrLocalCrypted() bool { + return v.FsConfig.Provider == sdk.LocalFilesystemProvider || v.FsConfig.Provider == sdk.CryptedFilesystemProvider +} + +// hideConfidentialData hides folder confidential data +func (v *BaseVirtualFolder) hideConfidentialData() { + switch v.FsConfig.Provider { + case sdk.S3FilesystemProvider: + v.FsConfig.S3Config.HideConfidentialData() + case sdk.GCSFilesystemProvider: + v.FsConfig.GCSConfig.HideConfidentialData() + case sdk.AzureBlobFilesystemProvider: + v.FsConfig.AzBlobConfig.HideConfidentialData() + case sdk.CryptedFilesystemProvider: + v.FsConfig.CryptConfig.HideConfidentialData() + case sdk.SFTPFilesystemProvider: + v.FsConfig.SFTPConfig.HideConfidentialData() + case sdk.HTTPFilesystemProvider: + v.FsConfig.HTTPConfig.HideConfidentialData() + } +} + +// PrepareForRendering prepares a folder for rendering. +// It hides confidential data and set to nil the empty secrets +// so they are not serialized +func (v *BaseVirtualFolder) PrepareForRendering() { + v.hideConfidentialData() + v.FsConfig.SetEmptySecretsIfNil() +} + +// HasRedactedSecret returns true if the folder has a redacted secret +func (v *BaseVirtualFolder) HasRedactedSecret() bool { + return v.FsConfig.HasRedactedSecret() +} + +// hasPathPlaceholder returns true if the folder has a path placeholder +func (v *BaseVirtualFolder) hasPathPlaceholder() bool { + placeholders := []string{"%username%", "%role%"} + var config string + switch v.FsConfig.Provider { + case sdk.S3FilesystemProvider: + config = v.FsConfig.S3Config.KeyPrefix + case sdk.GCSFilesystemProvider: + config = v.FsConfig.GCSConfig.KeyPrefix + case sdk.AzureBlobFilesystemProvider: + config = v.FsConfig.AzBlobConfig.KeyPrefix + case sdk.SFTPFilesystemProvider: + config = v.FsConfig.SFTPConfig.Prefix + case sdk.LocalFilesystemProvider, sdk.CryptedFilesystemProvider: + config = v.MappedPath + } + for _, placeholder := range placeholders { + if strings.Contains(config, placeholder) { + return true + } + } + return false +} + +// VirtualFolder defines a mapping between an SFTPGo virtual path and a +// filesystem path outside the user home directory. +// The specified paths must be absolute and the virtual path cannot be "/", +// it must be a sub directory. The parent directory for the specified virtual +// path must exist. SFTPGo will try to automatically create any missing +// parent directory for the configured virtual folders at user login. +type VirtualFolder struct { + BaseVirtualFolder + VirtualPath string `json:"virtual_path"` + // Maximum size allowed as bytes. 0 means unlimited, -1 included in user quota + QuotaSize int64 `json:"quota_size"` + // Maximum number of files allowed. 0 means unlimited, -1 included in user quota + QuotaFiles int `json:"quota_files"` +} + +// GetFilesystem returns the filesystem for this folder +func (v *VirtualFolder) GetFilesystem(connectionID string, forbiddenSelfUsers []string) (Fs, error) { + switch v.FsConfig.Provider { + case sdk.S3FilesystemProvider: + return NewS3Fs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.S3Config) + case sdk.GCSFilesystemProvider: + return NewGCSFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.GCSConfig) + case sdk.AzureBlobFilesystemProvider: + return NewAzBlobFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.AzBlobConfig) + case sdk.CryptedFilesystemProvider: + return NewCryptFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.CryptConfig) + case sdk.SFTPFilesystemProvider: + return NewSFTPFs(connectionID, v.VirtualPath, v.MappedPath, forbiddenSelfUsers, v.FsConfig.SFTPConfig) + case sdk.HTTPFilesystemProvider: + return NewHTTPFs(connectionID, v.MappedPath, v.VirtualPath, v.FsConfig.HTTPConfig) + default: + return NewOsFs(connectionID, v.MappedPath, v.VirtualPath, &v.FsConfig.OSConfig), nil + } +} + +// ScanQuota scans the folder and returns the number of files and their size +func (v *VirtualFolder) ScanQuota() (int, int64, error) { + if v.hasPathPlaceholder() { + return 0, 0, errors.New("cannot scan quota: this folder has a path placeholder") + } + fs, err := v.GetFilesystem(xid.New().String(), nil) + if err != nil { + return 0, 0, err + } + defer fs.Close() + + return fs.ScanRootDirContents() +} + +// IsIncludedInUserQuota returns true if the virtual folder is included in user quota +func (v *VirtualFolder) IsIncludedInUserQuota() bool { + return v.QuotaFiles == -1 && v.QuotaSize == -1 +} + +// HasNoQuotaRestrictions returns true if no quota restrictions need to be applyed +func (v *VirtualFolder) HasNoQuotaRestrictions(checkFiles bool) bool { + if v.QuotaSize == 0 && (!checkFiles || v.QuotaFiles == 0) { + return true + } + return false +} + +// GetACopy returns a copy +func (v *VirtualFolder) GetACopy() VirtualFolder { + return VirtualFolder{ + BaseVirtualFolder: v.BaseVirtualFolder.GetACopy(), + VirtualPath: v.VirtualPath, + QuotaSize: v.QuotaSize, + QuotaFiles: v.QuotaFiles, + } +} diff --git a/internal/vfs/gcsfs.go b/internal/vfs/gcsfs.go new file mode 100644 index 00000000..b3cfa6c2 --- /dev/null +++ b/internal/vfs/gcsfs.go @@ -0,0 +1,1046 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nogcs + +package vfs + +import ( + "context" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "cloud.google.com/go/storage" + "github.com/pkg/sftp" + "github.com/rs/xid" + "google.golang.org/api/googleapi" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + defaultGCSPageSize = 5000 +) + +var ( + gcsDefaultFieldsSelection = []string{"Name", "Size", "Deleted", "Updated", "ContentType", "Metadata"} +) + +// GCSFs is a Fs implementation for Google Cloud Storage. +type GCSFs struct { + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + config *GCSFsConfig + svc *storage.Client + ctxTimeout time.Duration + ctxLongTimeout time.Duration +} + +func init() { + version.AddFeature("+gcs") +} + +// NewGCSFs returns an GCSFs object that allows to interact with Google Cloud Storage +func NewGCSFs(connectionID, localTempDir, mountPath string, config GCSFsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = getLocalTempDir() + } + + var err error + fs := &GCSFs{ + connectionID: connectionID, + localTempDir: localTempDir, + mountPath: getMountPath(mountPath), + config: &config, + ctxTimeout: 30 * time.Second, + ctxLongTimeout: 300 * time.Second, + } + if err = fs.config.validate(); err != nil { + return fs, err + } + ctx := context.Background() + if fs.config.AutomaticCredentials > 0 { + fs.svc, err = storage.NewClient(ctx, + storage.WithJSONReads(), + option.WithUserAgent(version.GetVersionHash()), + ) + } else { + err = fs.config.Credentials.TryDecrypt() + if err != nil { + return fs, err + } + fs.svc, err = storage.NewClient(ctx, + storage.WithJSONReads(), + option.WithUserAgent(version.GetVersionHash()), + option.WithAuthCredentialsJSON(option.ServiceAccount, []byte(fs.config.Credentials.GetPayload())), + ) + } + return fs, err +} + +// Name returns the name for the Fs implementation +func (fs *GCSFs) Name() string { + return fmt.Sprintf("%s bucket %q", gcsfsName, fs.config.Bucket) +} + +// ConnectionID returns the connection ID associated to this Fs implementation +func (fs *GCSFs) ConnectionID() string { + return fs.connectionID +} + +// Stat returns a FileInfo describing the named file +func (fs *GCSFs) Stat(name string) (os.FileInfo, error) { + if name == "" || name == "/" || name == "." { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + if fs.config.KeyPrefix == name+"/" { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + return fs.getObjectStat(name) +} + +// Lstat returns a FileInfo describing the named file +func (fs *GCSFs) Lstat(name string) (os.FileInfo, error) { + return fs.Stat(name) +} + +// Open opens the named file for reading +func (fs *GCSFs) Open(name string, offset int64) (File, PipeReader, func(), error) { + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + return nil, nil, nil, err + } + p := NewPipeReader(r) + if readMetadata > 0 { + attrs, err := fs.headObject(name) + if err != nil { + r.Close() + w.Close() + return nil, nil, nil, err + } + p.setMetadata(attrs.Metadata) + } + bkt := fs.svc.Bucket(fs.config.Bucket) + obj := bkt.Object(name) + ctx, cancelFn := context.WithCancel(context.Background()) + objectReader, err := obj.NewRangeReader(ctx, offset, -1) + if err == nil && offset > 0 && objectReader.Attrs.ContentEncoding == "gzip" { + err = fmt.Errorf("range request is not possible for gzip content encoding, requested offset %d", offset) + objectReader.Close() + } + if err != nil { + r.Close() + w.Close() + cancelFn() + return nil, nil, nil, err + } + go func() { + defer cancelFn() + defer objectReader.Close() + + n, err := io.Copy(w, objectReader) + w.CloseWithError(err) //nolint:errcheck + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, n, err) + metric.GCSTransferCompleted(n, 1, err) + }() + return nil, p, cancelFn, nil +} + +// Create creates or opens the named file for writing +func (fs *GCSFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { + if checks&CheckParentDir != 0 { + _, err := fs.Stat(path.Dir(name)) + if err != nil { + return nil, nil, nil, err + } + } + chunkSize := googleapi.DefaultUploadChunkSize + if fs.config.UploadPartSize > 0 { + chunkSize = int(fs.config.UploadPartSize) * 1024 * 1024 + } + r, w, err := createPipeFn(fs.localTempDir, int64(chunkSize+1024*1024)) + if err != nil { + return nil, nil, nil, err + } + var partialFileName string + var attrs *storage.ObjectAttrs + var statErr error + + bkt := fs.svc.Bucket(fs.config.Bucket) + obj := bkt.Object(name) + + if flag == -1 { + obj = obj.If(storage.Conditions{DoesNotExist: true}) + } else { + attrs, statErr = fs.headObject(name) + if statErr == nil { + obj = obj.If(storage.Conditions{GenerationMatch: attrs.Generation}) + } else if fs.IsNotExist(statErr) { + obj = obj.If(storage.Conditions{DoesNotExist: true}) + } else { + fsLog(fs, logger.LevelWarn, "unable to set precondition for %q, stat err: %v", name, statErr) + } + } + ctx, cancelFn := context.WithCancel(context.Background()) + + var p PipeWriter + var objectWriter *storage.Writer + if checks&CheckResume != 0 { + if statErr != nil { + cancelFn() + r.Close() + w.Close() + return nil, nil, nil, fmt.Errorf("unable to resume %q stat error: %w", name, statErr) + } + p = newPipeWriterAtOffset(w, attrs.Size) + partialFileName = fs.getTempObject(name) + partialObj := bkt.Object(partialFileName) + partialObj = partialObj.If(storage.Conditions{DoesNotExist: true}) + objectWriter = partialObj.NewWriter(ctx) + } else { + p = NewPipeWriter(w) + objectWriter = obj.NewWriter(ctx) + } + + objectWriter.ChunkSize = chunkSize + if fs.config.UploadPartMaxTime > 0 { + objectWriter.ChunkRetryDeadline = time.Duration(fs.config.UploadPartMaxTime) * time.Second + } + fs.setWriterAttrs(objectWriter, flag, name) + + go func() { + defer cancelFn() + + n, err := io.Copy(objectWriter, r) + closeErr := objectWriter.Close() + if err == nil { + err = closeErr + } + if err == nil && partialFileName != "" { + partialObject := bkt.Object(partialFileName) + partialObject = partialObject.If(storage.Conditions{GenerationMatch: objectWriter.Attrs().Generation}) + err = fs.composeObjects(ctx, obj, partialObject) + } + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, acl: %q, readed bytes: %v, err: %+v", + name, fs.config.ACL, n, err) + metric.GCSTransferCompleted(n, 0, err) + }() + + if uploadMode&8 != 0 { + return nil, p, nil, nil + } + return nil, p, cancelFn, nil +} + +// Rename renames (moves) source to target. +func (fs *GCSFs) Rename(source, target string, checks int) (int, int64, error) { + if source == target { + return -1, -1, nil + } + if checks&CheckParentDir != 0 { + _, err := fs.Stat(path.Dir(target)) + if err != nil { + return -1, -1, err + } + } + fi, err := fs.getObjectStat(source) + if err != nil { + return -1, -1, err + } + return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0) +} + +// Remove removes the named file or (empty) directory. +func (fs *GCSFs) Remove(name string, isDir bool) error { + if isDir { + hasContents, err := fs.hasContents(name) + if err != nil { + return err + } + if hasContents { + return fmt.Errorf("cannot remove non empty directory: %q", name) + } + if !strings.HasSuffix(name, "/") { + name += "/" + } + } + obj := fs.svc.Bucket(fs.config.Bucket).Object(name) + attrs, statErr := fs.headObject(name) + if statErr == nil { + obj = obj.If(storage.Conditions{GenerationMatch: attrs.Generation}) + } else { + fsLog(fs, logger.LevelWarn, "unable to set precondition for deleting %q, stat err: %v", + name, statErr) + } + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + err := obj.Delete(ctx) + if isDir && fs.IsNotExist(err) { + // we can have directories without a trailing "/" (created using v2.1.0 and before) + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + err = fs.svc.Bucket(fs.config.Bucket).Object(strings.TrimSuffix(name, "/")).Delete(ctx) + } + metric.GCSDeleteObjectCompleted(err) + return err +} + +// Mkdir creates a new directory with the specified name and default permissions +func (fs *GCSFs) Mkdir(name string) error { + _, err := fs.Stat(name) + if !fs.IsNotExist(err) { + return err + } + return fs.mkdirInternal(name) +} + +// Symlink creates source as a symbolic link to target. +func (*GCSFs) Symlink(_, _ string) error { + return ErrVfsUnsupported +} + +// Readlink returns the destination of the named symbolic link +func (*GCSFs) Readlink(_ string) (string, error) { + return "", ErrVfsUnsupported +} + +// Chown changes the numeric uid and gid of the named file. +func (*GCSFs) Chown(_ string, _ int, _ int) error { + return ErrVfsUnsupported +} + +// Chmod changes the mode of the named file to mode. +func (*GCSFs) Chmod(_ string, _ os.FileMode) error { + return ErrVfsUnsupported +} + +// Chtimes changes the access and modification times of the named file. +func (fs *GCSFs) Chtimes(name string, _, mtime time.Time, isUploading bool) error { + if isUploading { + return nil + } + obj := fs.svc.Bucket(fs.config.Bucket).Object(name) + attrs, err := fs.headObject(name) + if err != nil { + return err + } + obj = obj.If(storage.Conditions{MetagenerationMatch: attrs.Metageneration}) + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + metadata := attrs.Metadata + if metadata == nil { + metadata = make(map[string]string) + } + metadata[lastModifiedField] = strconv.FormatInt(mtime.UnixMilli(), 10) + + objectAttrsToUpdate := storage.ObjectAttrsToUpdate{ + Metadata: metadata, + } + _, err = obj.Update(ctx, objectAttrsToUpdate) + + return err +} + +// Truncate changes the size of the named file. +// Truncate by path is not supported, while truncating an opened +// file is handled inside base transfer +func (*GCSFs) Truncate(_ string, _ int64) error { + return ErrVfsUnsupported +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (fs *GCSFs) ReadDir(dirname string) (DirLister, error) { + // dirname must be already cleaned + prefix := fs.getPrefix(dirname) + query := &storage.Query{Prefix: prefix, Delimiter: "/"} + err := query.SetAttrSelection(gcsDefaultFieldsSelection) + if err != nil { + return nil, err + } + bkt := fs.svc.Bucket(fs.config.Bucket) + + return &gcsDirLister{ + bucket: bkt, + query: query, + timeout: fs.ctxTimeout, + prefix: prefix, + prefixes: make(map[string]bool), + }, nil +} + +// IsUploadResumeSupported returns true if resuming uploads is supported. +// Resuming uploads is not supported on GCS +func (*GCSFs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (*GCSFs) IsConditionalUploadResumeSupported(_ int64) bool { + return true +} + +// IsAtomicUploadSupported returns true if atomic upload is supported. +// S3 uploads are already atomic, we don't need to upload to a temporary +// file +func (*GCSFs) IsAtomicUploadSupported() bool { + return false +} + +// IsNotExist returns a boolean indicating whether the error is known to +// report that a file or directory does not exist +func (*GCSFs) IsNotExist(err error) bool { + if err == nil { + return false + } + if errors.Is(err, storage.ErrObjectNotExist) { + return true + } + var apiErr *googleapi.Error + if errors.As(err, &apiErr) { + if apiErr.Code == http.StatusNotFound { + return true + } + } + return false +} + +// IsPermission returns a boolean indicating whether the error is known to +// report that permission is denied. +func (*GCSFs) IsPermission(err error) bool { + if err == nil { + return false + } + var apiErr *googleapi.Error + if errors.As(err, &apiErr) { + if apiErr.Code == http.StatusForbidden || apiErr.Code == http.StatusUnauthorized { + return true + } + } + return false +} + +// IsNotSupported returns true if the error indicate an unsupported operation +func (*GCSFs) IsNotSupported(err error) bool { + if err == nil { + return false + } + return errors.Is(err, ErrVfsUnsupported) +} + +// CheckRootPath creates the specified local root directory if it does not exists +func (fs *GCSFs) CheckRootPath(username string, uid int, gid int) bool { + // we need a local directory for temporary files + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) + return osFs.CheckRootPath(username, uid, gid) +} + +// ScanRootDirContents returns the number of files contained in the bucket, +// and their size +func (fs *GCSFs) ScanRootDirContents() (int, int64, error) { + return fs.GetDirSize(fs.config.KeyPrefix) +} + +// GetDirSize returns the number of files and the size for a folder +// including any subfolders +func (fs *GCSFs) GetDirSize(dirname string) (int, int64, error) { + prefix := fs.getPrefix(dirname) + numFiles := 0 + size := int64(0) + + query := &storage.Query{Prefix: prefix} + err := query.SetAttrSelection(gcsDefaultFieldsSelection) + if err != nil { + return numFiles, size, err + } + + iteratePage := func(nextPageToken string) (string, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + bkt := fs.svc.Bucket(fs.config.Bucket) + it := bkt.Objects(ctx, query) + pager := iterator.NewPager(it, defaultGCSPageSize, nextPageToken) + + var objects []*storage.ObjectAttrs + pageToken, err := pager.NextPage(&objects) + if err != nil { + return pageToken, err + } + for _, attrs := range objects { + if !attrs.Deleted.IsZero() { + continue + } + isDir := strings.HasSuffix(attrs.Name, "/") || attrs.ContentType == dirMimeType + if isDir && attrs.Size == 0 { + continue + } + numFiles++ + size += attrs.Size + } + return pageToken, nil + } + + pageToken := "" + for { + pageToken, err = iteratePage(pageToken) + if err != nil { + metric.GCSListObjectsCompleted(err) + return numFiles, size, err + } + fsLog(fs, logger.LevelDebug, "scan in progress for %q, files: %d, size: %d", dirname, numFiles, size) + if pageToken == "" { + break + } + } + + metric.GCSListObjectsCompleted(nil) + return numFiles, size, err +} + +// GetAtomicUploadPath returns the path to use for an atomic upload. +// GCS uploads are already atomic, we never call this method for GCS +func (*GCSFs) GetAtomicUploadPath(_ string) string { + return "" +} + +// GetRelativePath returns the path for a file relative to the user's home dir. +// This is the path as seen by SFTPGo users +func (fs *GCSFs) GetRelativePath(name string) string { + rel := path.Clean(name) + if rel == "." { + rel = "" + } + if !path.IsAbs(rel) { + rel = "/" + rel + } + if fs.config.KeyPrefix != "" { + if !strings.HasPrefix(rel, "/"+fs.config.KeyPrefix) { + rel = "/" + } + rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) + } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } + return rel +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root +func (fs *GCSFs) Walk(root string, walkFn filepath.WalkFunc) error { + prefix := fs.getPrefix(root) + + query := &storage.Query{Prefix: prefix} + err := query.SetAttrSelection(gcsDefaultFieldsSelection) + if err != nil { + walkFn(root, nil, err) //nolint:errcheck + return err + } + + iteratePage := func(nextPageToken string) (string, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + bkt := fs.svc.Bucket(fs.config.Bucket) + it := bkt.Objects(ctx, query) + pager := iterator.NewPager(it, defaultGCSPageSize, nextPageToken) + + var objects []*storage.ObjectAttrs + pageToken, err := pager.NextPage(&objects) + if err != nil { + walkFn(root, nil, err) //nolint:errcheck + return pageToken, err + } + for _, attrs := range objects { + if !attrs.Deleted.IsZero() { + continue + } + name, isDir := fs.resolve(attrs.Name, prefix, attrs.ContentType) + if name == "" { + continue + } + objectModTime := attrs.Updated + if val := getLastModified(attrs.Metadata); val > 0 { + objectModTime = util.GetTimeFromMsecSinceEpoch(val) + } + err = walkFn(attrs.Name, NewFileInfo(name, isDir, attrs.Size, objectModTime, false), nil) + if err != nil { + return pageToken, err + } + } + + return pageToken, nil + } + + pageToken := "" + for { + pageToken, err = iteratePage(pageToken) + if err != nil { + metric.GCSListObjectsCompleted(err) + return err + } + if pageToken == "" { + break + } + } + + walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), err) //nolint:errcheck + metric.GCSListObjectsCompleted(err) + return err +} + +// Join joins any number of path elements into a single path +func (*GCSFs) Join(elem ...string) string { + return strings.TrimPrefix(path.Join(elem...), "/") +} + +// HasVirtualFolders returns true if folders are emulated +func (*GCSFs) HasVirtualFolders() bool { + return true +} + +// ResolvePath returns the matching filesystem path for the specified virtual path +func (fs *GCSFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { + virtualPath = after + } + } + virtualPath = path.Clean("/" + virtualPath) + return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil +} + +// CopyFile implements the FsFileCopier interface +func (fs *GCSFs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) { + numFiles := 1 + sizeDiff := srcInfo.Size() + var conditions *storage.Conditions + attrs, err := fs.headObject(target) + if err == nil { + sizeDiff -= attrs.Size + numFiles = 0 + conditions = &storage.Conditions{GenerationMatch: attrs.Generation} + } else { + if !fs.IsNotExist(err) { + return 0, 0, err + } + conditions = &storage.Conditions{DoesNotExist: true} + } + if err := fs.copyFileInternal(source, target, conditions, srcInfo, true); err != nil { + return 0, 0, err + } + return numFiles, sizeDiff, nil +} + +func (fs *GCSFs) resolve(name, prefix, contentType string) (string, bool) { + result := strings.TrimPrefix(name, prefix) + isDir := strings.HasSuffix(result, "/") + if isDir { + result = strings.TrimSuffix(result, "/") + } + if contentType == dirMimeType { + isDir = true + } + return result, isDir +} + +// getObjectStat returns the stat result +func (fs *GCSFs) getObjectStat(name string) (os.FileInfo, error) { + attrs, err := fs.headObject(name) + if err == nil { + objSize := attrs.Size + objectModTime := attrs.Updated + if val := getLastModified(attrs.Metadata); val > 0 { + objectModTime = util.GetTimeFromMsecSinceEpoch(val) + } + isDir := attrs.ContentType == dirMimeType || strings.HasSuffix(attrs.Name, "/") + info := NewFileInfo(name, isDir, objSize, objectModTime, false) + if !isDir { + info.setMetadata(attrs.Metadata) + } + return info, nil + } + if !fs.IsNotExist(err) { + return nil, err + } + // now check if this is a prefix (virtual directory) + hasContents, err := fs.hasContents(name) + if err != nil { + return nil, err + } + if hasContents { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + // finally check if this is an object with a trailing / + attrs, err = fs.headObject(name + "/") + if err != nil { + return nil, err + } + objectModTime := attrs.Updated + if val := getLastModified(attrs.Metadata); val > 0 { + objectModTime = util.GetTimeFromMsecSinceEpoch(val) + } + return NewFileInfo(name, true, attrs.Size, objectModTime, false), nil +} + +func (fs *GCSFs) setWriterAttrs(objectWriter *storage.Writer, flag int, name string) { + var contentType string + if flag == -1 { + contentType = dirMimeType + } else { + contentType = mime.TypeByExtension(path.Ext(name)) + } + if contentType != "" { + objectWriter.ContentType = contentType + } + if fs.config.StorageClass != "" { + objectWriter.StorageClass = fs.config.StorageClass + } + if fs.config.ACL != "" { + objectWriter.PredefinedACL = fs.config.ACL + } +} + +func (fs *GCSFs) composeObjects(ctx context.Context, dst, partialObject *storage.ObjectHandle) error { + fsLog(fs, logger.LevelDebug, "start object compose for partial file %q, destination %q", + partialObject.ObjectName(), dst.ObjectName()) + composer := dst.ComposerFrom(dst, partialObject) + if fs.config.StorageClass != "" { + composer.StorageClass = fs.config.StorageClass + } + if fs.config.ACL != "" { + composer.PredefinedACL = fs.config.ACL + } + contentType := mime.TypeByExtension(path.Ext(dst.ObjectName())) + if contentType != "" { + composer.ContentType = contentType + } + _, err := composer.Run(ctx) + fsLog(fs, logger.LevelDebug, "object compose for %q finished, err: %v", dst.ObjectName(), err) + + delCtx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + errDelete := partialObject.Delete(delCtx) + metric.GCSDeleteObjectCompleted(errDelete) + fsLog(fs, logger.LevelDebug, "deleted partial file %q after composing with %q, err: %v", + partialObject.ObjectName(), dst.ObjectName(), errDelete) + return err +} + +func (fs *GCSFs) copyFileInternal(source, target string, conditions *storage.Conditions, + srcInfo os.FileInfo, updateModTime bool, +) error { + src := fs.svc.Bucket(fs.config.Bucket).Object(source) + dst := fs.svc.Bucket(fs.config.Bucket).Object(target) + if conditions != nil { + dst = dst.If(*conditions) + } else { + attrs, err := fs.headObject(target) + if err == nil { + dst = dst.If(storage.Conditions{GenerationMatch: attrs.Generation}) + } else if fs.IsNotExist(err) { + dst = dst.If(storage.Conditions{DoesNotExist: true}) + } else { + fsLog(fs, logger.LevelWarn, "unable to set precondition for copy, target %q, stat err: %v", + target, err) + } + } + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxLongTimeout)) + defer cancelFn() + + copier := dst.CopierFrom(src) + if fs.config.StorageClass != "" { + copier.StorageClass = fs.config.StorageClass + } + if fs.config.ACL != "" { + copier.PredefinedACL = fs.config.ACL + } + contentType := mime.TypeByExtension(path.Ext(source)) + if contentType != "" { + copier.ContentType = contentType + } + metadata := getMetadata(srcInfo) + if updateModTime && len(metadata) > 0 { + delete(metadata, lastModifiedField) + } + if len(metadata) > 0 { + copier.Metadata = metadata + } + _, err := copier.Run(ctx) + metric.GCSCopyObjectCompleted(err) + return err +} + +func (fs *GCSFs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int, + updateModTime bool, +) (int, int64, error) { + var numFiles int + var filesSize int64 + + if srcInfo.IsDir() { + if renameMode == 0 { + hasContents, err := fs.hasContents(source) + if err != nil { + return numFiles, filesSize, err + } + if hasContents { + return numFiles, filesSize, fmt.Errorf("%w: cannot rename non empty directory: %q", ErrVfsUnsupported, source) + } + } + if err := fs.mkdirInternal(target); err != nil { + return numFiles, filesSize, err + } + if renameMode == 1 { + files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime) + numFiles += files + filesSize += size + if err != nil { + return numFiles, filesSize, err + } + } + } else { + if err := fs.copyFileInternal(source, target, nil, srcInfo, updateModTime); err != nil { + return numFiles, filesSize, err + } + numFiles++ + filesSize += srcInfo.Size() + } + err := fs.Remove(source, srcInfo.IsDir()) + if fs.IsNotExist(err) { + err = nil + } + return numFiles, filesSize, err +} + +func (fs *GCSFs) mkdirInternal(name string) error { + if !strings.HasSuffix(name, "/") { + name += "/" + } + _, w, _, err := fs.Create(name, -1, 0) + if err != nil { + return err + } + return w.Close() +} + +func (fs *GCSFs) hasContents(name string) (bool, error) { + result := false + prefix := fs.getPrefix(name) + query := &storage.Query{Prefix: prefix} + err := query.SetAttrSelection(gcsDefaultFieldsSelection) + if err != nil { + return result, err + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + bkt := fs.svc.Bucket(fs.config.Bucket) + it := bkt.Objects(ctx, query) + // if we have a dir object with a trailing slash it will be returned so we set the size to 2 + pager := iterator.NewPager(it, 2, "") + + var objects []*storage.ObjectAttrs + _, err = pager.NextPage(&objects) + if err != nil { + metric.GCSListObjectsCompleted(err) + return result, err + } + + for _, attrs := range objects { + name, _ := fs.resolve(attrs.Name, prefix, attrs.ContentType) + // a dir object with a trailing slash will result in an empty name + if name == "/" || name == "" { + continue + } + result = true + break + } + + metric.GCSListObjectsCompleted(nil) + return result, nil +} + +func (fs *GCSFs) getPrefix(name string) string { + prefix := "" + if name != "" && name != "." && name != "/" { + prefix = strings.TrimPrefix(name, "/") + if !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + } + return prefix +} + +func (fs *GCSFs) headObject(name string) (*storage.ObjectAttrs, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + bkt := fs.svc.Bucket(fs.config.Bucket) + obj := bkt.Object(name) + attrs, err := obj.Attrs(ctx) + metric.GCSHeadObjectCompleted(err) + return attrs, err +} + +// GetMimeType returns the content type +func (fs *GCSFs) GetMimeType(name string) (string, error) { + attrs, err := fs.headObject(name) + if err != nil { + return "", err + } + return attrs.ContentType, nil +} + +// Close closes the fs +func (fs *GCSFs) Close() error { + return nil +} + +// GetAvailableDiskSize returns the available size for the specified path +func (*GCSFs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { + return nil, ErrStorageSizeUnavailable +} + +func (*GCSFs) getTempObject(name string) string { + dir := filepath.Dir(name) + guid := xid.New().String() + return filepath.Join(dir, ".sftpgo-partial."+guid+"."+filepath.Base(name)) +} + +type gcsDirLister struct { + baseDirLister + bucket *storage.BucketHandle + query *storage.Query + timeout time.Duration + nextPageToken string + noMorePages bool + prefix string + prefixes map[string]bool + metricUpdated bool +} + +func (l *gcsDirLister) resolve(name, contentType string) (string, bool) { + result := strings.TrimPrefix(name, l.prefix) + isDir := strings.HasSuffix(result, "/") + if isDir { + result = strings.TrimSuffix(result, "/") + } + if contentType == dirMimeType { + isDir = true + } + return result, isDir +} + +func (l *gcsDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + + if l.noMorePages { + if !l.metricUpdated { + l.metricUpdated = true + metric.GCSListObjectsCompleted(nil) + } + return l.returnFromCache(limit), io.EOF + } + + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) + defer cancelFn() + + it := l.bucket.Objects(ctx, l.query) + paginator := iterator.NewPager(it, defaultGCSPageSize, l.nextPageToken) + var objects []*storage.ObjectAttrs + + pageToken, err := paginator.NextPage(&objects) + if err != nil { + metric.GCSListObjectsCompleted(err) + return l.cache, err + } + + for _, attrs := range objects { + if attrs.Prefix != "" { + name, _ := l.resolve(attrs.Prefix, attrs.ContentType) + if name == "" { + continue + } + if _, ok := l.prefixes[name]; ok { + continue + } + l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) + l.prefixes[name] = true + } else { + name, isDir := l.resolve(attrs.Name, attrs.ContentType) + if name == "" { + continue + } + if !attrs.Deleted.IsZero() { + continue + } + if isDir { + // check if the dir is already included, it will be sent as blob prefix if it contains at least one item + if _, ok := l.prefixes[name]; ok { + continue + } + l.prefixes[name] = true + } + modTime := attrs.Updated + if val := getLastModified(attrs.Metadata); val > 0 { + modTime = util.GetTimeFromMsecSinceEpoch(val) + } + info := NewFileInfo(name, isDir, attrs.Size, modTime, false) + info.setMetadata(attrs.Metadata) + l.cache = append(l.cache, info) + } + } + + l.nextPageToken = pageToken + l.noMorePages = (l.nextPageToken == "") + + return l.returnFromCache(limit), nil +} + +func (l *gcsDirLister) Close() error { + clear(l.prefixes) + return l.baseDirLister.Close() +} diff --git a/internal/vfs/gcsfs_disabled.go b/internal/vfs/gcsfs_disabled.go new file mode 100644 index 00000000..cac24e60 --- /dev/null +++ b/internal/vfs/gcsfs_disabled.go @@ -0,0 +1,32 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nogcs + +package vfs + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-gcs") +} + +// NewGCSFs returns an error, GCS is disabled +func NewGCSFs(_, _, _ string, _ GCSFsConfig) (Fs, error) { + return nil, errors.New("Google Cloud Storage disabled at build time") +} diff --git a/internal/vfs/httpfs.go b/internal/vfs/httpfs.go new file mode 100644 index 00000000..76aa940c --- /dev/null +++ b/internal/vfs/httpfs.go @@ -0,0 +1,843 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "mime" + "net" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strings" + "time" + + "github.com/pkg/sftp" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + // httpFsName is the name for the HTTP Fs implementation + httpFsName = "httpfs" + maxHTTPFsResponseSize = 1048576 +) + +var ( + supportedEndpointSchema = []string{"http://", "https://"} +) + +// HTTPFsConfig defines the configuration for HTTP based filesystem +type HTTPFsConfig struct { + sdk.BaseHTTPFsConfig + Password *kms.Secret `json:"password,omitempty"` + APIKey *kms.Secret `json:"api_key,omitempty"` +} + +func (c *HTTPFsConfig) isUnixDomainSocket() bool { + return strings.HasPrefix(c.Endpoint, "http://unix") || strings.HasPrefix(c.Endpoint, "https://unix") +} + +// HideConfidentialData hides confidential data +func (c *HTTPFsConfig) HideConfidentialData() { + if c.Password != nil { + c.Password.Hide() + } + if c.APIKey != nil { + c.APIKey.Hide() + } +} + +func (c *HTTPFsConfig) setNilSecretsIfEmpty() { + if c.Password != nil && c.Password.IsEmpty() { + c.Password = nil + } + if c.APIKey != nil && c.APIKey.IsEmpty() { + c.APIKey = nil + } +} + +func (c *HTTPFsConfig) setEmptyCredentialsIfNil() { + if c.Password == nil { + c.Password = kms.NewEmptySecret() + } + if c.APIKey == nil { + c.APIKey = kms.NewEmptySecret() + } +} + +func (c *HTTPFsConfig) isEqual(other HTTPFsConfig) bool { + if c.Endpoint != other.Endpoint { + return false + } + if c.Username != other.Username { + return false + } + if c.SkipTLSVerify != other.SkipTLSVerify { + return false + } + c.setEmptyCredentialsIfNil() + other.setEmptyCredentialsIfNil() + if !c.Password.IsEqual(other.Password) { + return false + } + return c.APIKey.IsEqual(other.APIKey) +} + +func (c *HTTPFsConfig) isSameResource(other HTTPFsConfig) bool { + if c.EqualityCheckMode > 0 || other.EqualityCheckMode > 0 { + if c.Username != other.Username { + return false + } + } + return c.Endpoint == other.Endpoint +} + +// validate returns an error if the configuration is not valid +func (c *HTTPFsConfig) validate() error { + c.setEmptyCredentialsIfNil() + if c.Endpoint == "" { + return util.NewI18nError(errors.New("httpfs: endpoint cannot be empty"), util.I18nErrorEndpointRequired) + } + c.Endpoint = strings.TrimRight(c.Endpoint, "/") + endpointURL, err := url.Parse(c.Endpoint) + if err != nil { + return util.NewI18nError(fmt.Errorf("httpfs: invalid endpoint: %w", err), util.I18nErrorEndpointInvalid) + } + if !util.IsStringPrefixInSlice(c.Endpoint, supportedEndpointSchema) { + return util.NewI18nError( + errors.New("httpfs: invalid endpoint schema: http and https are supported"), + util.I18nErrorEndpointInvalid, + ) + } + if endpointURL.Host == "unix" { + socketPath := endpointURL.Query().Get("socket_path") + if !filepath.IsAbs(socketPath) { + return util.NewI18nError( + fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath), + util.I18nErrorEndpointInvalid, + ) + } + } + if !isEqualityCheckModeValid(c.EqualityCheckMode) { + return errors.New("invalid equality_check_mode") + } + if c.Password.IsEncrypted() && !c.Password.IsValid() { + return errors.New("httpfs: invalid encrypted password") + } + if !c.Password.IsEmpty() && !c.Password.IsValidInput() { + return errors.New("httpfs: invalid password") + } + if c.APIKey.IsEncrypted() && !c.APIKey.IsValid() { + return errors.New("httpfs: invalid encrypted API key") + } + if !c.APIKey.IsEmpty() && !c.APIKey.IsValidInput() { + return errors.New("httpfs: invalid API key") + } + return nil +} + +// ValidateAndEncryptCredentials validates the config and encrypts credentials if they are in plain text +func (c *HTTPFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + err := c.validate() + if err != nil { + var errI18n *util.I18nError + errValidation := util.NewValidationError(fmt.Sprintf("could not validate HTTP fs config: %v", err)) + if errors.As(err, &errI18n) { + return util.NewI18nError(errValidation, errI18n.Message) + } + return util.NewI18nError(errValidation, util.I18nErrorFsValidation) + } + if c.Password.IsPlain() { + c.Password.SetAdditionalData(additionalData) + if err := c.Password.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt HTTP fs password: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + if c.APIKey.IsPlain() { + c.APIKey.SetAdditionalData(additionalData) + if err := c.APIKey.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt HTTP fs API key: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + return nil +} + +// HTTPFs is a Fs implementation for the SFTPGo HTTP filesystem backend +type HTTPFs struct { + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + config *HTTPFsConfig + client *http.Client + ctxTimeout time.Duration +} + +// NewHTTPFs returns an HTTPFs object that allows to interact with SFTPGo HTTP filesystem backends +func NewHTTPFs(connectionID, localTempDir, mountPath string, config HTTPFsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = getLocalTempDir() + } + config.setEmptyCredentialsIfNil() + if !config.Password.IsEmpty() { + if err := config.Password.TryDecrypt(); err != nil { + return nil, err + } + } + if !config.APIKey.IsEmpty() { + if err := config.APIKey.TryDecrypt(); err != nil { + return nil, err + } + } + fs := &HTTPFs{ + connectionID: connectionID, + localTempDir: localTempDir, + mountPath: mountPath, + config: &config, + ctxTimeout: 30 * time.Second, + } + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxResponseHeaderBytes = 1 << 16 + transport.WriteBufferSize = 1 << 16 + transport.ReadBufferSize = 1 << 16 + if fs.config.isUnixDomainSocket() { + endpointURL, err := url.Parse(fs.config.Endpoint) + if err != nil { + return nil, err + } + if endpointURL.Host == "unix" { + socketPath := endpointURL.Query().Get("socket_path") + if !filepath.IsAbs(socketPath) { + return nil, fmt.Errorf("httpfs: invalid unix domain socket path: %q", socketPath) + } + if endpointURL.Scheme == "https" { + transport.DialTLSContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var tlsConfig *tls.Config + var d tls.Dialer + if config.SkipTLSVerify { + tlsConfig = getInsecureTLSConfig() + } + d.Config = tlsConfig + return d.DialContext(ctx, "unix", socketPath) + } + } else { + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + } + } + endpointURL.Path = path.Join(endpointURL.Path, endpointURL.Query().Get("api_prefix")) + endpointURL.RawQuery = "" + endpointURL.RawFragment = "" + fs.config.Endpoint = endpointURL.String() + } + } + if config.SkipTLSVerify { + if transport.TLSClientConfig != nil { + transport.TLSClientConfig.InsecureSkipVerify = true + } else { + transport.TLSClientConfig = getInsecureTLSConfig() + } + } + fs.client = &http.Client{ + Transport: transport, + } + return fs, nil +} + +// Name returns the name for the Fs implementation +func (fs *HTTPFs) Name() string { + return fmt.Sprintf("%v %q", httpFsName, fs.config.Endpoint) +} + +// ConnectionID returns the connection ID associated to this Fs implementation +func (fs *HTTPFs) ConnectionID() string { + return fs.connectionID +} + +// Stat returns a FileInfo describing the named file +func (fs *HTTPFs) Stat(name string) (os.FileInfo, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "stat", name, "", "", nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) + if err != nil { + return nil, err + } + var response statResponse + err = json.Unmarshal(respBody, &response) + if err != nil { + return nil, err + } + return response.getFileInfo(), nil +} + +// Lstat returns a FileInfo describing the named file +func (fs *HTTPFs) Lstat(name string) (os.FileInfo, error) { + return fs.Stat(name) +} + +// Open opens the named file for reading +func (fs *HTTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) { + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + return nil, nil, nil, err + } + p := NewPipeReader(r) + ctx, cancelFn := context.WithCancel(context.Background()) + + var queryString string + if offset > 0 { + queryString = fmt.Sprintf("?offset=%d", offset) + } + + go func() { + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "open", name, queryString, "", nil) + if err != nil { + fsLog(fs, logger.LevelError, "download error, path %q, err: %v", name, err) + w.CloseWithError(err) //nolint:errcheck + metric.HTTPFsTransferCompleted(0, 1, err) + return + } + defer resp.Body.Close() + n, err := io.Copy(w, resp.Body) + w.CloseWithError(err) //nolint:errcheck + fsLog(fs, logger.LevelDebug, "download completed, path %q size: %v, err: %+v", name, n, err) + metric.HTTPFsTransferCompleted(n, 1, err) + }() + + return nil, p, cancelFn, nil +} + +// Create creates or opens the named file for writing +func (fs *HTTPFs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + return nil, nil, nil, err + } + p := NewPipeWriter(w) + ctx, cancelFn := context.WithCancel(context.Background()) + + go func() { + defer cancelFn() + + contentType := mime.TypeByExtension(path.Ext(name)) + queryString := fmt.Sprintf("?flags=%d&checks=%d", flag, checks) + resp, err := fs.sendHTTPRequest(ctx, http.MethodPost, "create", name, queryString, contentType, + &wrapReader{reader: r}) + if err != nil { + fsLog(fs, logger.LevelError, "upload error, path %q, err: %v", name, err) + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + metric.HTTPFsTransferCompleted(0, 0, err) + return + } + defer resp.Body.Close() + + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %d", name, r.GetReadedBytes()) + metric.HTTPFsTransferCompleted(r.GetReadedBytes(), 0, err) + }() + + return nil, p, cancelFn, nil +} + +// Rename renames (moves) source to target. +func (fs *HTTPFs) Rename(source, target string, checks int) (int, int64, error) { + if source == target { + return -1, -1, nil + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + queryString := fmt.Sprintf("?target=%s", url.QueryEscape(target)) + resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "rename", source, queryString, "", nil) + if err != nil { + return -1, -1, err + } + defer resp.Body.Close() + if checks&CheckUpdateModTime != 0 { + fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck + } + return -1, -1, nil +} + +// Remove removes the named file or (empty) directory. +func (fs *HTTPFs) Remove(name string, _ bool) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodDelete, "remove", name, "", "", nil) + if err != nil { + return err + } + defer resp.Body.Close() + return nil +} + +// Mkdir creates a new directory with the specified name and default permissions +func (fs *HTTPFs) Mkdir(name string) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodPost, "mkdir", name, "", "", nil) + if err != nil { + return err + } + defer resp.Body.Close() + return nil +} + +// Symlink creates source as a symbolic link to target. +func (*HTTPFs) Symlink(_, _ string) error { + return ErrVfsUnsupported +} + +// Readlink returns the destination of the named symbolic link +func (*HTTPFs) Readlink(_ string) (string, error) { + return "", ErrVfsUnsupported +} + +// Chown changes the numeric uid and gid of the named file. +func (fs *HTTPFs) Chown(_ string, _ int, _ int) error { + return ErrVfsUnsupported +} + +// Chmod changes the mode of the named file to mode. +func (fs *HTTPFs) Chmod(name string, mode os.FileMode) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + queryString := fmt.Sprintf("?mode=%d", mode) + resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "chmod", name, queryString, "", nil) + if err != nil { + return err + } + defer resp.Body.Close() + return nil +} + +// Chtimes changes the access and modification times of the named file. +func (fs *HTTPFs) Chtimes(name string, atime, mtime time.Time, _ bool) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + queryString := fmt.Sprintf("?access_time=%s&modification_time=%s", atime.UTC().Format(time.RFC3339), + mtime.UTC().Format(time.RFC3339)) + resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "chtimes", name, queryString, "", nil) + if err != nil { + return err + } + defer resp.Body.Close() + return nil +} + +// Truncate changes the size of the named file. +// Truncate by path is not supported, while truncating an opened +// file is handled inside base transfer +func (fs *HTTPFs) Truncate(name string, size int64) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + queryString := fmt.Sprintf("?size=%d", size) + resp, err := fs.sendHTTPRequest(ctx, http.MethodPatch, "truncate", name, queryString, "", nil) + if err != nil { + return err + } + defer resp.Body.Close() + return nil +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (fs *HTTPFs) ReadDir(dirname string) (DirLister, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "readdir", dirname, "", "", nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize*10)) + if err != nil { + return nil, err + } + var response []statResponse + err = json.Unmarshal(respBody, &response) + if err != nil { + return nil, err + } + result := make([]os.FileInfo, 0, len(response)) + for _, stat := range response { + result = append(result, stat.getFileInfo()) + } + return &baseDirLister{result}, nil +} + +// IsUploadResumeSupported returns true if resuming uploads is supported. +func (*HTTPFs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (*HTTPFs) IsConditionalUploadResumeSupported(_ int64) bool { + return false +} + +// IsAtomicUploadSupported returns true if atomic upload is supported. +func (*HTTPFs) IsAtomicUploadSupported() bool { + return false +} + +// IsNotExist returns a boolean indicating whether the error is known to +// report that a file or directory does not exist +func (*HTTPFs) IsNotExist(err error) bool { + return errors.Is(err, fs.ErrNotExist) +} + +// IsPermission returns a boolean indicating whether the error is known to +// report that permission is denied. +func (*HTTPFs) IsPermission(err error) bool { + return errors.Is(err, fs.ErrPermission) +} + +// IsNotSupported returns true if the error indicate an unsupported operation +func (*HTTPFs) IsNotSupported(err error) bool { + if err == nil { + return false + } + return err == ErrVfsUnsupported +} + +// CheckRootPath creates the specified local root directory if it does not exists +func (fs *HTTPFs) CheckRootPath(username string, uid int, gid int) bool { + // we need a local directory for temporary files + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) + return osFs.CheckRootPath(username, uid, gid) +} + +// ScanRootDirContents returns the number of files and their size +func (fs *HTTPFs) ScanRootDirContents() (int, int64, error) { + return fs.GetDirSize("/") +} + +// CheckMetadata checks the metadata consistency +func (*HTTPFs) CheckMetadata() error { + return nil +} + +// GetDirSize returns the number of files and the size for a folder +// including any subfolders +func (fs *HTTPFs) GetDirSize(dirname string) (int, int64, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "dirsize", dirname, "", "", nil) + if err != nil { + return 0, 0, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) + if err != nil { + return 0, 0, err + } + + var response dirSizeResponse + err = json.Unmarshal(respBody, &response) + if err != nil { + return 0, 0, err + } + return response.Files, response.Size, nil +} + +// GetAtomicUploadPath returns the path to use for an atomic upload. +func (*HTTPFs) GetAtomicUploadPath(_ string) string { + return "" +} + +// GetRelativePath returns the path for a file relative to the user's home dir. +// This is the path as seen by SFTPGo users +func (fs *HTTPFs) GetRelativePath(name string) string { + rel := path.Clean(name) + if rel == "." { + rel = "" + } + if !path.IsAbs(rel) { + rel = "/" + rel + } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } + return rel +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root. The result are unordered +func (fs *HTTPFs) Walk(root string, walkFn filepath.WalkFunc) error { + info, err := fs.Lstat(root) + if err != nil { + return walkFn(root, nil, err) + } + return fs.walk(root, info, walkFn) +} + +// Join joins any number of path elements into a single path +func (*HTTPFs) Join(elem ...string) string { + return strings.TrimPrefix(path.Join(elem...), "/") +} + +// HasVirtualFolders returns true if folders are emulated +func (*HTTPFs) HasVirtualFolders() bool { + return false +} + +// ResolvePath returns the matching filesystem path for the specified virtual path +func (fs *HTTPFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { + virtualPath = after + } + } + return path.Clean("/" + virtualPath), nil +} + +// GetMimeType returns the content type +func (fs *HTTPFs) GetMimeType(name string) (string, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "stat", name, "", "", nil) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) + if err != nil { + return "", err + } + + var response mimeTypeResponse + err = json.Unmarshal(respBody, &response) + if err != nil { + return "", err + } + return response.Mime, nil +} + +// Close closes the fs +func (fs *HTTPFs) Close() error { + fs.client.CloseIdleConnections() + return nil +} + +// GetAvailableDiskSize returns the available size for the specified path +func (fs *HTTPFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + resp, err := fs.sendHTTPRequest(ctx, http.MethodGet, "statvfs", dirName, "", "", nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxHTTPFsResponseSize)) + if err != nil { + return nil, err + } + + var response statVFSResponse + err = json.Unmarshal(respBody, &response) + if err != nil { + return nil, err + } + return response.toSFTPStatVFS(), nil +} + +func (fs *HTTPFs) sendHTTPRequest(ctx context.Context, method, base, name, queryString, contentType string, + body io.Reader, +) (*http.Response, error) { + url := fmt.Sprintf("%s/%s/%s%s", fs.config.Endpoint, base, url.PathEscape(name), queryString) + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + if fs.config.APIKey.GetPayload() != "" { + req.Header.Set("X-API-KEY", fs.config.APIKey.GetPayload()) + } + if fs.config.Username != "" || fs.config.Password.GetPayload() != "" { + req.SetBasicAuth(fs.config.Username, fs.config.Password.GetPayload()) + } + resp, err := fs.client.Do(req.WithContext(ctx)) + if err != nil { + return nil, fmt.Errorf("unable to send HTTP request to URL %v: %w", url, err) + } + if err = getErrorFromResponseCode(resp.StatusCode); err != nil { + resp.Body.Close() + return nil, err + } + return resp, nil +} + +// walk recursively descends path, calling walkFn. +func (fs *HTTPFs) walk(filePath string, info fs.FileInfo, walkFn filepath.WalkFunc) error { + if !info.IsDir() { + return walkFn(filePath, info, nil) + } + lister, err := fs.ReadDir(filePath) + err1 := walkFn(filePath, info, err) + if err != nil || err1 != nil { + if err == nil { + lister.Close() + } + return err1 + } + defer lister.Close() + + for { + files, err := lister.Next(ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return err + } + for _, fi := range files { + objName := path.Join(filePath, fi.Name()) + err = fs.walk(objName, fi, walkFn) + if err != nil { + return err + } + } + if finished { + return nil + } + } +} + +func getErrorFromResponseCode(code int) error { + switch code { + case 401, 403: + return os.ErrPermission + case 404: + return os.ErrNotExist + case 501: + return ErrVfsUnsupported + case 200, 201: + return nil + default: + return fmt.Errorf("unexpected response code: %v", code) + } +} + +func getInsecureTLSConfig() *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, + } +} + +type wrapReader struct { + reader io.Reader +} + +func (r *wrapReader) Read(p []byte) (n int, err error) { + return r.reader.Read(p) +} + +type statResponse struct { + Name string `json:"name"` + Size int64 `json:"size"` + Mode uint32 `json:"mode"` + LastModified time.Time `json:"last_modified"` +} + +func (s *statResponse) getFileInfo() os.FileInfo { + info := NewFileInfo(s.Name, false, s.Size, s.LastModified, false) + info.SetMode(fs.FileMode(s.Mode)) + return info +} + +type dirSizeResponse struct { + Files int `json:"files"` + Size int64 `json:"size"` +} + +type mimeTypeResponse struct { + Mime string `json:"mime"` +} + +type statVFSResponse struct { + ID uint32 `json:"-"` + Bsize uint64 `json:"bsize"` + Frsize uint64 `json:"frsize"` + Blocks uint64 `json:"blocks"` + Bfree uint64 `json:"bfree"` + Bavail uint64 `json:"bavail"` + Files uint64 `json:"files"` + Ffree uint64 `json:"ffree"` + Favail uint64 `json:"favail"` + Fsid uint64 `json:"fsid"` + Flag uint64 `json:"flag"` + Namemax uint64 `json:"namemax"` +} + +func (s *statVFSResponse) toSFTPStatVFS() *sftp.StatVFS { + return &sftp.StatVFS{ + Bsize: s.Bsize, + Frsize: s.Frsize, + Blocks: s.Blocks, + Bfree: s.Bfree, + Bavail: s.Bavail, + Files: s.Files, + Ffree: s.Ffree, + Favail: s.Ffree, + Flag: s.Flag, + Namemax: s.Namemax, + } +} diff --git a/internal/vfs/osfs.go b/internal/vfs/osfs.go new file mode 100644 index 00000000..1f0a502f --- /dev/null +++ b/internal/vfs/osfs.go @@ -0,0 +1,623 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "bufio" + "errors" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "path" + "path/filepath" + "slices" + "strings" + "time" + + fscopy "github.com/otiai10/copy" + "github.com/pkg/sftp" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/logger" +) + +const ( + // osFsName is the name for the local Fs implementation + osFsName = "osfs" +) + +type pathResolutionError struct { + err string +} + +func (e *pathResolutionError) Error() string { + return fmt.Sprintf("Path resolution error: %s", e.err) +} + +// OsFs is a Fs implementation that uses functions provided by the os package. +type OsFs struct { + name string + connectionID string + rootDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + localTempDir string + readBufferSize int + writeBufferSize int +} + +// NewOsFs returns an OsFs object that allows to interact with local Os filesystem +func NewOsFs(connectionID, rootDir, mountPath string, config *sdk.OSFsConfig) Fs { + var readBufferSize, writeBufferSize int + if config != nil { + readBufferSize = config.ReadBufferSize * 1024 * 1024 + writeBufferSize = config.WriteBufferSize * 1024 * 1024 + } + return &OsFs{ + name: osFsName, + connectionID: connectionID, + rootDir: rootDir, + mountPath: getMountPath(mountPath), + localTempDir: getLocalTempDir(), + readBufferSize: readBufferSize, + writeBufferSize: writeBufferSize, + } +} + +// Name returns the name for the Fs implementation +func (fs *OsFs) Name() string { + return fs.name +} + +// ConnectionID returns the SSH connection ID associated to this Fs implementation +func (fs *OsFs) ConnectionID() string { + return fs.connectionID +} + +// Stat returns a FileInfo describing the named file +func (fs *OsFs) Stat(name string) (os.FileInfo, error) { + return os.Stat(name) +} + +// Lstat returns a FileInfo describing the named file +func (fs *OsFs) Lstat(name string) (os.FileInfo, error) { + return os.Lstat(name) +} + +// Open opens the named file for reading +func (fs *OsFs) Open(name string, offset int64) (File, PipeReader, func(), error) { + f, err := os.Open(name) + if err != nil { + return nil, nil, nil, err + } + if offset > 0 { + _, err = f.Seek(offset, io.SeekStart) + if err != nil { + f.Close() + return nil, nil, nil, err + } + } + if fs.readBufferSize <= 0 { + return f, nil, nil, err + } + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + f.Close() + return nil, nil, nil, err + } + p := NewPipeReader(r) + go func() { + br := bufio.NewReaderSize(f, fs.readBufferSize) + n, err := doCopy(w, br, nil) + w.CloseWithError(err) //nolint:errcheck + f.Close() + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %v", name, n, err) + }() + + return nil, p, nil, nil +} + +// Create creates or opens the named file for writing +func (fs *OsFs) Create(name string, flag, _ int) (File, PipeWriter, func(), error) { + if !fs.useWriteBuffering(flag) { + var err error + var f *os.File + if flag == 0 { + f, err = os.Create(name) + } else { + f, err = os.OpenFile(name, flag, 0666) + } + return f, nil, nil, err + } + f, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return nil, nil, nil, err + } + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + f.Close() + return nil, nil, nil, err + } + p := NewPipeWriter(w) + + go func() { + bw := bufio.NewWriterSize(f, fs.writeBufferSize) + n, err := doCopy(bw, r, nil) + errFlush := bw.Flush() + if err == nil && errFlush != nil { + err = errFlush + } + errClose := f.Close() + if err == nil && errClose != nil { + err = errClose + } + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %v", name, n, err) + }() + + return nil, p, nil, nil +} + +// Rename renames (moves) source to target +func (fs *OsFs) Rename(source, target string, checks int) (int, int64, error) { + if source == target { + return -1, -1, nil + } + err := os.Rename(source, target) + if err != nil && isCrossDeviceError(err) { + fsLog(fs, logger.LevelError, "cross device error detected while renaming %q -> %q. Trying a copy and remove, this could take a long time", + source, target) + var readBufferSize uint + if fs.readBufferSize > 0 { + readBufferSize = uint(fs.readBufferSize) + } + + err = fscopy.Copy(source, target, fscopy.Options{ + OnSymlink: func(_ string) fscopy.SymlinkAction { + return fscopy.Skip + }, + CopyBufferSize: readBufferSize, + }) + if err != nil { + fsLog(fs, logger.LevelError, "cross device copy error: %v", err) + return -1, -1, err + } + if checks&CheckUpdateModTime != 0 { + fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck + } + err = os.RemoveAll(source) + return -1, -1, err + } + if checks&CheckUpdateModTime != 0 && err == nil { + fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck + } + return -1, -1, err +} + +// Remove removes the named file or (empty) directory. +func (*OsFs) Remove(name string, _ bool) error { + return os.Remove(name) +} + +// Mkdir creates a new directory with the specified name and default permissions +func (*OsFs) Mkdir(name string) error { + return os.Mkdir(name, os.ModePerm) +} + +// Symlink creates source as a symbolic link to target. +func (*OsFs) Symlink(source, target string) error { + return os.Symlink(source, target) +} + +// Readlink returns the destination of the named symbolic link +// as absolute virtual path +func (fs *OsFs) Readlink(name string) (string, error) { + // we don't have to follow multiple links: + // https://github.com/openssh/openssh-portable/blob/7bf2eb958fbb551e7d61e75c176bb3200383285d/sftp-server.c#L1329 + resolved, err := os.Readlink(name) + if err != nil { + return "", err + } + resolved = filepath.Clean(resolved) + if !filepath.IsAbs(resolved) { + resolved = filepath.Join(filepath.Dir(name), resolved) + } + return fs.GetRelativePath(resolved), nil +} + +// Chown changes the numeric uid and gid of the named file. +func (*OsFs) Chown(name string, uid int, gid int) error { + return os.Chown(name, uid, gid) +} + +// Chmod changes the mode of the named file to mode +func (*OsFs) Chmod(name string, mode os.FileMode) error { + return os.Chmod(name, mode) +} + +// Chtimes changes the access and modification times of the named file +func (*OsFs) Chtimes(name string, atime, mtime time.Time, _ bool) error { + return os.Chtimes(name, atime, mtime) +} + +// Truncate changes the size of the named file +func (*OsFs) Truncate(name string, size int64) error { + return os.Truncate(name, size) +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (*OsFs) ReadDir(dirname string) (DirLister, error) { + f, err := os.Open(dirname) + if err != nil { + if isInvalidNameError(err) { + err = os.ErrNotExist + } + return nil, err + } + return &osFsDirLister{f}, nil +} + +// IsUploadResumeSupported returns true if resuming uploads is supported +func (*OsFs) IsUploadResumeSupported() bool { + return true +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (*OsFs) IsConditionalUploadResumeSupported(_ int64) bool { + return true +} + +// IsAtomicUploadSupported returns true if atomic upload is supported +func (*OsFs) IsAtomicUploadSupported() bool { + return true +} + +// IsNotExist returns a boolean indicating whether the error is known to +// report that a file or directory does not exist +func (*OsFs) IsNotExist(err error) bool { + return errors.Is(err, fs.ErrNotExist) +} + +// IsPermission returns a boolean indicating whether the error is known to +// report that permission is denied. +func (*OsFs) IsPermission(err error) bool { + if _, ok := err.(*pathResolutionError); ok { + return true + } + return errors.Is(err, fs.ErrPermission) +} + +// IsNotSupported returns true if the error indicate an unsupported operation +func (*OsFs) IsNotSupported(err error) bool { + if err == nil { + return false + } + return err == ErrVfsUnsupported +} + +// CheckRootPath creates the root directory if it does not exists +func (fs *OsFs) CheckRootPath(username string, uid int, gid int) bool { + var err error + if _, err = fs.Stat(fs.rootDir); fs.IsNotExist(err) { + err = os.MkdirAll(fs.rootDir, os.ModePerm) + if err == nil { + SetPathPermissions(fs, fs.rootDir, uid, gid) + } else { + fsLog(fs, logger.LevelError, "error creating root directory %q for user %q: %v", fs.rootDir, username, err) + } + } + return err == nil +} + +// ScanRootDirContents returns the number of files contained in the root +// directory and their size +func (fs *OsFs) ScanRootDirContents() (int, int64, error) { + return fs.GetDirSize(fs.rootDir) +} + +// CheckMetadata checks the metadata consistency +func (*OsFs) CheckMetadata() error { + return nil +} + +// GetAtomicUploadPath returns the path to use for an atomic upload +func (*OsFs) GetAtomicUploadPath(name string) string { + dir := filepath.Dir(name) + if tempPath != "" { + dir = tempPath + } + guid := xid.New().String() + return filepath.Join(dir, ".sftpgo-upload."+guid+"."+filepath.Base(name)) +} + +// GetRelativePath returns the path for a file relative to the user's home dir. +// This is the path as seen by SFTPGo users +func (fs *OsFs) GetRelativePath(name string) string { + virtualPath := "/" + if fs.mountPath != "" { + virtualPath = fs.mountPath + } + rel, err := filepath.Rel(fs.rootDir, filepath.Clean(name)) + if err != nil { + return virtualPath + } + rel = filepath.ToSlash(rel) + if rel == ".." || strings.HasPrefix(rel, "../") { + return virtualPath + } + if rel == "." { + rel = "" + } + return path.Join(virtualPath, rel) +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root +func (*OsFs) Walk(root string, walkFn filepath.WalkFunc) error { + return filepath.Walk(root, walkFn) +} + +// Join joins any number of path elements into a single path +func (*OsFs) Join(elem ...string) string { + return filepath.Join(elem...) +} + +// ResolvePath returns the matching filesystem path for the specified sftp path +func (fs *OsFs) ResolvePath(virtualPath string) (string, error) { + if !filepath.IsAbs(fs.rootDir) { + return "", fmt.Errorf("invalid root path %q", fs.rootDir) + } + if fs.mountPath != "" { + if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { + virtualPath = after + } + } + virtualPath = path.Clean("/" + virtualPath) + r := filepath.Clean(filepath.Join(fs.rootDir, virtualPath)) + p, err := filepath.EvalSymlinks(r) + if isInvalidNameError(err) { + err = os.ErrNotExist + } + isNotExist := fs.IsNotExist(err) + if err != nil && !isNotExist { + return "", err + } else if isNotExist { + // The requested path doesn't exist, so at this point we need to iterate up the + // path chain until we hit a directory that _does_ exist and can be validated. + _, err = fs.findFirstExistingDir(r) + if err != nil { + fsLog(fs, logger.LevelError, "error resolving non-existent path %q", err) + } + return r, err + } + + err = fs.isSubDir(p) + if err != nil { + fsLog(fs, logger.LevelError, "Invalid path resolution, path %q original path %q resolved %q err: %v", + p, virtualPath, r, err) + } + return r, err +} + +// RealPath implements the FsRealPather interface +func (fs *OsFs) RealPath(p string) (string, error) { + linksWalked := 0 + for { + info, err := os.Lstat(p) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return fs.GetRelativePath(p), nil + } + return "", err + } + if info.Mode()&os.ModeSymlink == 0 { + return fs.GetRelativePath(p), nil + } + resolvedLink, err := os.Readlink(p) + if err != nil { + return "", err + } + resolvedLink = filepath.Clean(resolvedLink) + if filepath.IsAbs(resolvedLink) { + p = resolvedLink + } else { + p = filepath.Join(filepath.Dir(p), resolvedLink) + } + linksWalked++ + if linksWalked > 10 { + fsLog(fs, logger.LevelError, "unable to get real path, too many links: %d", linksWalked) + return "", &pathResolutionError{err: "too many links"} + } + } +} + +// GetDirSize returns the number of files and the size for a folder +// including any subfolders +func (fs *OsFs) GetDirSize(dirname string) (int, int64, error) { + numFiles := 0 + size := int64(0) + isDir, err := isDirectory(fs, dirname) + if err == nil && isDir { + err = filepath.Walk(dirname, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info != nil && info.Mode().IsRegular() { + size += info.Size() + numFiles++ + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "dirname %q scan in progress, files: %d, size: %d", dirname, numFiles, size) + } + } + return err + }) + } + return numFiles, size, err +} + +// HasVirtualFolders returns true if folders are emulated +func (*OsFs) HasVirtualFolders() bool { + return false +} + +func (fs *OsFs) findNonexistentDirs(filePath string) ([]string, error) { + results := []string{} + cleanPath := filepath.Clean(filePath) + parent := filepath.Dir(cleanPath) + _, err := os.Stat(parent) + + for fs.IsNotExist(err) { + results = append(results, parent) + parent = filepath.Dir(parent) + if slices.Contains(results, parent) { + break + } + _, err = os.Stat(parent) + } + if err != nil { + return results, err + } + p, err := filepath.EvalSymlinks(parent) + if err != nil { + return results, err + } + err = fs.isSubDir(p) + if err != nil { + fsLog(fs, logger.LevelError, "error finding non existing dir: %v", err) + } + return results, err +} + +func (fs *OsFs) findFirstExistingDir(path string) (string, error) { + results, err := fs.findNonexistentDirs(path) + if err != nil { + fsLog(fs, logger.LevelError, "unable to find non existent dirs: %v", err) + return "", err + } + var parent string + if len(results) > 0 { + lastMissingDir := results[len(results)-1] + parent = filepath.Dir(lastMissingDir) + } else { + parent = fs.rootDir + } + p, err := filepath.EvalSymlinks(parent) + if err != nil { + return "", err + } + fileInfo, err := os.Stat(p) + if err != nil { + return "", err + } + if !fileInfo.IsDir() { + return "", fmt.Errorf("resolved path is not a dir: %q", p) + } + err = fs.isSubDir(p) + return p, err +} + +func (fs *OsFs) isSubDir(sub string) error { + // fs.rootDir must exist and it is already a validated absolute path + parent, err := filepath.EvalSymlinks(fs.rootDir) + if err != nil { + fsLog(fs, logger.LevelError, "invalid root path %q: %v", fs.rootDir, err) + return err + } + if parent == sub { + return nil + } + if len(sub) < len(parent) { + err = fmt.Errorf("path %q is not inside %q", sub, parent) + return &pathResolutionError{err: err.Error()} + } + separator := string(os.PathSeparator) + if parent == filepath.Dir(parent) { + // parent is the root dir, on Windows we can have C:\, D:\ and so on here + // so we still need the prefix check + separator = "" + } + if !strings.HasPrefix(sub, parent+separator) { + err = fmt.Errorf("path %q is not inside %q", sub, parent) + return &pathResolutionError{err: err.Error()} + } + return nil +} + +// GetMimeType returns the content type +func (fs *OsFs) GetMimeType(name string) (string, error) { + f, err := os.OpenFile(name, os.O_RDONLY, 0) + if err != nil { + return "", err + } + defer f.Close() + var buf [512]byte + n, err := io.ReadFull(f, buf[:]) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return "", err + } + ctype := http.DetectContentType(buf[:n]) + // Rewind file. + _, err = f.Seek(0, io.SeekStart) + return ctype, err +} + +// Close closes the fs +func (*OsFs) Close() error { + return nil +} + +// GetAvailableDiskSize returns the available size for the specified path +func (*OsFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { + return getStatFS(dirName) +} + +func (fs *OsFs) useWriteBuffering(flag int) bool { + if fs.writeBufferSize <= 0 { + return false + } + if flag == 0 { + return true + } + if flag&os.O_TRUNC == 0 { + fsLog(fs, logger.LevelDebug, "truncate flag missing, buffering write not possible") + return false + } + if flag&os.O_RDWR != 0 { + fsLog(fs, logger.LevelDebug, "read and write flag found, buffering write not possible") + return false + } + return true +} + +type osFsDirLister struct { + f *os.File +} + +func (l *osFsDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + return l.f.Readdir(limit) +} + +func (l *osFsDirLister) Close() error { + return l.f.Close() +} diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go new file mode 100644 index 00000000..f0e9c093 --- /dev/null +++ b/internal/vfs/s3fs.go @@ -0,0 +1,1425 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !nos3 + +package vfs + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/sha256" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "mime" + "net" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "slices" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/pkg/sftp" + + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + // using this mime type for directories improves compatibility with s3fs-fuse + s3DirMimeType = "application/x-directory" + s3TransferBufferSize = 256 * 1024 + s3CopyObjectThreshold = 500 * 1024 * 1024 +) + +var ( + s3DirMimeTypes = []string{s3DirMimeType, "httpd/unix-directory"} + s3DefaultPageSize = int32(5000) +) + +// S3Fs is a Fs implementation for AWS S3 compatible object storages +type S3Fs struct { + connectionID string + localTempDir string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + config *S3FsConfig + svc *s3.Client + ctxTimeout time.Duration + sseCustomerKey string + sseCustomerKeyMD5 string + sseCustomerAlgo string +} + +func init() { + version.AddFeature("+s3") +} + +// NewS3Fs returns an S3Fs object that allows to interact with an s3 compatible +// object storage +func NewS3Fs(connectionID, localTempDir, mountPath string, s3Config S3FsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = getLocalTempDir() + } + fs := &S3Fs{ + connectionID: connectionID, + localTempDir: localTempDir, + mountPath: getMountPath(mountPath), + config: &s3Config, + ctxTimeout: 30 * time.Second, + } + if err := fs.config.validate(); err != nil { + return fs, err + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + awsConfig, err := config.LoadDefaultConfig(ctx, config.WithHTTPClient( + getAWSHTTPClient(0, 30*time.Second, fs.config.SkipTLSVerify)), + ) + if err != nil { + return fs, fmt.Errorf("unable to get AWS config: %w", err) + } + if fs.config.Region != "" { + awsConfig.Region = fs.config.Region + } + if !fs.config.AccessSecret.IsEmpty() { + if err := fs.config.AccessSecret.TryDecrypt(); err != nil { + return fs, err + } + awsConfig.Credentials = aws.NewCredentialsCache( + credentials.NewStaticCredentialsProvider( + fs.config.AccessKey, + fs.config.AccessSecret.GetPayload(), + fs.config.SessionToken), + ) + } + if !fs.config.SSECustomerKey.IsEmpty() { + if err := fs.config.SSECustomerKey.TryDecrypt(); err != nil { + return fs, err + } + key := fs.config.SSECustomerKey.GetPayload() + if len(key) == 32 { + md5sumBinary := md5.Sum([]byte(key)) + fs.sseCustomerKey = base64.StdEncoding.EncodeToString([]byte(key)) + fs.sseCustomerKeyMD5 = base64.StdEncoding.EncodeToString(md5sumBinary[:]) + } else { + keyHash := sha256.Sum256([]byte(key)) + md5sumBinary := md5.Sum(keyHash[:]) + fs.sseCustomerKey = base64.StdEncoding.EncodeToString(keyHash[:]) + fs.sseCustomerKeyMD5 = base64.StdEncoding.EncodeToString(md5sumBinary[:]) + } + fs.sseCustomerAlgo = "AES256" + } + + fs.setConfigDefaults() + + if fs.config.RoleARN != "" { + client := sts.NewFromConfig(awsConfig) + creds := stscreds.NewAssumeRoleProvider(client, fs.config.RoleARN) + awsConfig.Credentials = creds + } + fs.svc = s3.NewFromConfig(awsConfig, func(o *s3.Options) { + o.AppID = version.GetVersionHash() + o.UsePathStyle = fs.config.ForcePathStyle + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired + if fs.config.Endpoint != "" { + o.BaseEndpoint = aws.String(fs.config.Endpoint) + } + }) + return fs, nil +} + +// Name returns the name for the Fs implementation +func (fs *S3Fs) Name() string { + return fmt.Sprintf("%s bucket %q", s3fsName, fs.config.Bucket) +} + +// ConnectionID returns the connection ID associated to this Fs implementation +func (fs *S3Fs) ConnectionID() string { + return fs.connectionID +} + +// Stat returns a FileInfo describing the named file +func (fs *S3Fs) Stat(name string) (os.FileInfo, error) { + var result *FileInfo + if name == "" || name == "/" || name == "." { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + if fs.config.KeyPrefix == name+"/" { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } + obj, err := fs.headObject(name) + if err == nil { + // Some S3 providers (like SeaweedFS) remove the trailing '/' from object keys. + // So we check some common content types to detect if this is a "directory". + isDir := slices.Contains(s3DirMimeTypes, util.GetStringFromPointer(obj.ContentType)) + if util.GetIntFromPointer(obj.ContentLength) == 0 && !isDir { + _, err = fs.headObject(name + "/") + isDir = err == nil + } + info := NewFileInfo(name, isDir, util.GetIntFromPointer(obj.ContentLength), util.GetTimeFromPointer(obj.LastModified), false) + return info, nil + } + if !fs.IsNotExist(err) { + return result, err + } + // now check if this is a prefix (virtual directory) + hasContents, err := fs.hasContents(name) + if err == nil && hasContents { + return NewFileInfo(name, true, 0, time.Unix(0, 0), false), nil + } else if err != nil { + return nil, err + } + // the requested file may still be a directory as a zero bytes key + // with a trailing forward slash (created using mkdir). + // S3 doesn't return content type when listing objects, so we have + // create "dirs" adding a trailing "/" to the key + return fs.getStatForDir(name) +} + +func (fs *S3Fs) getStatForDir(name string) (os.FileInfo, error) { + var result *FileInfo + obj, err := fs.headObject(name + "/") + if err != nil { + return result, err + } + return NewFileInfo(name, true, util.GetIntFromPointer(obj.ContentLength), util.GetTimeFromPointer(obj.LastModified), false), nil +} + +// Lstat returns a FileInfo describing the named file +func (fs *S3Fs) Lstat(name string) (os.FileInfo, error) { + return fs.Stat(name) +} + +// Open opens the named file for reading +func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error) { + attrs, err := fs.headObject(name) + if err != nil { + return nil, nil, nil, err + } + r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1) + if err != nil { + return nil, nil, nil, err + } + p := NewPipeReader(r) + if readMetadata > 0 { + p.setMetadata(attrs.Metadata) + } + ctx, cancelFn := context.WithCancel(context.Background()) + + go func() { + defer cancelFn() + + err := fs.handleDownload(ctx, name, offset, w, attrs) + w.CloseWithError(err) //nolint:errcheck + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %d, err: %+v", name, w.GetWrittenBytes(), err) + metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err) + }() + + return nil, p, cancelFn, nil +} + +// Create creates or opens the named file for writing +func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(), error) { + if checks&CheckParentDir != 0 { + _, err := fs.Stat(path.Dir(name)) + if err != nil { + return nil, nil, nil, err + } + } + r, w, err := createPipeFn(fs.localTempDir, fs.config.UploadPartSize+1024*1024) + if err != nil { + return nil, nil, nil, err + } + var p PipeWriter + if checks&CheckResume != 0 { + p = newPipeWriterAtOffset(w, 0) + } else { + p = NewPipeWriter(w) + } + ctx, cancelFn := context.WithCancel(context.Background()) + + go func() { + defer cancelFn() + + var contentType string + if flag == -1 { + contentType = s3DirMimeType + } else { + contentType = mime.TypeByExtension(path.Ext(name)) + } + err := fs.handleUpload(ctx, r, name, contentType) + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, acl: %q, readed bytes: %d, err: %+v", + name, fs.config.ACL, r.GetReadedBytes(), err) + metric.S3TransferCompleted(r.GetReadedBytes(), 0, err) + }() + + if checks&CheckResume != 0 { + readCh := make(chan error, 1) + + go func() { + n, err := fs.downloadToWriter(name, p) + pw := p.(*pipeWriterAtOffset) + pw.offset = 0 + pw.writeOffset = n + readCh <- err + }() + + err = <-readCh + if err != nil { + cancelFn() + p.Close() + fsLog(fs, logger.LevelDebug, "download before resume failed, writer closed and read cancelled") + return nil, nil, nil, err + } + } + + if uploadMode&4 != 0 { + return nil, p, nil, nil + } + return nil, p, cancelFn, nil +} + +// Rename renames (moves) source to target. +func (fs *S3Fs) Rename(source, target string, checks int) (int, int64, error) { + if source == target { + return -1, -1, nil + } + if checks&CheckParentDir != 0 { + _, err := fs.Stat(path.Dir(target)) + if err != nil { + return -1, -1, err + } + } + fi, err := fs.Stat(source) + if err != nil { + return -1, -1, err + } + return fs.renameInternal(source, target, fi, 0, checks&CheckUpdateModTime != 0) +} + +// Remove removes the named file or (empty) directory. +func (fs *S3Fs) Remove(name string, isDir bool) error { + if isDir { + hasContents, err := fs.hasContents(name) + if err != nil { + return err + } + if hasContents { + return fmt.Errorf("cannot remove non empty directory: %q", name) + } + if !strings.HasSuffix(name, "/") { + name += "/" + } + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + _, err := fs.svc.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + }) + metric.S3DeleteObjectCompleted(err) + return err +} + +// Mkdir creates a new directory with the specified name and default permissions +func (fs *S3Fs) Mkdir(name string) error { + _, err := fs.Stat(name) + if !fs.IsNotExist(err) { + return err + } + return fs.mkdirInternal(name) +} + +// Symlink creates source as a symbolic link to target. +func (*S3Fs) Symlink(_, _ string) error { + return ErrVfsUnsupported +} + +// Readlink returns the destination of the named symbolic link +func (*S3Fs) Readlink(_ string) (string, error) { + return "", ErrVfsUnsupported +} + +// Chown changes the numeric uid and gid of the named file. +func (*S3Fs) Chown(_ string, _ int, _ int) error { + return ErrVfsUnsupported +} + +// Chmod changes the mode of the named file to mode. +func (*S3Fs) Chmod(_ string, _ os.FileMode) error { + return ErrVfsUnsupported +} + +// Chtimes changes the access and modification times of the named file. +func (fs *S3Fs) Chtimes(_ string, _, _ time.Time, _ bool) error { + return ErrVfsUnsupported +} + +// Truncate changes the size of the named file. +// Truncate by path is not supported, while truncating an opened +// file is handled inside base transfer +func (*S3Fs) Truncate(_ string, _ int64) error { + return ErrVfsUnsupported +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (fs *S3Fs) ReadDir(dirname string) (DirLister, error) { + // dirname must be already cleaned + prefix := fs.getPrefix(dirname) + paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ + Bucket: aws.String(fs.config.Bucket), + Prefix: aws.String(prefix), + Delimiter: aws.String("/"), + MaxKeys: &s3DefaultPageSize, + }) + + return &s3DirLister{ + paginator: paginator, + timeout: fs.ctxTimeout, + prefix: prefix, + prefixes: make(map[string]bool), + }, nil +} + +// IsUploadResumeSupported returns true if resuming uploads is supported. +// Resuming uploads is not supported on S3 +func (*S3Fs) IsUploadResumeSupported() bool { + return false +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (*S3Fs) IsConditionalUploadResumeSupported(size int64) bool { + return size <= resumeMaxSize +} + +// IsAtomicUploadSupported returns true if atomic upload is supported. +// S3 uploads are already atomic, we don't need to upload to a temporary +// file +func (*S3Fs) IsAtomicUploadSupported() bool { + return false +} + +// IsNotExist returns a boolean indicating whether the error is known to +// report that a file or directory does not exist +func (*S3Fs) IsNotExist(err error) bool { + if err == nil { + return false + } + + var re *awshttp.ResponseError + if errors.As(err, &re) { + if re.Response != nil { + return re.Response.StatusCode == http.StatusNotFound + } + } + return false +} + +// IsPermission returns a boolean indicating whether the error is known to +// report that permission is denied. +func (*S3Fs) IsPermission(err error) bool { + if err == nil { + return false + } + + var re *awshttp.ResponseError + if errors.As(err, &re) { + if re.Response != nil { + return re.Response.StatusCode == http.StatusForbidden || + re.Response.StatusCode == http.StatusUnauthorized + } + } + return false +} + +// IsNotSupported returns true if the error indicate an unsupported operation +func (*S3Fs) IsNotSupported(err error) bool { + if err == nil { + return false + } + return errors.Is(err, ErrVfsUnsupported) +} + +// CheckRootPath creates the specified local root directory if it does not exists +func (fs *S3Fs) CheckRootPath(username string, uid int, gid int) bool { + // we need a local directory for temporary files + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) + return osFs.CheckRootPath(username, uid, gid) +} + +// ScanRootDirContents returns the number of files contained in the bucket, +// and their size +func (fs *S3Fs) ScanRootDirContents() (int, int64, error) { + return fs.GetDirSize(fs.config.KeyPrefix) +} + +// GetDirSize returns the number of files and the size for a folder +// including any subfolders +func (fs *S3Fs) GetDirSize(dirname string) (int, int64, error) { + prefix := fs.getPrefix(dirname) + numFiles := 0 + size := int64(0) + + paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ + Bucket: aws.String(fs.config.Bucket), + Prefix: aws.String(prefix), + MaxKeys: &s3DefaultPageSize, + }) + + for paginator.HasMorePages() { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + page, err := paginator.NextPage(ctx) + if err != nil { + metric.S3ListObjectsCompleted(err) + return numFiles, size, err + } + for _, fileObject := range page.Contents { + isDir := strings.HasSuffix(util.GetStringFromPointer(fileObject.Key), "/") + objectSize := util.GetIntFromPointer(fileObject.Size) + if isDir && objectSize == 0 { + continue + } + numFiles++ + size += objectSize + } + fsLog(fs, logger.LevelDebug, "scan in progress for %q, files: %d, size: %d", dirname, numFiles, size) + } + + metric.S3ListObjectsCompleted(nil) + return numFiles, size, nil +} + +// GetAtomicUploadPath returns the path to use for an atomic upload. +// S3 uploads are already atomic, we never call this method for S3 +func (*S3Fs) GetAtomicUploadPath(_ string) string { + return "" +} + +// GetRelativePath returns the path for a file relative to the user's home dir. +// This is the path as seen by SFTPGo users +func (fs *S3Fs) GetRelativePath(name string) string { + rel := path.Clean(name) + if rel == "." { + rel = "" + } + if !path.IsAbs(rel) { + rel = "/" + rel + } + if fs.config.KeyPrefix != "" { + if !strings.HasPrefix(rel, "/"+fs.config.KeyPrefix) { + rel = "/" + } + rel = path.Clean("/" + strings.TrimPrefix(rel, "/"+fs.config.KeyPrefix)) + } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } + return rel +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root. The result are unordered +func (fs *S3Fs) Walk(root string, walkFn filepath.WalkFunc) error { + prefix := fs.getPrefix(root) + + paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ + Bucket: aws.String(fs.config.Bucket), + Prefix: aws.String(prefix), + MaxKeys: &s3DefaultPageSize, + }) + + for paginator.HasMorePages() { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + page, err := paginator.NextPage(ctx) + if err != nil { + metric.S3ListObjectsCompleted(err) + walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), err) //nolint:errcheck + return err + } + for _, fileObject := range page.Contents { + name, isDir := fs.resolve(fileObject.Key, prefix) + if name == "" { + continue + } + err := walkFn(util.GetStringFromPointer(fileObject.Key), + NewFileInfo(name, isDir, util.GetIntFromPointer(fileObject.Size), + util.GetTimeFromPointer(fileObject.LastModified), false), nil) + if err != nil { + return err + } + } + } + + metric.S3ListObjectsCompleted(nil) + walkFn(root, NewFileInfo(root, true, 0, time.Unix(0, 0), false), nil) //nolint:errcheck + return nil +} + +// Join joins any number of path elements into a single path +func (*S3Fs) Join(elem ...string) string { + return strings.TrimPrefix(path.Join(elem...), "/") +} + +// HasVirtualFolders returns true if folders are emulated +func (*S3Fs) HasVirtualFolders() bool { + return true +} + +// ResolvePath returns the matching filesystem path for the specified virtual path +func (fs *S3Fs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { + virtualPath = after + } + } + virtualPath = path.Clean("/" + virtualPath) + return fs.Join(fs.config.KeyPrefix, strings.TrimPrefix(virtualPath, "/")), nil +} + +// CopyFile implements the FsFileCopier interface +func (fs *S3Fs) CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) { + numFiles := 1 + sizeDiff := srcInfo.Size() + attrs, err := fs.headObject(target) + if err == nil { + sizeDiff -= util.GetIntFromPointer(attrs.ContentLength) + numFiles = 0 + } else { + if !fs.IsNotExist(err) { + return 0, 0, err + } + } + if err := fs.copyFileInternal(source, target, srcInfo); err != nil { + return 0, 0, err + } + return numFiles, sizeDiff, nil +} + +func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) { + result := strings.TrimPrefix(util.GetStringFromPointer(name), prefix) + isDir := strings.HasSuffix(result, "/") + if isDir { + result = strings.TrimSuffix(result, "/") + } + return result, isDir +} + +func (fs *S3Fs) setConfigDefaults() { + const defaultPartSize = 1024 * 1024 * 5 + const defaultConcurrency = 5 + + if fs.config.UploadPartSize == 0 { + fs.config.UploadPartSize = defaultPartSize + } else { + if fs.config.UploadPartSize < 1024*1024 { + fs.config.UploadPartSize *= 1024 * 1024 + } + } + if fs.config.UploadConcurrency == 0 { + fs.config.UploadConcurrency = defaultConcurrency + } + if fs.config.DownloadPartSize == 0 { + fs.config.DownloadPartSize = defaultPartSize + } else { + if fs.config.DownloadPartSize < 1024*1024 { + fs.config.DownloadPartSize *= 1024 * 1024 + } + } + if fs.config.DownloadConcurrency == 0 { + fs.config.DownloadConcurrency = defaultConcurrency + } +} + +func (fs *S3Fs) copyFileInternal(source, target string, srcInfo os.FileInfo) error { + contentType := mime.TypeByExtension(path.Ext(source)) + copySource := pathEscape(fs.Join(fs.config.Bucket, source)) + + if srcInfo.Size() > s3CopyObjectThreshold { + fsLog(fs, logger.LevelDebug, "renaming file %q with size %d using multipart copy", + source, srcInfo.Size()) + err := fs.doMultipartCopy(copySource, target, contentType, srcInfo.Size()) + metric.S3CopyObjectCompleted(err) + return err + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + copyObject := &s3.CopyObjectInput{ + Bucket: aws.String(fs.config.Bucket), + CopySource: aws.String(copySource), + Key: aws.String(target), + StorageClass: types.StorageClass(fs.config.StorageClass), + ACL: types.ObjectCannedACL(fs.config.ACL), + ContentType: util.NilIfEmpty(contentType), + CopySourceSSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + CopySourceSSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + CopySourceSSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + } + + _, err := fs.svc.CopyObject(ctx, copyObject) + + metric.S3CopyObjectCompleted(err) + return err +} + +func (fs *S3Fs) renameInternal(source, target string, srcInfo os.FileInfo, recursion int, + updateModTime bool, +) (int, int64, error) { + var numFiles int + var filesSize int64 + + if srcInfo.IsDir() { + if renameMode == 0 { + hasContents, err := fs.hasContents(source) + if err != nil { + return numFiles, filesSize, err + } + if hasContents { + return numFiles, filesSize, fmt.Errorf("%w: cannot rename non empty directory: %q", ErrVfsUnsupported, source) + } + } + if err := fs.mkdirInternal(target); err != nil { + return numFiles, filesSize, err + } + if renameMode == 1 { + files, size, err := doRecursiveRename(fs, source, target, fs.renameInternal, recursion, updateModTime) + numFiles += files + filesSize += size + if err != nil { + return numFiles, filesSize, err + } + } + } else { + if err := fs.copyFileInternal(source, target, srcInfo); err != nil { + return numFiles, filesSize, err + } + numFiles++ + filesSize += srcInfo.Size() + } + err := fs.Remove(source, srcInfo.IsDir()) + if fs.IsNotExist(err) { + err = nil + } + return numFiles, filesSize, err +} + +func (fs *S3Fs) mkdirInternal(name string) error { + if !strings.HasSuffix(name, "/") { + name += "/" + } + _, w, _, err := fs.Create(name, -1, 0) + if err != nil { + return err + } + return w.Close() +} + +func (fs *S3Fs) hasContents(name string) (bool, error) { + prefix := fs.getPrefix(name) + maxKeys := int32(2) + paginator := s3.NewListObjectsV2Paginator(fs.svc, &s3.ListObjectsV2Input{ + Bucket: aws.String(fs.config.Bucket), + Prefix: aws.String(prefix), + MaxKeys: &maxKeys, + }) + + if paginator.HasMorePages() { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + page, err := paginator.NextPage(ctx) + metric.S3ListObjectsCompleted(err) + if err != nil { + return false, err + } + + for _, obj := range page.Contents { + name, _ := fs.resolve(obj.Key, prefix) + if name == "" || name == "/" { + continue + } + return true, nil + } + return false, nil + } + + metric.S3ListObjectsCompleted(nil) + return false, nil +} + +func (fs *S3Fs) downloadPart(ctx context.Context, name string, buf []byte, w io.WriterAt, start, count, writeOffset int64) error { + if count == 0 { + return nil + } + rangeHeader := fmt.Sprintf("bytes=%d-%d", start, start+count-1) + + resp, err := fs.svc.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + Range: &rangeHeader, + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return err + } + defer resp.Body.Close() + + _, err = io.ReadAtLeast(resp.Body, buf, int(count)) + if err != nil { + return err + } + + return writeAtFull(w, buf, writeOffset, int(count)) +} + +func (fs *S3Fs) handleDownload(ctx context.Context, name string, offset int64, writer io.WriterAt, attrs *s3.HeadObjectOutput) error { + contentLength := util.GetIntFromPointer(attrs.ContentLength) + sizeToDownload := contentLength - offset + if sizeToDownload < 0 { + fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %d, offset: %d, size to download: %d", + contentLength, offset, sizeToDownload) + return errors.New("the requested offset exceeds the file size") + } + if sizeToDownload == 0 { + fsLog(fs, logger.LevelDebug, "nothing to download, offset %d, content length %d", offset, contentLength) + return nil + } + partSize := fs.config.DownloadPartSize + guard := make(chan struct{}, fs.config.DownloadConcurrency) + var blockCtxTimeout time.Duration + if fs.config.DownloadPartMaxTime > 0 { + blockCtxTimeout = time.Duration(fs.config.DownloadPartMaxTime) * time.Second + } else { + blockCtxTimeout = time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute + } + pool := newBufferAllocator(int(partSize)) + defer pool.free() + + finished := false + var wg sync.WaitGroup + var errOnce sync.Once + var hasError atomic.Bool + var poolError error + + poolCtx, poolCancel := context.WithCancel(ctx) + defer poolCancel() + + for part := 0; !finished; part++ { + start := offset + end := offset + partSize + if end >= contentLength { + end = contentLength + finished = true + } + writeOffset := int64(part) * partSize + offset = end + + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelDebug, "pool error, download for part %d not started", part) + break + } + + buf := pool.getBuffer() + wg.Add(1) + go func(start, end, writeOffset int64, buf []byte) { + defer func() { + pool.releaseBuffer(buf) + <-guard + wg.Done() + }() + + innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) + defer cancelFn() + + err := fs.downloadPart(innerCtx, name, buf, writer, start, end-start, writeOffset) + if err != nil { + errOnce.Do(func() { + fsLog(fs, logger.LevelError, "multipart download error: %+v", err) + hasError.Store(true) + poolError = fmt.Errorf("multipart download error: %w", err) + poolCancel() + }) + } + }(start, end, writeOffset, buf) + } + + wg.Wait() + close(guard) + + return poolError +} + +func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) { + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + StorageClass: types.StorageClass(fs.config.StorageClass), + ACL: types.ObjectCannedACL(fs.config.ACL), + ContentType: util.NilIfEmpty(contentType), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return "", fmt.Errorf("unable to create multipart upload request: %w", err) + } + uploadID := util.GetStringFromPointer(res.UploadId) + if uploadID == "" { + return "", errors.New("unable to get multipart upload ID") + } + return uploadID, nil +} + +func (fs *S3Fs) uploadPart(ctx context.Context, name, uploadID string, partNumber int32, data []byte) (*string, error) { + timeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute + if fs.config.UploadPartMaxTime > 0 { + timeout = time.Duration(fs.config.UploadPartMaxTime) * time.Second + } + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancelFn() + + resp, err := fs.svc.UploadPart(ctx, &s3.UploadPartInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + PartNumber: &partNumber, + UploadId: aws.String(uploadID), + Body: bytes.NewReader(data), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return nil, fmt.Errorf("unable to upload part number %d: %w", partNumber, err) + } + return resp.ETag, nil +} + +func (fs *S3Fs) completeMultipartUpload(ctx context.Context, name, uploadID string, completedParts []types.CompletedPart) error { + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + _, err := fs.svc.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + UploadId: aws.String(uploadID), + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completedParts, + }, + }) + return err +} + +func (fs *S3Fs) abortMultipartUpload(name, uploadID string) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + _, err := fs.svc.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + UploadId: aws.String(uploadID), + }) + return err +} + +func (fs *S3Fs) singlePartUpload(ctx context.Context, name, contentType string, data []byte) error { + timeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute + if fs.config.UploadPartMaxTime > 0 { + timeout = time.Duration(fs.config.UploadPartMaxTime) * time.Second + } + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancelFn() + + contentLength := int64(len(data)) + _, err := fs.svc.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + ACL: types.ObjectCannedACL(fs.config.ACL), + Body: bytes.NewReader(data), + ContentType: util.NilIfEmpty(contentType), + ContentLength: &contentLength, + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + StorageClass: types.StorageClass(fs.config.StorageClass), + }) + return err +} + +func (fs *S3Fs) handleUpload(ctx context.Context, reader io.Reader, name, contentType string) error { + pool := newBufferAllocator(int(fs.config.UploadPartSize)) + defer pool.free() + + firstBuf := pool.getBuffer() + firstReadSize, err := readFill(reader, firstBuf) + if err == io.EOF { + return fs.singlePartUpload(ctx, name, contentType, firstBuf[:firstReadSize]) + } + if err != nil { + return err + } + + uploadID, err := fs.initiateMultipartUpload(ctx, name, contentType) + if err != nil { + return err + } + guard := make(chan struct{}, fs.config.UploadConcurrency) + finished := false + var partMutex sync.Mutex + var completedParts []types.CompletedPart + var wg sync.WaitGroup + var hasError atomic.Bool + var poolErr error + var errOnce sync.Once + var partNumber int32 + + poolCtx, poolCancel := context.WithCancel(ctx) + defer poolCancel() + + finalizeFailedUpload := func(err error) { + fsLog(fs, logger.LevelError, "finalize failed multipart upload after error: %v", err) + hasError.Store(true) + poolErr = err + poolCancel() + if abortErr := fs.abortMultipartUpload(name, uploadID); abortErr != nil { + fsLog(fs, logger.LevelError, "unable to abort multipart upload: %+v", abortErr) + } + } + + uploadPart := func(partNum int32, buf []byte, bytesRead int) { + defer func() { + pool.releaseBuffer(buf) + <-guard + wg.Done() + }() + + etag, err := fs.uploadPart(poolCtx, name, uploadID, partNum, buf[:bytesRead]) + if err != nil { + errOnce.Do(func() { + finalizeFailedUpload(err) + }) + return + } + partMutex.Lock() + completedParts = append(completedParts, types.CompletedPart{ + PartNumber: &partNum, + ETag: etag, + }) + partMutex.Unlock() + } + + partNumber = 1 + guard <- struct{}{} + + wg.Add(1) + go uploadPart(partNumber, firstBuf, firstReadSize) + + for partNumber = 2; !finished; partNumber++ { + buf := pool.getBuffer() + + n, err := readFill(reader, buf) + if err == io.EOF { + if n == 0 { + pool.releaseBuffer(buf) + break + } + finished = true + } else if err != nil { + pool.releaseBuffer(buf) + errOnce.Do(func() { + finalizeFailedUpload(err) + }) + break + } + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelError, "pool error, upload for part %d not started", partNumber) + pool.releaseBuffer(buf) + break + } + + wg.Add(1) + go uploadPart(partNumber, buf, n) + } + + wg.Wait() + close(guard) + + if poolErr != nil { + return poolErr + } + + sort.Slice(completedParts, func(i, j int) bool { + getPartNumber := func(number *int32) int32 { + if number == nil { + return 0 + } + return *number + } + + return getPartNumber(completedParts[i].PartNumber) < getPartNumber(completedParts[j].PartNumber) + }) + + return fs.completeMultipartUpload(ctx, name, uploadID, completedParts) +} + +func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int64) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(target), + StorageClass: types.StorageClass(fs.config.StorageClass), + ACL: types.ObjectCannedACL(fs.config.ACL), + ContentType: util.NilIfEmpty(contentType), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return fmt.Errorf("unable to create multipart copy request: %w", err) + } + uploadID := util.GetStringFromPointer(res.UploadId) + if uploadID == "" { + return errors.New("unable to get multipart copy upload ID") + } + // We use 32 MB part size and copy 10 parts in parallel. + // These values are arbitrary. We don't want to start too many goroutines + maxPartSize := int64(32 * 1024 * 1024) + if fileSize > int64(100*1024*1024*1024) { + maxPartSize = int64(500 * 1024 * 1024) + } + guard := make(chan struct{}, 10) + finished := false + var completedParts []types.CompletedPart + var partMutex sync.Mutex + var wg sync.WaitGroup + var hasError atomic.Bool + var errOnce sync.Once + var copyError error + var partNumber int32 + var offset int64 + + opCtx, opCancel := context.WithCancel(context.Background()) + defer opCancel() + + for partNumber = 1; !finished; partNumber++ { + start := offset + end := offset + maxPartSize + if end >= fileSize { + end = fileSize + finished = true + } + offset = end + + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelDebug, "previous multipart copy error, copy for part %d not started", partNumber) + break + } + + wg.Add(1) + go func(partNum int32, partStart, partEnd int64) { + defer func() { + <-guard + wg.Done() + }() + + innerCtx, innerCancelFn := context.WithDeadline(opCtx, time.Now().Add(fs.ctxTimeout)) + defer innerCancelFn() + + partResp, err := fs.svc.UploadPartCopy(innerCtx, &s3.UploadPartCopyInput{ + Bucket: aws.String(fs.config.Bucket), + CopySource: aws.String(source), + Key: aws.String(target), + PartNumber: &partNum, + UploadId: aws.String(uploadID), + CopySourceRange: aws.String(fmt.Sprintf("bytes=%d-%d", partStart, partEnd-1)), + CopySourceSSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + CopySourceSSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + CopySourceSSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + errOnce.Do(func() { + fsLog(fs, logger.LevelError, "unable to copy part number %d: %+v", partNum, err) + hasError.Store(true) + copyError = fmt.Errorf("error copying part number %d: %w", partNum, err) + opCancel() + + if errAbort := fs.abortMultipartUpload(target, uploadID); errAbort != nil { + fsLog(fs, logger.LevelError, "unable to abort multipart copy: %+v", errAbort) + } + }) + return + } + + partMutex.Lock() + completedParts = append(completedParts, types.CompletedPart{ + ETag: partResp.CopyPartResult.ETag, + PartNumber: &partNum, + }) + partMutex.Unlock() + }(partNumber, start, end) + } + + wg.Wait() + close(guard) + + if copyError != nil { + return copyError + } + sort.Slice(completedParts, func(i, j int) bool { + getPartNumber := func(number *int32) int32 { + if number == nil { + return 0 + } + return *number + } + + return getPartNumber(completedParts[i].PartNumber) < getPartNumber(completedParts[j].PartNumber) + }) + + completeCtx, completeCancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer completeCancelFn() + + _, err = fs.svc.CompleteMultipartUpload(completeCtx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(target), + UploadId: aws.String(uploadID), + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completedParts, + }, + }) + if err != nil { + return fmt.Errorf("unable to complete multipart upload: %w", err) + } + return nil +} + +func (fs *S3Fs) getPrefix(name string) string { + prefix := "" + if name != "" && name != "." && name != "/" { + prefix = strings.TrimPrefix(name, "/") + if !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + } + return prefix +} + +func (fs *S3Fs) headObject(name string) (*s3.HeadObjectOutput, error) { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + obj, err := fs.svc.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + metric.S3HeadObjectCompleted(err) + return obj, err +} + +// GetMimeType returns the content type +func (fs *S3Fs) GetMimeType(name string) (string, error) { + obj, err := fs.headObject(name) + if err != nil { + return "", err + } + return util.GetStringFromPointer(obj.ContentType), nil +} + +// Close closes the fs +func (*S3Fs) Close() error { + return nil +} + +// GetAvailableDiskSize returns the available size for the specified path +func (*S3Fs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { + return nil, ErrStorageSizeUnavailable +} + +func (fs *S3Fs) downloadToWriter(name string, w PipeWriter) (int64, error) { + fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name) + attrs, err := fs.headObject(name) + if err != nil { + return 0, err + } + ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout) + defer cancelFn() + + err = fs.handleDownload(ctx, name, 0, w, attrs) + fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v", + name, w.GetWrittenBytes(), err) + metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err) + return w.GetWrittenBytes(), err +} + +type s3DirLister struct { + baseDirLister + paginator *s3.ListObjectsV2Paginator + timeout time.Duration + prefix string + prefixes map[string]bool + metricUpdated bool +} + +func (l *s3DirLister) resolve(name *string) (string, bool) { + result := strings.TrimPrefix(util.GetStringFromPointer(name), l.prefix) + isDir := strings.HasSuffix(result, "/") + if isDir { + result = strings.TrimSuffix(result, "/") + } + return result, isDir +} + +func (l *s3DirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + if !l.paginator.HasMorePages() { + if !l.metricUpdated { + l.metricUpdated = true + metric.S3ListObjectsCompleted(nil) + } + return l.returnFromCache(limit), io.EOF + } + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(l.timeout)) + defer cancelFn() + + page, err := l.paginator.NextPage(ctx) + if err != nil { + metric.S3ListObjectsCompleted(err) + return l.cache, err + } + for _, p := range page.CommonPrefixes { + // prefixes have a trailing slash + name, _ := l.resolve(p.Prefix) + if name == "" { + continue + } + if _, ok := l.prefixes[name]; ok { + continue + } + l.cache = append(l.cache, NewFileInfo(name, true, 0, time.Unix(0, 0), false)) + l.prefixes[name] = true + } + for _, fileObject := range page.Contents { + objectModTime := util.GetTimeFromPointer(fileObject.LastModified) + objectSize := util.GetIntFromPointer(fileObject.Size) + name, isDir := l.resolve(fileObject.Key) + if name == "" || name == "/" { + continue + } + if isDir { + if _, ok := l.prefixes[name]; ok { + continue + } + l.prefixes[name] = true + } + + l.cache = append(l.cache, NewFileInfo(name, (isDir && objectSize == 0), objectSize, objectModTime, false)) + } + return l.returnFromCache(limit), nil +} + +func (l *s3DirLister) Close() error { + return l.baseDirLister.Close() +} + +func getAWSHTTPClient(timeout int, idleConnectionTimeout time.Duration, skipTLSVerify bool) *awshttp.BuildableClient { + c := awshttp.NewBuildableClient(). + WithDialerOptions(func(d *net.Dialer) { + d.Timeout = 8 * time.Second + }). + WithTransportOptions(func(tr *http.Transport) { + tr.IdleConnTimeout = idleConnectionTimeout + tr.WriteBufferSize = s3TransferBufferSize + tr.ReadBufferSize = s3TransferBufferSize + if skipTLSVerify { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = skipTLSVerify + } else { + tr.TLSClientConfig = &tls.Config{ + MinVersion: awshttp.DefaultHTTPTransportTLSMinVersion, + InsecureSkipVerify: skipTLSVerify, + } + } + } + }) + if timeout > 0 { + c = c.WithTimeout(time.Duration(timeout) * time.Second) + } + return c +} + +// ideally we should simply use url.PathEscape: +// +// https://github.com/awsdocs/aws-doc-sdk-examples/blob/master/go/example_code/s3/s3_copy_object.go#L65 +// +// but this cause issue with some vendors, see #483, the code below is copied from rclone +func pathEscape(in string) string { + var u url.URL + u.Path = in + return strings.ReplaceAll(u.String(), "+", "%2B") +} diff --git a/internal/vfs/s3fs_disabled.go b/internal/vfs/s3fs_disabled.go new file mode 100644 index 00000000..5c1f1b53 --- /dev/null +++ b/internal/vfs/s3fs_disabled.go @@ -0,0 +1,32 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build nos3 + +package vfs + +import ( + "errors" + + "github.com/drakkan/sftpgo/v2/internal/version" +) + +func init() { + version.AddFeature("-s3") +} + +// NewS3Fs returns an error, S3 is disabled +func NewS3Fs(_, _, _ string, _ S3FsConfig) (Fs, error) { + return nil, errors.New("S3 disabled at build time") +} diff --git a/internal/vfs/sftpfs.go b/internal/vfs/sftpfs.go new file mode 100644 index 00000000..2f263dd3 --- /dev/null +++ b/internal/vfs/sftpfs.go @@ -0,0 +1,1232 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "bufio" + "bytes" + "crypto/rsa" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "io/fs" + "net" + "net/http" + "os" + "path" + "path/filepath" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pkg/sftp" + "github.com/robfig/cron/v3" + "github.com/rs/xid" + "github.com/sftpgo/sdk" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +const ( + // sftpFsName is the name for the SFTP Fs implementation + sftpFsName = "sftpfs" + logSenderSFTPCache = "sftpCache" + maxSessionsPerConnection = 5 +) + +var ( + // ErrSFTPLoop defines the error to return if an SFTP loop is detected + ErrSFTPLoop = errors.New("SFTP loop or nested local SFTP folders detected") + sftpConnsCache = newSFTPConnectionCache() +) + +// SFTPFsConfig defines the configuration for SFTP based filesystem +type SFTPFsConfig struct { + sdk.BaseSFTPFsConfig + Password *kms.Secret `json:"password,omitempty"` + PrivateKey *kms.Secret `json:"private_key,omitempty"` + KeyPassphrase *kms.Secret `json:"key_passphrase,omitempty"` + forbiddenSelfUsernames []string `json:"-"` +} + +func (c *SFTPFsConfig) getKeySigner() (ssh.Signer, error) { + privPayload := c.PrivateKey.GetPayload() + if privPayload == "" { + return nil, nil + } + if key := c.KeyPassphrase.GetPayload(); key != "" { + return ssh.ParsePrivateKeyWithPassphrase([]byte(privPayload), []byte(key)) + } + return ssh.ParsePrivateKey([]byte(privPayload)) +} + +// HideConfidentialData hides confidential data +func (c *SFTPFsConfig) HideConfidentialData() { + if c.Password != nil { + c.Password.Hide() + } + if c.PrivateKey != nil { + c.PrivateKey.Hide() + } + if c.KeyPassphrase != nil { + c.KeyPassphrase.Hide() + } +} + +func (c *SFTPFsConfig) setNilSecretsIfEmpty() { + if c.Password != nil && c.Password.IsEmpty() { + c.Password = nil + } + if c.PrivateKey != nil && c.PrivateKey.IsEmpty() { + c.PrivateKey = nil + } + if c.KeyPassphrase != nil && c.KeyPassphrase.IsEmpty() { + c.KeyPassphrase = nil + } +} + +func (c *SFTPFsConfig) isEqual(other SFTPFsConfig) bool { + if c.Endpoint != other.Endpoint { + return false + } + if c.Username != other.Username { + return false + } + if c.Prefix != other.Prefix { + return false + } + if c.DisableCouncurrentReads != other.DisableCouncurrentReads { + return false + } + if c.BufferSize != other.BufferSize { + return false + } + if len(c.Fingerprints) != len(other.Fingerprints) { + return false + } + for _, fp := range c.Fingerprints { + if !slices.Contains(other.Fingerprints, fp) { + return false + } + } + c.setEmptyCredentialsIfNil() + other.setEmptyCredentialsIfNil() + if !c.Password.IsEqual(other.Password) { + return false + } + if !c.KeyPassphrase.IsEqual(other.KeyPassphrase) { + return false + } + return c.PrivateKey.IsEqual(other.PrivateKey) +} + +func (c *SFTPFsConfig) setEmptyCredentialsIfNil() { + if c.Password == nil { + c.Password = kms.NewEmptySecret() + } + if c.PrivateKey == nil { + c.PrivateKey = kms.NewEmptySecret() + } + if c.KeyPassphrase == nil { + c.KeyPassphrase = kms.NewEmptySecret() + } +} + +func (c *SFTPFsConfig) isSameResource(other SFTPFsConfig) bool { + if c.EqualityCheckMode > 0 || other.EqualityCheckMode > 0 { + if c.Username != other.Username { + return false + } + } + return c.Endpoint == other.Endpoint +} + +// validate returns an error if the configuration is not valid +func (c *SFTPFsConfig) validate() error { + c.setEmptyCredentialsIfNil() + if c.Endpoint == "" { + return util.NewI18nError(errors.New("endpoint cannot be empty"), util.I18nErrorEndpointRequired) + } + if !strings.Contains(c.Endpoint, ":") { + c.Endpoint += ":22" + } + _, _, err := net.SplitHostPort(c.Endpoint) + if err != nil { + return util.NewI18nError(fmt.Errorf("invalid endpoint: %v", err), util.I18nErrorEndpointInvalid) + } + if c.Username == "" { + return util.NewI18nError(errors.New("username cannot be empty"), util.I18nErrorFsUsernameRequired) + } + if c.BufferSize < 0 || c.BufferSize > 16 { + return errors.New("invalid buffer_size, valid range is 0-16") + } + if !isEqualityCheckModeValid(c.EqualityCheckMode) { + return errors.New("invalid equality_check_mode") + } + if err := c.validateCredentials(); err != nil { + return err + } + if c.Prefix != "" { + c.Prefix = util.CleanPath(c.Prefix) + } else { + c.Prefix = "/" + } + return c.validatePrivateKey() +} + +func (c *SFTPFsConfig) validatePrivateKey() error { + if c.PrivateKey.IsPlain() { + signer, err := c.getKeySigner() + if err != nil { + return util.NewI18nError(fmt.Errorf("invalid private key: %w", err), util.I18nErrorPrivKeyInvalid) + } + if signer != nil { + if key, ok := signer.PublicKey().(ssh.CryptoPublicKey); ok { + cryptoKey := key.CryptoPublicKey() + if rsaKey, ok := cryptoKey.(*rsa.PublicKey); ok { + if size := rsaKey.N.BitLen(); size < 2048 { + return util.NewI18nError( + fmt.Errorf("rsa key with size %d not accepted, minimum 2048", size), + util.I18nErrorKeySizeInvalid, + ) + } + } + } + } + } + return nil +} + +func (c *SFTPFsConfig) validateCredentials() error { + if c.Password.IsEmpty() && c.PrivateKey.IsEmpty() { + return util.NewI18nError(errors.New("credentials cannot be empty"), util.I18nErrorFsCredentialsRequired) + } + if c.Password.IsEncrypted() && !c.Password.IsValid() { + return errors.New("invalid encrypted password") + } + if !c.Password.IsEmpty() && !c.Password.IsValidInput() { + return errors.New("invalid password") + } + if c.PrivateKey.IsEncrypted() && !c.PrivateKey.IsValid() { + return errors.New("invalid encrypted private key") + } + if !c.PrivateKey.IsEmpty() && !c.PrivateKey.IsValidInput() { + return errors.New("invalid private key") + } + if c.KeyPassphrase.IsEncrypted() && !c.KeyPassphrase.IsValid() { + return errors.New("invalid encrypted private key passphrase") + } + if !c.KeyPassphrase.IsEmpty() && !c.KeyPassphrase.IsValidInput() { + return errors.New("invalid private key passphrase") + } + return nil +} + +// ValidateAndEncryptCredentials validates the config and encrypts credentials if they are in plain text +func (c *SFTPFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + var errI18n *util.I18nError + errValidation := util.NewValidationError(fmt.Sprintf("could not validate SFTP fs config: %v", err)) + if errors.As(err, &errI18n) { + return util.NewI18nError(errValidation, errI18n.Message) + } + return util.NewI18nError(errValidation, util.I18nErrorFsValidation) + } + if c.Password.IsPlain() { + c.Password.SetAdditionalData(additionalData) + if err := c.Password.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs password: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + if c.PrivateKey.IsPlain() { + c.PrivateKey.SetAdditionalData(additionalData) + if err := c.PrivateKey.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs private key: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + if c.KeyPassphrase.IsPlain() { + c.KeyPassphrase.SetAdditionalData(additionalData) + if err := c.KeyPassphrase.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt SFTP fs private key passphrase: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + return nil +} + +// getUniqueID returns an hash of the settings used to connect to the SFTP server +func (c *SFTPFsConfig) getUniqueID(partition int) string { + h := sha256.New() + var b bytes.Buffer + + b.WriteString(c.Endpoint) + b.WriteString(c.Username) + b.WriteString(strings.Join(c.Fingerprints, "")) + b.WriteString(strconv.FormatBool(c.DisableCouncurrentReads)) + b.WriteString(strconv.FormatInt(c.BufferSize, 10)) + b.WriteString(c.Password.GetPayload()) + b.WriteString(c.PrivateKey.GetPayload()) + b.WriteString(c.KeyPassphrase.GetPayload()) + if allowSelfConnections != 0 { + b.WriteString(strings.Join(c.forbiddenSelfUsernames, "")) + } + b.WriteString(strconv.Itoa(partition)) + + h.Write(b.Bytes()) + return hex.EncodeToString(h.Sum(nil)) +} + +// SFTPFs is a Fs implementation for SFTP backends +type SFTPFs struct { + connectionID string + // if not empty this fs is mouted as virtual folder in the specified path + mountPath string + localTempDir string + config *SFTPFsConfig + conn *sftpConnection +} + +// NewSFTPFs returns an SFTPFs object that allows to interact with an SFTP server +func NewSFTPFs(connectionID, mountPath, localTempDir string, forbiddenSelfUsernames []string, config SFTPFsConfig) (Fs, error) { + if localTempDir == "" { + localTempDir = getLocalTempDir() + } + if err := config.validate(); err != nil { + return nil, err + } + if !config.Password.IsEmpty() { + if err := config.Password.TryDecrypt(); err != nil { + return nil, err + } + } + if !config.PrivateKey.IsEmpty() { + if err := config.PrivateKey.TryDecrypt(); err != nil { + return nil, err + } + } + if !config.KeyPassphrase.IsEmpty() { + if err := config.KeyPassphrase.TryDecrypt(); err != nil { + return nil, err + } + } + conn, err := sftpConnsCache.Get(&config, connectionID) + if err != nil { + return nil, err + } + config.forbiddenSelfUsernames = forbiddenSelfUsernames + sftpFs := &SFTPFs{ + connectionID: connectionID, + mountPath: getMountPath(mountPath), + localTempDir: localTempDir, + config: &config, + conn: conn, + } + err = sftpFs.createConnection() + if err != nil { + sftpFs.Close() //nolint:errcheck + } + return sftpFs, err +} + +// Name returns the name for the Fs implementation +func (fs *SFTPFs) Name() string { + return fmt.Sprintf(`%s %q@%q`, sftpFsName, fs.config.Username, fs.config.Endpoint) +} + +// ConnectionID returns the connection ID associated to this Fs implementation +func (fs *SFTPFs) ConnectionID() string { + return fs.connectionID +} + +// Stat returns a FileInfo describing the named file +func (fs *SFTPFs) Stat(name string) (os.FileInfo, error) { + client, err := fs.conn.getClient() + if err != nil { + return nil, err + } + return client.Stat(name) +} + +// Lstat returns a FileInfo describing the named file +func (fs *SFTPFs) Lstat(name string) (os.FileInfo, error) { + client, err := fs.conn.getClient() + if err != nil { + return nil, err + } + return client.Lstat(name) +} + +// Open opens the named file for reading +func (fs *SFTPFs) Open(name string, offset int64) (File, PipeReader, func(), error) { + client, err := fs.conn.getClient() + if err != nil { + return nil, nil, nil, err + } + f, err := client.Open(name) + if err != nil { + return nil, nil, nil, err + } + if offset > 0 { + _, err = f.Seek(offset, io.SeekStart) + if err != nil { + f.Close() + return nil, nil, nil, err + } + } + if fs.config.BufferSize == 0 { + return f, nil, nil, nil + } + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + f.Close() + return nil, nil, nil, err + } + p := NewPipeReader(r) + + go func() { + // if we enable buffering the client stalls + //br := bufio.NewReaderSize(f, int(fs.config.BufferSize)*1024*1024) + //n, err := fs.copy(w, br) + n, err := io.Copy(w, f) + w.CloseWithError(err) //nolint:errcheck + f.Close() + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %v", name, n, err) + }() + + return nil, p, nil, nil +} + +// Create creates or opens the named file for writing +func (fs *SFTPFs) Create(name string, flag, _ int) (File, PipeWriter, func(), error) { + client, err := fs.conn.getClient() + if err != nil { + return nil, nil, nil, err + } + if fs.config.BufferSize == 0 { + var f File + if flag == 0 { + f, err = client.Create(name) + } else { + f, err = client.OpenFile(name, flag) + } + return f, nil, nil, err + } + // buffering is enabled + f, err := client.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_TRUNC) + if err != nil { + return nil, nil, nil, err + } + r, w, err := createPipeFn(fs.localTempDir, 0) + if err != nil { + f.Close() + return nil, nil, nil, err + } + p := NewPipeWriter(w) + + go func() { + bw := bufio.NewWriterSize(f, int(fs.config.BufferSize)*1024*1024) + // we don't use io.Copy since bufio.Writer implements io.WriterTo and + // so it calls the sftp.File WriteTo method without buffering + n, err := doCopy(bw, r, nil) + errFlush := bw.Flush() + if err == nil && errFlush != nil { + err = errFlush + } + var errTruncate error + if err != nil { + errTruncate = f.Truncate(n) + } + errClose := f.Close() + if err == nil && errClose != nil { + err = errClose + } + r.CloseWithError(err) //nolint:errcheck + p.Done(err) + fsLog(fs, logger.LevelDebug, "upload completed, path: %q, readed bytes: %v, err: %v err truncate: %v", + name, n, err, errTruncate) + }() + + return nil, p, nil, nil +} + +// Rename renames (moves) source to target. +func (fs *SFTPFs) Rename(source, target string, checks int) (int, int64, error) { + if source == target { + return -1, -1, nil + } + client, err := fs.conn.getClient() + if err != nil { + return -1, -1, err + } + if _, ok := client.HasExtension("posix-rename@openssh.com"); ok { + err := client.PosixRename(source, target) + if checks&CheckUpdateModTime != 0 && err == nil { + fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck + } + return -1, -1, err + } + err = client.Rename(source, target) + if checks&CheckUpdateModTime != 0 && err == nil { + fs.Chtimes(target, time.Now(), time.Now(), false) //nolint:errcheck + } + return -1, -1, err +} + +// Remove removes the named file or (empty) directory. +func (fs *SFTPFs) Remove(name string, isDir bool) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + if isDir { + return client.RemoveDirectory(name) + } + return client.Remove(name) +} + +// Mkdir creates a new directory with the specified name and default permissions +func (fs *SFTPFs) Mkdir(name string) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + return client.Mkdir(name) +} + +// Symlink creates source as a symbolic link to target. +func (fs *SFTPFs) Symlink(source, target string) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + return client.Symlink(source, target) +} + +// Readlink returns the destination of the named symbolic link +func (fs *SFTPFs) Readlink(name string) (string, error) { + client, err := fs.conn.getClient() + if err != nil { + return "", err + } + resolved, err := client.ReadLink(name) + if err != nil { + return resolved, err + } + resolved = path.Clean(strings.ReplaceAll(resolved, "\\", "/")) + if !path.IsAbs(resolved) { + // we assume that multiple links are not followed + resolved = path.Join(path.Dir(name), resolved) + } + return fs.GetRelativePath(resolved), nil +} + +// Chown changes the numeric uid and gid of the named file. +func (fs *SFTPFs) Chown(name string, uid int, gid int) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + return client.Chown(name, uid, gid) +} + +// Chmod changes the mode of the named file to mode. +func (fs *SFTPFs) Chmod(name string, mode os.FileMode) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + return client.Chmod(name, mode) +} + +// Chtimes changes the access and modification times of the named file. +func (fs *SFTPFs) Chtimes(name string, atime, mtime time.Time, _ bool) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + return client.Chtimes(name, atime, mtime) +} + +// Truncate changes the size of the named file. +func (fs *SFTPFs) Truncate(name string, size int64) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + return client.Truncate(name, size) +} + +// ReadDir reads the directory named by dirname and returns +// a list of directory entries. +func (fs *SFTPFs) ReadDir(dirname string) (DirLister, error) { + client, err := fs.conn.getClient() + if err != nil { + return nil, err + } + files, err := client.ReadDir(dirname) + if err != nil { + return nil, err + } + return &baseDirLister{files}, nil +} + +// IsUploadResumeSupported returns true if resuming uploads is supported. +func (fs *SFTPFs) IsUploadResumeSupported() bool { + return fs.config.BufferSize == 0 +} + +// IsConditionalUploadResumeSupported returns if resuming uploads is supported +// for the specified size +func (fs *SFTPFs) IsConditionalUploadResumeSupported(_ int64) bool { + return fs.IsUploadResumeSupported() +} + +// IsAtomicUploadSupported returns true if atomic upload is supported. +func (fs *SFTPFs) IsAtomicUploadSupported() bool { + return fs.config.BufferSize == 0 +} + +// IsNotExist returns a boolean indicating whether the error is known to +// report that a file or directory does not exist +func (*SFTPFs) IsNotExist(err error) bool { + return errors.Is(err, fs.ErrNotExist) +} + +// IsPermission returns a boolean indicating whether the error is known to +// report that permission is denied. +func (*SFTPFs) IsPermission(err error) bool { + if _, ok := err.(*pathResolutionError); ok { + return true + } + return errors.Is(err, fs.ErrPermission) +} + +// IsNotSupported returns true if the error indicate an unsupported operation +func (*SFTPFs) IsNotSupported(err error) bool { + if err == nil { + return false + } + return err == ErrVfsUnsupported +} + +// CheckRootPath creates the specified local root directory if it does not exists +func (fs *SFTPFs) CheckRootPath(username string, uid int, gid int) bool { + // local directory for temporary files in buffer mode + osFs := NewOsFs(fs.ConnectionID(), fs.localTempDir, "", nil) + osFs.CheckRootPath(username, uid, gid) + if fs.config.Prefix == "/" { + return true + } + client, err := fs.conn.getClient() + if err != nil { + return false + } + if err := client.MkdirAll(fs.config.Prefix); err != nil { + fsLog(fs, logger.LevelDebug, "error creating root directory %q for user %q: %v", fs.config.Prefix, username, err) + return false + } + return true +} + +// ScanRootDirContents returns the number of files contained in a directory and +// their size +func (fs *SFTPFs) ScanRootDirContents() (int, int64, error) { + return fs.GetDirSize(fs.config.Prefix) +} + +// CheckMetadata checks the metadata consistency +func (*SFTPFs) CheckMetadata() error { + return nil +} + +// GetAtomicUploadPath returns the path to use for an atomic upload +func (*SFTPFs) GetAtomicUploadPath(name string) string { + dir := path.Dir(name) + guid := xid.New().String() + return path.Join(dir, ".sftpgo-upload."+guid+"."+path.Base(name)) +} + +// GetRelativePath returns the path for a file relative to the sftp prefix if any. +// This is the path as seen by SFTPGo users +func (fs *SFTPFs) GetRelativePath(name string) string { + rel := path.Clean(name) + if rel == "." { + rel = "" + } + if !path.IsAbs(rel) { + // If we have a relative path we assume it is already relative to the virtual root + rel = "/" + rel + } else if fs.config.Prefix != "/" { + prefixDir := fs.config.Prefix + if !strings.HasSuffix(prefixDir, "/") { + prefixDir += "/" + } + + if rel == fs.config.Prefix { + rel = "/" + } else if after, found := strings.CutPrefix(rel, prefixDir); found { + rel = path.Clean("/" + after) + } else { + // Absolute path outside of the configured prefix + fsLog(fs, logger.LevelWarn, "path %q is an absolute path outside %q", name, fs.config.Prefix) + rel = "/" + } + } + if fs.mountPath != "" { + rel = path.Join(fs.mountPath, rel) + } + return rel +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root +func (fs *SFTPFs) Walk(root string, walkFn filepath.WalkFunc) error { + client, err := fs.conn.getClient() + if err != nil { + return err + } + walker := client.Walk(root) + for walker.Step() { + err := walker.Err() + if err != nil { + return err + } + err = walkFn(walker.Path(), walker.Stat(), err) + if err != nil { + return err + } + } + return nil +} + +// Join joins any number of path elements into a single path +func (*SFTPFs) Join(elem ...string) string { + return path.Join(elem...) +} + +// HasVirtualFolders returns true if folders are emulated +func (*SFTPFs) HasVirtualFolders() bool { + return false +} + +// ResolvePath returns the matching filesystem path for the specified virtual path +func (fs *SFTPFs) ResolvePath(virtualPath string) (string, error) { + if fs.mountPath != "" { + if after, found := strings.CutPrefix(virtualPath, fs.mountPath); found { + virtualPath = after + } + } + virtualPath = path.Clean("/" + virtualPath) + fsPath := fs.Join(fs.config.Prefix, virtualPath) + if fs.config.Prefix != "/" && fsPath != "/" { + // we need to check if this path is a symlink outside the given prefix + // or a file/dir inside a dir symlinked outside the prefix + var validatedPath string + var err error + validatedPath, err = fs.getRealPath(fsPath) + isNotExist := fs.IsNotExist(err) + if err != nil && !isNotExist { + fsLog(fs, logger.LevelError, "Invalid path resolution, original path %v resolved %q err: %v", + virtualPath, fsPath, err) + return "", err + } else if isNotExist { + for fs.IsNotExist(err) { + validatedPath = path.Dir(validatedPath) + if validatedPath == "/" { + err = nil + break + } + validatedPath, err = fs.getRealPath(validatedPath) + } + if err != nil { + fsLog(fs, logger.LevelError, "Invalid path resolution, dir %q original path %q resolved %q err: %v", + validatedPath, virtualPath, fsPath, err) + return "", err + } + } + if err := fs.isSubDir(validatedPath); err != nil { + fsLog(fs, logger.LevelError, "Invalid path resolution, dir %q original path %q resolved %q err: %v", + validatedPath, virtualPath, fsPath, err) + return "", err + } + } + return fsPath, nil +} + +// RealPath implements the FsRealPather interface +func (fs *SFTPFs) RealPath(p string) (string, error) { + client, err := fs.conn.getClient() + if err != nil { + return "", err + } + resolved, err := client.RealPath(p) + if err != nil { + return "", err + } + resolved = path.Clean(strings.ReplaceAll(resolved, "\\", "/")) + if fs.config.Prefix != "/" { + if err := fs.isSubDir(resolved); err != nil { + fsLog(fs, logger.LevelError, "Invalid real path resolution, original path %q resolved %q err: %v", + p, resolved, err) + return "", err + } + } + return fs.GetRelativePath(resolved), nil +} + +// getRealPath returns the real remote path trying to resolve symbolic links if any +func (fs *SFTPFs) getRealPath(name string) (string, error) { + client, err := fs.conn.getClient() + if err != nil { + return "", err + } + linksWalked := 0 + for { + info, err := client.Lstat(name) + if err != nil { + return name, err + } + if info.Mode()&os.ModeSymlink == 0 { + return name, nil + } + resolvedLink, err := client.ReadLink(name) + if err != nil { + return name, fmt.Errorf("unable to resolve link to %q: %w", name, err) + } + resolvedLink = strings.ReplaceAll(resolvedLink, "\\", "/") + resolvedLink = path.Clean(resolvedLink) + if path.IsAbs(resolvedLink) { + name = resolvedLink + } else { + name = path.Join(path.Dir(name), resolvedLink) + } + linksWalked++ + if linksWalked > 10 { + fsLog(fs, logger.LevelError, "unable to get real path, too many links: %d", linksWalked) + return "", &pathResolutionError{err: "too many links"} + } + } +} + +func (fs *SFTPFs) isSubDir(name string) error { + if name == fs.config.Prefix { + return nil + } + if len(name) < len(fs.config.Prefix) { + err := fmt.Errorf("path %q is not inside: %q", name, fs.config.Prefix) + return &pathResolutionError{err: err.Error()} + } + if !strings.HasPrefix(name, fs.config.Prefix+"/") { + err := fmt.Errorf("path %q is not inside: %q", name, fs.config.Prefix) + return &pathResolutionError{err: err.Error()} + } + return nil +} + +// GetDirSize returns the number of files and the size for a folder +// including any subfolders +func (fs *SFTPFs) GetDirSize(dirname string) (int, int64, error) { + numFiles := 0 + size := int64(0) + client, err := fs.conn.getClient() + if err != nil { + return numFiles, size, err + } + isDir, err := isDirectory(fs, dirname) + if err == nil && isDir { + walker := client.Walk(dirname) + for walker.Step() { + err := walker.Err() + if err != nil { + return numFiles, size, err + } + if walker.Stat().Mode().IsRegular() { + size += walker.Stat().Size() + numFiles++ + if numFiles%1000 == 0 { + fsLog(fs, logger.LevelDebug, "dirname %q scan in progress, files: %d, size: %d", dirname, numFiles, size) + } + } + } + } + return numFiles, size, err +} + +// GetMimeType returns the content type +func (fs *SFTPFs) GetMimeType(name string) (string, error) { + client, err := fs.conn.getClient() + if err != nil { + return "", err + } + f, err := client.OpenFile(name, os.O_RDONLY) + if err != nil { + return "", err + } + defer f.Close() + var buf [512]byte + n, err := io.ReadFull(f, buf[:]) + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + return "", err + } + ctype := http.DetectContentType(buf[:n]) + // Rewind file. + _, err = f.Seek(0, io.SeekStart) + return ctype, err +} + +// GetAvailableDiskSize returns the available size for the specified path +func (fs *SFTPFs) GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) { + client, err := fs.conn.getClient() + if err != nil { + return nil, err + } + if _, ok := client.HasExtension("statvfs@openssh.com"); !ok { + return nil, ErrStorageSizeUnavailable + } + return client.StatVFS(dirName) +} + +// Close the connection +func (fs *SFTPFs) Close() error { + fs.conn.RemoveSession(fs.connectionID) + return nil +} + +func (fs *SFTPFs) createConnection() error { + err := fs.conn.OpenConnection() + if err != nil { + fsLog(fs, logger.LevelError, "error opening connection: %v", err) + return err + } + return nil +} + +type sftpConnection struct { + config *SFTPFsConfig + logSender string + sshClient *ssh.Client + sftpClient *sftp.Client + mu sync.RWMutex + isConnected bool + sessions map[string]bool + lastActivity time.Time + signer ssh.Signer +} + +func newSFTPConnection(config *SFTPFsConfig, sessionID string) *sftpConnection { + c := &sftpConnection{ + config: config, + logSender: fmt.Sprintf(`%s "%s@%s"`, sftpFsName, config.Username, config.Endpoint), + isConnected: false, + sessions: map[string]bool{}, + lastActivity: time.Now().UTC(), + signer: nil, + } + c.sessions[sessionID] = true + return c +} + +func (c *sftpConnection) OpenConnection() error { + c.mu.Lock() + defer c.mu.Unlock() + + return c.openConnNoLock() +} + +func (c *sftpConnection) openConnNoLock() error { + if c.isConnected { + logger.Debug(c.logSender, "", "reusing connection") + return nil + } + + logger.Debug(c.logSender, "", "try to open a new connection") + clientConfig := &ssh.ClientConfig{ + User: c.config.Username, + HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error { + fp := ssh.FingerprintSHA256(key) + if slices.Contains(sftpFingerprints, fp) { + if allowSelfConnections == 0 { + logger.Log(logger.LevelError, c.logSender, "", "SFTP self connections not allowed") + return ErrSFTPLoop + } + if slices.Contains(c.config.forbiddenSelfUsernames, c.config.Username) { + logger.Log(logger.LevelError, c.logSender, "", + "SFTP loop or nested local SFTP folders detected, username %q, forbidden usernames: %+v", + c.config.Username, c.config.forbiddenSelfUsernames) + return ErrSFTPLoop + } + } + if len(c.config.Fingerprints) > 0 { + for _, provided := range c.config.Fingerprints { + if provided == fp { + return nil + } + } + return fmt.Errorf("invalid fingerprint %q", fp) + } + logger.Log(logger.LevelWarn, c.logSender, "", "login without host key validation, please provide at least a fingerprint!") + return nil + }, + Timeout: 15 * time.Second, + ClientVersion: fmt.Sprintf("SSH-2.0-%s", version.GetServerVersion("_", false)), + } + if c.signer != nil { + clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(c.signer)) + } + if pwd := c.config.Password.GetPayload(); pwd != "" { + clientConfig.Auth = append(clientConfig.Auth, ssh.Password(pwd)) + } + supportedAlgos := ssh.SupportedAlgorithms() + insecureAlgos := ssh.InsecureAlgorithms() + // add all available ciphers, KEXs and MACs, they are negotiated according to the order + clientConfig.Ciphers = append(supportedAlgos.Ciphers, ssh.InsecureCipherAES128CBC) + clientConfig.KeyExchanges = append(supportedAlgos.KeyExchanges, insecureAlgos.KeyExchanges...) + clientConfig.MACs = append(supportedAlgos.MACs, insecureAlgos.MACs...) + sshClient, err := ssh.Dial("tcp", c.config.Endpoint, clientConfig) + if err != nil { + return fmt.Errorf("sftpfs: unable to connect: %w", err) + } + sftpClient, err := sftp.NewClient(sshClient, c.getClientOptions()...) + if err != nil { + sshClient.Close() + return fmt.Errorf("sftpfs: unable to create SFTP client: %w", err) + } + c.sshClient = sshClient + c.sftpClient = sftpClient + c.isConnected = true + go c.Wait() + return nil +} + +func (c *sftpConnection) getClientOptions() []sftp.ClientOption { + var options []sftp.ClientOption + if c.config.DisableCouncurrentReads { + options = append(options, sftp.UseConcurrentReads(false)) + logger.Debug(c.logSender, "", "disabling concurrent reads") + } + if c.config.BufferSize > 0 { + options = append(options, sftp.UseConcurrentWrites(true)) + logger.Debug(c.logSender, "", "enabling concurrent writes") + } + return options +} + +func (c *sftpConnection) getClient() (*sftp.Client, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.isConnected { + return c.sftpClient, nil + } + err := c.openConnNoLock() + return c.sftpClient, err +} + +func (c *sftpConnection) Wait() { + done := make(chan struct{}) + + go func() { + var watchdogInProgress atomic.Bool + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if watchdogInProgress.Load() { + logger.Error(c.logSender, "", "watchdog still in progress, closing hanging connection") + c.sshClient.Close() + return + } + go func() { + watchdogInProgress.Store(true) + defer watchdogInProgress.Store(false) + + _, err := c.sftpClient.Getwd() + if err != nil { + logger.Error(c.logSender, "", "watchdog error: %v", err) + } + }() + case <-done: + logger.Debug(c.logSender, "", "quitting watchdog") + return + } + } + }() + + // we wait on the sftp client otherwise if the channel is closed but not the connection + // we don't detect the event. + err := c.sftpClient.Wait() + logger.Log(logger.LevelDebug, c.logSender, "", "sftp channel closed: %v", err) + close(done) + + c.mu.Lock() + defer c.mu.Unlock() + + c.isConnected = false + if c.sshClient != nil { + c.sshClient.Close() + } +} + +func (c *sftpConnection) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + logger.Debug(c.logSender, "", "closing connection") + var sftpErr, sshErr error + if c.sftpClient != nil { + sftpErr = c.sftpClient.Close() + } + if c.sshClient != nil { + sshErr = c.sshClient.Close() + } + if sftpErr != nil { + return sftpErr + } + c.isConnected = false + return sshErr +} + +func (c *sftpConnection) AddSession(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + + c.sessions[sessionID] = true + logger.Debug(c.logSender, "", "added session %s, active sessions: %d", sessionID, len(c.sessions)) +} + +func (c *sftpConnection) RemoveSession(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.sessions, sessionID) + logger.Debug(c.logSender, "", "removed session %s, active sessions: %d", sessionID, len(c.sessions)) + if len(c.sessions) == 0 { + c.lastActivity = time.Now().UTC() + } +} + +func (c *sftpConnection) ActiveSessions() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.sessions) +} + +func (c *sftpConnection) GetLastActivity() time.Time { + c.mu.RLock() + defer c.mu.RUnlock() + + if len(c.sessions) > 0 { + return time.Now().UTC() + } + logger.Debug(c.logSender, "", "last activity %s", c.lastActivity) + return c.lastActivity +} + +type sftpConnectionsCache struct { + scheduler *cron.Cron + sync.Mutex + items map[string]*sftpConnection +} + +func newSFTPConnectionCache() *sftpConnectionsCache { + c := &sftpConnectionsCache{ + scheduler: cron.New(cron.WithLocation(time.UTC), cron.WithLogger(cron.DiscardLogger)), + items: make(map[string]*sftpConnection), + } + _, err := c.scheduler.AddFunc("@every 1m", c.Cleanup) + util.PanicOnError(err) + c.scheduler.Start() + return c +} + +func (c *sftpConnectionsCache) Get(config *SFTPFsConfig, sessionID string) (*sftpConnection, error) { + partition := 0 + key := config.getUniqueID(partition) + + c.Lock() + defer c.Unlock() + + for { + if val, ok := c.items[key]; ok { + activeSessions := val.ActiveSessions() + if activeSessions < maxSessionsPerConnection { + logger.Debug(logSenderSFTPCache, "", + "reusing connection for session ID %q, key %s, active sessions %d, active connections: %d", + sessionID, key, activeSessions+1, len(c.items)) + val.AddSession(sessionID) + return val, nil + } + partition++ + key = config.getUniqueID(partition) + logger.Debug(logSenderSFTPCache, "", + "connection full, generated new key for partition: %d, active sessions: %d, key: %s", + partition, activeSessions, key) + } else { + conn := newSFTPConnection(config, sessionID) + signer, err := config.getKeySigner() + if err != nil { + return nil, fmt.Errorf("sftpfs: unable to parse the private key: %w", err) + } + conn.signer = signer + c.items[key] = conn + logger.Debug(logSenderSFTPCache, "", + "adding new connection for session ID %q, partition: %d, key: %s, active connections: %d", + sessionID, partition, key, len(c.items)) + return conn, nil + } + } +} + +func (c *sftpConnectionsCache) Cleanup() { + c.Lock() + + var connectionsToClose []*sftpConnection + + for k, conn := range c.items { + if val := conn.GetLastActivity(); val.Before(time.Now().Add(-30 * time.Second)) { + delete(c.items, k) + logger.Debug(logSenderSFTPCache, "", "removed connection with key %s, last activity %s, active connections: %d", + k, val, len(c.items)) + connectionsToClose = append(connectionsToClose, conn) + } + } + + c.Unlock() + + for _, conn := range connectionsToClose { + err := conn.Close() + logger.Debug(logSenderSFTPCache, "", "connection closed, err: %v", err) + } +} diff --git a/internal/vfs/statvfs_fallback.go b/internal/vfs/statvfs_fallback.go new file mode 100644 index 00000000..19d8e2c3 --- /dev/null +++ b/internal/vfs/statvfs_fallback.go @@ -0,0 +1,52 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !darwin && !linux && !freebsd + +package vfs + +import ( + "github.com/pkg/sftp" + "github.com/shirou/gopsutil/v3/disk" +) + +const bsize = uint64(4096) + +func getStatFS(path string) (*sftp.StatVFS, error) { + usage, err := disk.Usage(path) + if err != nil { + return nil, err + } + // we assume block size = 4096 + blocks := usage.Total / bsize + bfree := usage.Free / bsize + files := usage.InodesTotal + ffree := usage.InodesFree + if files == 0 { + // these assumptions are wrong but still better than returning 0 + files = blocks / 4 + ffree = bfree / 4 + } + return &sftp.StatVFS{ + Bsize: bsize, + Frsize: bsize, + Blocks: blocks, + Bfree: bfree, + Bavail: bfree, + Files: files, + Ffree: ffree, + Favail: ffree, + Namemax: 255, + }, nil +} diff --git a/internal/vfs/statvfs_linux.go b/internal/vfs/statvfs_linux.go new file mode 100644 index 00000000..772265fc --- /dev/null +++ b/internal/vfs/statvfs_linux.go @@ -0,0 +1,42 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build linux + +package vfs + +import ( + "github.com/pkg/sftp" + "golang.org/x/sys/unix" +) + +func getStatFS(path string) (*sftp.StatVFS, error) { + stat := unix.Statfs_t{} + err := unix.Statfs(path, &stat) + if err != nil { + return nil, err + } + return &sftp.StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Frsize), + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: stat.Bavail, + Files: stat.Files, + Ffree: stat.Ffree, + Favail: stat.Ffree, // not sure how to calculate Favail + Flag: uint64(stat.Flags), + Namemax: uint64(stat.Namelen), + }, nil +} diff --git a/internal/vfs/statvfs_unix.go b/internal/vfs/statvfs_unix.go new file mode 100644 index 00000000..53f43202 --- /dev/null +++ b/internal/vfs/statvfs_unix.go @@ -0,0 +1,42 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build freebsd || darwin + +package vfs + +import ( + "github.com/pkg/sftp" + "golang.org/x/sys/unix" +) + +func getStatFS(path string) (*sftp.StatVFS, error) { + stat := unix.Statfs_t{} + err := unix.Statfs(path, &stat) + if err != nil { + return nil, err + } + return &sftp.StatVFS{ + Bsize: uint64(stat.Bsize), + Frsize: uint64(stat.Bsize), + Blocks: stat.Blocks, + Bfree: stat.Bfree, + Bavail: uint64(stat.Bavail), + Files: stat.Files, + Ffree: uint64(stat.Ffree), + Favail: uint64(stat.Ffree), // not sure how to calculate Favail + Flag: uint64(stat.Flags), + Namemax: 255, // we use a conservative value here + }, nil +} diff --git a/internal/vfs/sys_unix.go b/internal/vfs/sys_unix.go new file mode 100644 index 00000000..427792f8 --- /dev/null +++ b/internal/vfs/sys_unix.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +//go:build !windows + +package vfs + +import ( + "errors" + + "golang.org/x/sys/unix" +) + +func isCrossDeviceError(err error) bool { + return errors.Is(err, unix.EXDEV) +} + +func isInvalidNameError(_ error) bool { + return false +} diff --git a/internal/vfs/sys_windows.go b/internal/vfs/sys_windows.go new file mode 100644 index 00000000..f39fd572 --- /dev/null +++ b/internal/vfs/sys_windows.go @@ -0,0 +1,32 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vfs + +import ( + "errors" + + "golang.org/x/sys/windows" +) + +func isCrossDeviceError(err error) bool { + return errors.Is(err, windows.ERROR_NOT_SAME_DEVICE) +} + +func isInvalidNameError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, windows.ERROR_INVALID_NAME) +} diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go new file mode 100644 index 00000000..6f883e6a --- /dev/null +++ b/internal/vfs/vfs.go @@ -0,0 +1,1359 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package vfs provides local and remote filesystems support +package vfs + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/eikenb/pipeat" + "github.com/pkg/sftp" + "github.com/sftpgo/sdk" + + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +const ( + dirMimeType = "inode/directory" + s3fsName = "S3Fs" + gcsfsName = "GCSFs" + azBlobFsName = "AzureBlobFs" + lastModifiedField = "sftpgo_last_modified" + preResumeTimeout = 90 * time.Second + // ListerBatchSize defines the default limit for DirLister implementations + ListerBatchSize = 1000 +) + +// Additional checks for files +const ( + CheckParentDir = 1 + CheckResume = 2 + CheckUpdateModTime = 4 +) + +var ( + validAzAccessTier = []string{"", "Archive", "Hot", "Cool"} + // ErrStorageSizeUnavailable is returned if the storage backend does not support getting the size + ErrStorageSizeUnavailable = errors.New("unable to get available size for this storage backend") + // ErrVfsUnsupported defines the error for an unsupported VFS operation + ErrVfsUnsupported = errors.New("not supported") + errInvalidDirListerLimit = errors.New("dir lister: invalid limit, must be > 0") + tempPath string + sftpFingerprints []string + allowSelfConnections int + renameMode int + readMetadata int + resumeMaxSize int64 + uploadMode int +) + +var ( + createPipeFn = func(dirPath string, _ int64) (pipeReaderAt, pipeWriterAt, error) { + return pipeat.PipeInDir(dirPath) + } +) + +// SetAllowSelfConnections sets the desired behaviour for self connections +func SetAllowSelfConnections(value int) { + allowSelfConnections = value +} + +// SetTempPath sets the path for temporary files +func SetTempPath(fsPath string) { + tempPath = fsPath +} + +// GetTempPath returns the path for temporary files +func GetTempPath() string { + return tempPath +} + +// SetSFTPFingerprints sets the SFTP host key fingerprints +func SetSFTPFingerprints(fp []string) { + sftpFingerprints = fp +} + +// SetRenameMode sets the rename mode +func SetRenameMode(val int) { + renameMode = val +} + +// SetReadMetadataMode sets the read metadata mode +func SetReadMetadataMode(val int) { + readMetadata = val +} + +// SetResumeMaxSize sets the max size allowed for resuming uploads for backends +// with immutable objects +func SetResumeMaxSize(val int64) { + resumeMaxSize = val +} + +// SetUploadMode sets the upload mode +func SetUploadMode(val int) { + uploadMode = val +} + +// Fs defines the interface for filesystem backends +type Fs interface { + Name() string + ConnectionID() string + Stat(name string) (os.FileInfo, error) + Lstat(name string) (os.FileInfo, error) + Open(name string, offset int64) (File, PipeReader, func(), error) + Create(name string, flag, checks int) (File, PipeWriter, func(), error) + Rename(source, target string, checks int) (int, int64, error) + Remove(name string, isDir bool) error + Mkdir(name string) error + Symlink(source, target string) error + Chown(name string, uid int, gid int) error + Chmod(name string, mode os.FileMode) error + Chtimes(name string, atime, mtime time.Time, isUploading bool) error + Truncate(name string, size int64) error + ReadDir(dirname string) (DirLister, error) + Readlink(name string) (string, error) + IsUploadResumeSupported() bool + IsConditionalUploadResumeSupported(size int64) bool + IsAtomicUploadSupported() bool + CheckRootPath(username string, uid int, gid int) bool + ResolvePath(virtualPath string) (string, error) + IsNotExist(err error) bool + IsPermission(err error) bool + IsNotSupported(err error) bool + ScanRootDirContents() (int, int64, error) + GetDirSize(dirname string) (int, int64, error) + GetAtomicUploadPath(name string) string + GetRelativePath(name string) string + Walk(root string, walkFn filepath.WalkFunc) error + Join(elem ...string) string + HasVirtualFolders() bool + GetMimeType(name string) (string, error) + GetAvailableDiskSize(dirName string) (*sftp.StatVFS, error) + Close() error +} + +// FsRealPather is a Fs that implements the RealPath method. +type FsRealPather interface { + Fs + RealPath(p string) (string, error) +} + +// FsFileCopier is a Fs that implements the CopyFile method. +type FsFileCopier interface { + Fs + CopyFile(source, target string, srcInfo os.FileInfo) (int, int64, error) +} + +// File defines an interface representing a SFTPGo file +type File interface { + io.Reader + io.Writer + io.Closer + io.ReaderAt + io.WriterAt + io.Seeker + Stat() (os.FileInfo, error) + Name() string + Truncate(size int64) error +} + +// PipeWriter defines an interface representing a SFTPGo pipe writer +type PipeWriter interface { + io.Writer + io.WriterAt + io.Closer + Done(err error) + GetWrittenBytes() int64 +} + +// PipeReader defines an interface representing a SFTPGo pipe reader +type PipeReader interface { + io.Reader + io.ReaderAt + io.Closer + setMetadata(value map[string]string) + setMetadataFromPointerVal(value map[string]*string) + Metadata() map[string]string +} + +type pipeReaderAt interface { + Read(p []byte) (int, error) + ReadAt(p []byte, offset int64) (int, error) + GetReadedBytes() int64 + Close() error + CloseWithError(err error) error +} + +type pipeWriterAt interface { + Write(p []byte) (int, error) + WriteAt(p []byte, offset int64) (int, error) + GetWrittenBytes() int64 + Close() error + CloseWithError(err error) error +} + +// DirLister defines an interface for a directory lister +type DirLister interface { + Next(limit int) ([]os.FileInfo, error) + Close() error +} + +// Metadater defines an interface to implement to return metadata for a file +type Metadater interface { + Metadata() map[string]string +} + +type baseDirLister struct { + cache []os.FileInfo +} + +func (l *baseDirLister) Next(limit int) ([]os.FileInfo, error) { + if limit <= 0 { + return nil, errInvalidDirListerLimit + } + if len(l.cache) >= limit { + return l.returnFromCache(limit), nil + } + return l.returnFromCache(limit), io.EOF +} + +func (l *baseDirLister) returnFromCache(limit int) []os.FileInfo { + if len(l.cache) >= limit { + result := l.cache[:limit] + l.cache = l.cache[limit:] + return result + } + result := l.cache + l.cache = nil + return result +} + +func (l *baseDirLister) Close() error { + l.cache = nil + return nil +} + +// QuotaCheckResult defines the result for a quota check +type QuotaCheckResult struct { + HasSpace bool + AllowedSize int64 + AllowedFiles int + UsedSize int64 + UsedFiles int + QuotaSize int64 + QuotaFiles int +} + +// GetRemainingSize returns the remaining allowed size +func (q *QuotaCheckResult) GetRemainingSize() int64 { + if q.QuotaSize > 0 { + return q.QuotaSize - q.UsedSize + } + return 0 +} + +// GetRemainingFiles returns the remaining allowed files +func (q *QuotaCheckResult) GetRemainingFiles() int { + if q.QuotaFiles > 0 { + return q.QuotaFiles - q.UsedFiles + } + return 0 +} + +// S3FsConfig defines the configuration for S3 based filesystem +type S3FsConfig struct { + sdk.BaseS3FsConfig + AccessSecret *kms.Secret `json:"access_secret,omitempty"` + SSECustomerKey *kms.Secret `json:"sse_customer_key,omitempty"` +} + +// HideConfidentialData hides confidential data +func (c *S3FsConfig) HideConfidentialData() { + if c.AccessSecret != nil { + c.AccessSecret.Hide() + } + if c.SSECustomerKey != nil { + c.SSECustomerKey.Hide() + } +} + +func (c *S3FsConfig) isEqual(other S3FsConfig) bool { + if c.Bucket != other.Bucket { + return false + } + if c.KeyPrefix != other.KeyPrefix { + return false + } + if c.Region != other.Region { + return false + } + if c.AccessKey != other.AccessKey { + return false + } + if c.RoleARN != other.RoleARN { + return false + } + if c.Endpoint != other.Endpoint { + return false + } + if c.StorageClass != other.StorageClass { + return false + } + if c.ACL != other.ACL { + return false + } + if !c.areMultipartFieldsEqual(other) { + return false + } + if c.ForcePathStyle != other.ForcePathStyle { + return false + } + if c.SkipTLSVerify != other.SkipTLSVerify { + return false + } + return c.isSecretEqual(other) +} + +func (c *S3FsConfig) areMultipartFieldsEqual(other S3FsConfig) bool { + if c.UploadPartSize != other.UploadPartSize { + return false + } + if c.UploadConcurrency != other.UploadConcurrency { + return false + } + if c.DownloadConcurrency != other.DownloadConcurrency { + return false + } + if c.DownloadPartSize != other.DownloadPartSize { + return false + } + if c.DownloadPartMaxTime != other.DownloadPartMaxTime { + return false + } + if c.UploadPartMaxTime != other.UploadPartMaxTime { + return false + } + return true +} + +func (c *S3FsConfig) isSecretEqual(other S3FsConfig) bool { + if c.SSECustomerKey == nil { + c.SSECustomerKey = kms.NewEmptySecret() + } + if other.SSECustomerKey == nil { + other.SSECustomerKey = kms.NewEmptySecret() + } + if !c.SSECustomerKey.IsEqual(other.SSECustomerKey) { + return false + } + if c.AccessSecret == nil { + c.AccessSecret = kms.NewEmptySecret() + } + if other.AccessSecret == nil { + other.AccessSecret = kms.NewEmptySecret() + } + return c.AccessSecret.IsEqual(other.AccessSecret) +} + +func (c *S3FsConfig) checkCredentials() error { + if c.AccessKey == "" && !c.AccessSecret.IsEmpty() { + return util.NewI18nError( + errors.New("access_key cannot be empty with access_secret not empty"), + util.I18nErrorAccessKeyRequired, + ) + } + if c.AccessSecret.IsEmpty() && c.AccessKey != "" { + return util.NewI18nError( + errors.New("access_secret cannot be empty with access_key not empty"), + util.I18nErrorAccessSecretRequired, + ) + } + if c.AccessSecret.IsEncrypted() && !c.AccessSecret.IsValid() { + return errors.New("invalid encrypted access_secret") + } + if !c.AccessSecret.IsEmpty() && !c.AccessSecret.IsValidInput() { + return errors.New("invalid access_secret") + } + if c.SSECustomerKey.IsEncrypted() && !c.SSECustomerKey.IsValid() { + return errors.New("invalid encrypted sse_customer_key") + } + if !c.SSECustomerKey.IsEmpty() && !c.SSECustomerKey.IsValidInput() { + return errors.New("invalid sse_customer_key") + } + return nil +} + +// ValidateAndEncryptCredentials validates the configuration and encrypts access secret if it is in plain text +func (c *S3FsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + var errI18n *util.I18nError + errValidation := util.NewValidationError(fmt.Sprintf("could not validate s3config: %v", err)) + if errors.As(err, &errI18n) { + return util.NewI18nError(errValidation, errI18n.Message) + } + return util.NewI18nError(errValidation, util.I18nErrorFsValidation) + } + if c.AccessSecret.IsPlain() { + c.AccessSecret.SetAdditionalData(additionalData) + err := c.AccessSecret.Encrypt() + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt s3 access secret: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + if c.SSECustomerKey.IsPlain() { + c.SSECustomerKey.SetAdditionalData(additionalData) + err := c.SSECustomerKey.Encrypt() + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt s3 SSE customer key: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + return nil +} + +func (c *S3FsConfig) checkPartSizeAndConcurrency() error { + if c.UploadPartSize != 0 && (c.UploadPartSize < 5 || c.UploadPartSize > 2000) { + return util.NewI18nError( + errors.New("upload_part_size cannot be != 0, lower than 5 (MB) or greater than 2000 (MB)"), + util.I18nErrorULPartSizeInvalid, + ) + } + if c.UploadConcurrency < 0 || c.UploadConcurrency > 64 { + return util.NewI18nError( + fmt.Errorf("invalid upload concurrency: %v", c.UploadConcurrency), + util.I18nErrorULConcurrencyInvalid, + ) + } + if c.DownloadPartSize != 0 && (c.DownloadPartSize < 5 || c.DownloadPartSize > 2000) { + return util.NewI18nError( + errors.New("download_part_size cannot be != 0, lower than 5 (MB) or greater than 2000 (MB)"), + util.I18nErrorDLPartSizeInvalid, + ) + } + if c.DownloadConcurrency < 0 || c.DownloadConcurrency > 64 { + return util.NewI18nError( + fmt.Errorf("invalid download concurrency: %v", c.DownloadConcurrency), + util.I18nErrorDLConcurrencyInvalid, + ) + } + return nil +} + +func (c *S3FsConfig) isSameResource(other S3FsConfig) bool { + if c.Bucket != other.Bucket { + return false + } + if c.Endpoint != other.Endpoint { + return false + } + return c.Region == other.Region +} + +// validate returns an error if the configuration is not valid +func (c *S3FsConfig) validate() error { + if c.AccessSecret == nil { + c.AccessSecret = kms.NewEmptySecret() + } + if c.SSECustomerKey == nil { + c.SSECustomerKey = kms.NewEmptySecret() + } + if c.Bucket == "" { + return util.NewI18nError(errors.New("bucket cannot be empty"), util.I18nErrorBucketRequired) + } + // the region may be embedded within the endpoint for some S3 compatible + // object storage, for example B2 + if c.Endpoint == "" && c.Region == "" { + return util.NewI18nError(errors.New("region cannot be empty"), util.I18nErrorRegionRequired) + } + if err := c.checkCredentials(); err != nil { + return err + } + if c.KeyPrefix != "" { + if strings.HasPrefix(c.KeyPrefix, "/") { + return util.NewI18nError(errors.New("key_prefix cannot start with /"), util.I18nErrorKeyPrefixInvalid) + } + c.KeyPrefix = path.Clean(c.KeyPrefix) + if !strings.HasSuffix(c.KeyPrefix, "/") { + c.KeyPrefix += "/" + } + } + c.StorageClass = strings.TrimSpace(c.StorageClass) + c.ACL = strings.TrimSpace(c.ACL) + return c.checkPartSizeAndConcurrency() +} + +// GCSFsConfig defines the configuration for Google Cloud Storage based filesystem +type GCSFsConfig struct { + sdk.BaseGCSFsConfig + Credentials *kms.Secret `json:"credentials,omitempty"` +} + +// HideConfidentialData hides confidential data +func (c *GCSFsConfig) HideConfidentialData() { + if c.Credentials != nil { + c.Credentials.Hide() + } +} + +// ValidateAndEncryptCredentials validates the configuration and encrypts credentials if they are in plain text +func (c *GCSFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + var errI18n *util.I18nError + errValidation := util.NewValidationError(fmt.Sprintf("could not validate GCS config: %v", err)) + if errors.As(err, &errI18n) { + return util.NewI18nError(errValidation, errI18n.Message) + } + return util.NewI18nError(errValidation, util.I18nErrorFsValidation) + } + if c.Credentials.IsPlain() { + c.Credentials.SetAdditionalData(additionalData) + err := c.Credentials.Encrypt() + if err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt GCS credentials: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + return nil +} + +func (c *GCSFsConfig) isEqual(other GCSFsConfig) bool { + if c.Bucket != other.Bucket { + return false + } + if c.KeyPrefix != other.KeyPrefix { + return false + } + if c.AutomaticCredentials != other.AutomaticCredentials { + return false + } + if c.StorageClass != other.StorageClass { + return false + } + if c.ACL != other.ACL { + return false + } + if c.UploadPartSize != other.UploadPartSize { + return false + } + if c.UploadPartMaxTime != other.UploadPartMaxTime { + return false + } + if c.Credentials == nil { + c.Credentials = kms.NewEmptySecret() + } + if other.Credentials == nil { + other.Credentials = kms.NewEmptySecret() + } + return c.Credentials.IsEqual(other.Credentials) +} + +func (c *GCSFsConfig) isSameResource(other GCSFsConfig) bool { + return c.Bucket == other.Bucket +} + +// validate returns an error if the configuration is not valid +func (c *GCSFsConfig) validate() error { //nolint:gocyclo + if c.Credentials == nil || c.AutomaticCredentials == 1 { + c.Credentials = kms.NewEmptySecret() + } + if c.Bucket == "" { + return util.NewI18nError(errors.New("bucket cannot be empty"), util.I18nErrorBucketRequired) + } + if c.KeyPrefix != "" { + if strings.HasPrefix(c.KeyPrefix, "/") { + return util.NewI18nError(errors.New("key_prefix cannot start with /"), util.I18nErrorKeyPrefixInvalid) + } + c.KeyPrefix = path.Clean(c.KeyPrefix) + if !strings.HasSuffix(c.KeyPrefix, "/") { + c.KeyPrefix += "/" + } + } + if c.Credentials.IsEncrypted() && !c.Credentials.IsValid() { + return errors.New("invalid encrypted credentials") + } + if c.AutomaticCredentials == 0 && !c.Credentials.IsValidInput() { + return util.NewI18nError(errors.New("invalid credentials"), util.I18nErrorFsCredentialsRequired) + } + c.StorageClass = strings.TrimSpace(c.StorageClass) + c.ACL = strings.TrimSpace(c.ACL) + if c.UploadPartSize < 0 || c.UploadPartSize > 2000 { + c.UploadPartSize = 0 + } + if c.UploadPartMaxTime < 0 { + c.UploadPartMaxTime = 0 + } + return nil +} + +// AzBlobFsConfig defines the configuration for Azure Blob Storage based filesystem +type AzBlobFsConfig struct { + sdk.BaseAzBlobFsConfig + // Storage Account Key leave blank to use SAS URL. + // The access key is stored encrypted based on the kms configuration + AccountKey *kms.Secret `json:"account_key,omitempty"` + // Shared access signature URL, leave blank if using account/key + SASURL *kms.Secret `json:"sas_url,omitempty"` +} + +// HideConfidentialData hides confidential data +func (c *AzBlobFsConfig) HideConfidentialData() { + if c.AccountKey != nil { + c.AccountKey.Hide() + } + if c.SASURL != nil { + c.SASURL.Hide() + } +} + +func (c *AzBlobFsConfig) isEqual(other AzBlobFsConfig) bool { + if c.Container != other.Container { + return false + } + if c.AccountName != other.AccountName { + return false + } + if c.Endpoint != other.Endpoint { + return false + } + if c.SASURL.IsEmpty() { + c.SASURL = kms.NewEmptySecret() + } + if other.SASURL.IsEmpty() { + other.SASURL = kms.NewEmptySecret() + } + if !c.SASURL.IsEqual(other.SASURL) { + return false + } + if c.KeyPrefix != other.KeyPrefix { + return false + } + if c.UploadPartSize != other.UploadPartSize { + return false + } + if c.UploadConcurrency != other.UploadConcurrency { + return false + } + if c.DownloadPartSize != other.DownloadPartSize { + return false + } + if c.DownloadConcurrency != other.DownloadConcurrency { + return false + } + if c.UseEmulator != other.UseEmulator { + return false + } + if c.AccessTier != other.AccessTier { + return false + } + return c.isSecretEqual(other) +} + +func (c *AzBlobFsConfig) isSecretEqual(other AzBlobFsConfig) bool { + if c.AccountKey == nil { + c.AccountKey = kms.NewEmptySecret() + } + if other.AccountKey == nil { + other.AccountKey = kms.NewEmptySecret() + } + return c.AccountKey.IsEqual(other.AccountKey) +} + +// ValidateAndEncryptCredentials validates the configuration and encrypts access secret if it is in plain text +func (c *AzBlobFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + var errI18n *util.I18nError + errValidation := util.NewValidationError(fmt.Sprintf("could not validate Azure Blob config: %v", err)) + if errors.As(err, &errI18n) { + return util.NewI18nError(errValidation, errI18n.Message) + } + return util.NewI18nError(errValidation, util.I18nErrorFsValidation) + } + if c.AccountKey.IsPlain() { + c.AccountKey.SetAdditionalData(additionalData) + if err := c.AccountKey.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob account key: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + if c.SASURL.IsPlain() { + c.SASURL.SetAdditionalData(additionalData) + if err := c.SASURL.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt Azure blob SAS URL: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + return nil +} + +func (c *AzBlobFsConfig) checkCredentials() error { + if c.SASURL.IsPlain() { + _, err := url.Parse(c.SASURL.GetPayload()) + if err != nil { + return util.NewI18nError(err, util.I18nErrorSASURLInvalid) + } + return nil + } + if c.SASURL.IsEncrypted() && !c.SASURL.IsValid() { + return errors.New("invalid encrypted sas_url") + } + if !c.SASURL.IsEmpty() { + return nil + } + if c.AccountName == "" { + return util.NewI18nError(errors.New("account name is required"), util.I18nErrorAccountNameRequired) + } + if c.AccountKey.IsEncrypted() && !c.AccountKey.IsValid() { + return errors.New("invalid encrypted account_key") + } + return nil +} + +func (c *AzBlobFsConfig) checkPartSizeAndConcurrency() error { + if c.UploadPartSize < 0 || c.UploadPartSize > 2000 { + return util.NewI18nError( + fmt.Errorf("invalid upload part size: %v", c.UploadPartSize), + util.I18nErrorULPartSizeInvalid, + ) + } + if c.UploadConcurrency < 0 || c.UploadConcurrency > 64 { + return util.NewI18nError( + fmt.Errorf("invalid upload concurrency: %v", c.UploadConcurrency), + util.I18nErrorULConcurrencyInvalid, + ) + } + if c.DownloadPartSize < 0 || c.DownloadPartSize > 2000 { + return util.NewI18nError( + fmt.Errorf("invalid download part size: %v", c.DownloadPartSize), + util.I18nErrorDLPartSizeInvalid, + ) + } + if c.DownloadConcurrency < 0 || c.DownloadConcurrency > 64 { + return util.NewI18nError( + fmt.Errorf("invalid upload concurrency: %v", c.DownloadConcurrency), + util.I18nErrorDLConcurrencyInvalid, + ) + } + return nil +} + +func (c *AzBlobFsConfig) tryDecrypt() error { + if err := c.AccountKey.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt account key: %w", err) + } + if err := c.SASURL.TryDecrypt(); err != nil { + return fmt.Errorf("unable to decrypt SAS URL: %w", err) + } + return nil +} + +func (c *AzBlobFsConfig) isSameResource(other AzBlobFsConfig) bool { + if c.AccountName != other.AccountName { + return false + } + if c.Endpoint != other.Endpoint { + return false + } + if c.SASURL == nil { + c.SASURL = kms.NewEmptySecret() + } + if other.SASURL == nil { + other.SASURL = kms.NewEmptySecret() + } + return c.SASURL.GetPayload() == other.SASURL.GetPayload() +} + +// validate returns an error if the configuration is not valid +func (c *AzBlobFsConfig) validate() error { + if c.AccountKey == nil { + c.AccountKey = kms.NewEmptySecret() + } + if c.SASURL == nil { + c.SASURL = kms.NewEmptySecret() + } + // container could be embedded within SAS URL we check this at runtime + if c.SASURL.IsEmpty() && c.Container == "" { + return util.NewI18nError(errors.New("container cannot be empty"), util.I18nErrorContainerRequired) + } + if err := c.checkCredentials(); err != nil { + return err + } + if c.KeyPrefix != "" { + if strings.HasPrefix(c.KeyPrefix, "/") { + return util.NewI18nError(errors.New("key_prefix cannot start with /"), util.I18nErrorKeyPrefixInvalid) + } + c.KeyPrefix = path.Clean(c.KeyPrefix) + if !strings.HasSuffix(c.KeyPrefix, "/") { + c.KeyPrefix += "/" + } + } + if err := c.checkPartSizeAndConcurrency(); err != nil { + return err + } + if !slices.Contains(validAzAccessTier, c.AccessTier) { + return fmt.Errorf("invalid access tier %q, valid values: \"''%v\"", c.AccessTier, strings.Join(validAzAccessTier, ", ")) + } + return nil +} + +// CryptFsConfig defines the configuration to store local files as encrypted +type CryptFsConfig struct { + sdk.OSFsConfig + Passphrase *kms.Secret `json:"passphrase,omitempty"` +} + +// HideConfidentialData hides confidential data +func (c *CryptFsConfig) HideConfidentialData() { + if c.Passphrase != nil { + c.Passphrase.Hide() + } +} + +func (c *CryptFsConfig) isEqual(other CryptFsConfig) bool { + if c.Passphrase == nil { + c.Passphrase = kms.NewEmptySecret() + } + if other.Passphrase == nil { + other.Passphrase = kms.NewEmptySecret() + } + return c.Passphrase.IsEqual(other.Passphrase) +} + +// ValidateAndEncryptCredentials validates the configuration and encrypts the passphrase if it is in plain text +func (c *CryptFsConfig) ValidateAndEncryptCredentials(additionalData string) error { + if err := c.validate(); err != nil { + var errI18n *util.I18nError + errValidation := util.NewValidationError(fmt.Sprintf("could not validate crypt fs config: %v", err)) + if errors.As(err, &errI18n) { + return util.NewI18nError(errValidation, errI18n.Message) + } + return util.NewI18nError(errValidation, util.I18nErrorFsValidation) + } + if c.Passphrase.IsPlain() { + c.Passphrase.SetAdditionalData(additionalData) + if err := c.Passphrase.Encrypt(); err != nil { + return util.NewI18nError( + util.NewValidationError(fmt.Sprintf("could not encrypt Crypt fs passphrase: %v", err)), + util.I18nErrorFsValidation, + ) + } + } + return nil +} + +func (c *CryptFsConfig) isSameResource(other CryptFsConfig) bool { + return c.Passphrase.GetPayload() == other.Passphrase.GetPayload() +} + +// validate returns an error if the configuration is not valid +func (c *CryptFsConfig) validate() error { + if c.Passphrase == nil || c.Passphrase.IsEmpty() { + return util.NewI18nError(errors.New("invalid passphrase"), util.I18nErrorPassphraseRequired) + } + if !c.Passphrase.IsValidInput() { + return util.NewI18nError(errors.New("passphrase cannot be empty or invalid"), util.I18nErrorPassphraseRequired) + } + if c.Passphrase.IsEncrypted() && !c.Passphrase.IsValid() { + return errors.New("invalid encrypted passphrase") + } + return nil +} + +// pipeWriter defines a wrapper for a pipeWriterAt. +type pipeWriter struct { + pipeWriterAt + err error + done chan bool +} + +// NewPipeWriter initializes a new PipeWriter +func NewPipeWriter(w pipeWriterAt) PipeWriter { + return &pipeWriter{ + pipeWriterAt: w, + err: nil, + done: make(chan bool), + } +} + +// Close waits for the upload to end, closes the pipeWriterAt and returns an error if any. +func (p *pipeWriter) Close() error { + p.pipeWriterAt.Close() //nolint:errcheck // the returned error is always null + <-p.done + return p.err +} + +// Done unlocks other goroutines waiting on Close(). +// It must be called when the upload ends +func (p *pipeWriter) Done(err error) { + p.err = err + p.done <- true +} + +func newPipeWriterAtOffset(w pipeWriterAt, offset int64) PipeWriter { + return &pipeWriterAtOffset{ + pipeWriter: &pipeWriter{ + pipeWriterAt: w, + err: nil, + done: make(chan bool), + }, + offset: offset, + writeOffset: offset, + } +} + +type pipeWriterAtOffset struct { + *pipeWriter + offset int64 + writeOffset int64 +} + +func (p *pipeWriterAtOffset) WriteAt(buf []byte, off int64) (int, error) { + if off < p.offset { + return 0, fmt.Errorf("invalid offset %d, minimum accepted %d", off, p.offset) + } + return p.pipeWriter.WriteAt(buf, off-p.offset) +} + +func (p *pipeWriterAtOffset) Write(buf []byte) (int, error) { + n, err := p.WriteAt(buf, p.writeOffset) + p.writeOffset += int64(n) + return n, err +} + +// NewPipeReader initializes a new PipeReader +func NewPipeReader(r pipeReaderAt) PipeReader { + return &pipeReader{ + pipeReaderAt: r, + } +} + +// pipeReader defines a wrapper for pipeat.PipeReaderAt. +type pipeReader struct { + pipeReaderAt + mu sync.RWMutex + metadata map[string]string +} + +func (p *pipeReader) setMetadata(value map[string]string) { + p.mu.Lock() + defer p.mu.Unlock() + + p.metadata = value +} + +func (p *pipeReader) setMetadataFromPointerVal(value map[string]*string) { + p.mu.Lock() + defer p.mu.Unlock() + + if len(value) == 0 { + p.metadata = nil + return + } + + p.metadata = map[string]string{} + for k, v := range value { + val := util.GetStringFromPointer(v) + if val != "" { + p.metadata[k] = val + } + } +} + +// Metadata implements the Metadater interface +func (p *pipeReader) Metadata() map[string]string { + p.mu.RLock() + defer p.mu.RUnlock() + + if len(p.metadata) == 0 { + return nil + } + result := make(map[string]string) + for k, v := range p.metadata { + result[k] = v + } + return result +} + +func isEqualityCheckModeValid(mode int) bool { + return mode >= 0 || mode <= 1 +} + +// isDirectory checks if a path exists and is a directory +func isDirectory(fs Fs, path string) (bool, error) { + fileInfo, err := fs.Stat(path) + if err != nil { + return false, err + } + return fileInfo.IsDir(), err +} + +// IsLocalOsFs returns true if fs is a local filesystem implementation +func IsLocalOsFs(fs Fs) bool { + return fs.Name() == osFsName +} + +// IsCryptOsFs returns true if fs is an encrypted local filesystem implementation +func IsCryptOsFs(fs Fs) bool { + return fs.Name() == cryptFsName +} + +// IsSFTPFs returns true if fs is an SFTP filesystem +func IsSFTPFs(fs Fs) bool { + return strings.HasPrefix(fs.Name(), sftpFsName) +} + +// IsHTTPFs returns true if fs is an HTTP filesystem +func IsHTTPFs(fs Fs) bool { + return strings.HasPrefix(fs.Name(), httpFsName) +} + +// IsBufferedLocalOrSFTPFs returns true if this is a buffered SFTP or local filesystem +func IsBufferedLocalOrSFTPFs(fs Fs) bool { + if osFs, ok := fs.(*OsFs); ok { + return osFs.writeBufferSize > 0 + } + if !IsSFTPFs(fs) { + return false + } + return !fs.IsUploadResumeSupported() +} + +// FsOpenReturnsFile returns true if fs.Open returns a *os.File handle +func FsOpenReturnsFile(fs Fs) bool { + if osFs, ok := fs.(*OsFs); ok { + return osFs.readBufferSize == 0 + } + if sftpFs, ok := fs.(*SFTPFs); ok { + return sftpFs.config.BufferSize == 0 + } + return false +} + +// IsLocalOrSFTPFs returns true if fs is local or SFTP +func IsLocalOrSFTPFs(fs Fs) bool { + return IsLocalOsFs(fs) || IsSFTPFs(fs) +} + +// HasTruncateSupport returns true if the fs supports truncate files +func HasTruncateSupport(fs Fs) bool { + return IsLocalOsFs(fs) || IsSFTPFs(fs) || IsHTTPFs(fs) +} + +// IsRenameAtomic returns true if renaming a directory is supposed to be atomic +func IsRenameAtomic(fs Fs) bool { + if strings.HasPrefix(fs.Name(), s3fsName) { + return false + } + if strings.HasPrefix(fs.Name(), gcsfsName) { + return false + } + if strings.HasPrefix(fs.Name(), azBlobFsName) { + return false + } + return true +} + +// HasImplicitAtomicUploads returns true if the fs don't persists partial files on error +func HasImplicitAtomicUploads(fs Fs) bool { + if strings.HasPrefix(fs.Name(), s3fsName) { + return uploadMode&4 == 0 + } + if strings.HasPrefix(fs.Name(), gcsfsName) { + return uploadMode&8 == 0 + } + if strings.HasPrefix(fs.Name(), azBlobFsName) { + return uploadMode&16 == 0 + } + return false +} + +// HasOpenRWSupport returns true if the fs can open a file +// for reading and writing at the same time +func HasOpenRWSupport(fs Fs) bool { + if IsLocalOsFs(fs) { + return true + } + if IsSFTPFs(fs) && fs.IsUploadResumeSupported() { + return true + } + return false +} + +// IsLocalOrCryptoFs returns true if fs is local or local encrypted +func IsLocalOrCryptoFs(fs Fs) bool { + return IsLocalOsFs(fs) || IsCryptOsFs(fs) +} + +// SetPathPermissions calls fs.Chown. +// It does nothing for local filesystem on windows +func SetPathPermissions(fs Fs, path string, uid int, gid int) { + if uid == -1 && gid == -1 { + return + } + if IsLocalOsFs(fs) { + if runtime.GOOS == "windows" { + return + } + } + if err := fs.Chown(path, uid, gid); err != nil { + fsLog(fs, logger.LevelWarn, "error chowning path %v: %v", path, err) + } +} + +// IsUploadResumeSupported returns true if resuming uploads is supported +func IsUploadResumeSupported(fs Fs, size int64) bool { + if fs.IsUploadResumeSupported() { + return true + } + return fs.IsConditionalUploadResumeSupported(size) +} + +func getLastModified(metadata map[string]string) int64 { + if val, ok := metadata[lastModifiedField]; ok && val != "" { + lastModified, err := strconv.ParseInt(val, 10, 64) + if err == nil { + return lastModified + } + } + return 0 +} + +func getAzureLastModified(metadata map[string]*string) int64 { + for k, v := range metadata { + if strings.EqualFold(k, lastModifiedField) { + if val := util.GetStringFromPointer(v); val != "" { + lastModified, err := strconv.ParseInt(val, 10, 64) + if err == nil { + return lastModified + } + } + return 0 + } + } + return 0 +} + +func validateOSFsConfig(config *sdk.OSFsConfig) error { + if config.ReadBufferSize < 0 || config.ReadBufferSize > 10 { + return fmt.Errorf("invalid read buffer size must be between 0 and 10 MB") + } + if config.WriteBufferSize < 0 || config.WriteBufferSize > 10 { + return fmt.Errorf("invalid write buffer size must be between 0 and 10 MB") + } + return nil +} + +func doCopy(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { + if buf == nil { + buf = make([]byte, 32768) + } + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write") + } + } + written += int64(nw) + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} + +func getMountPath(mountPath string) string { + if mountPath == "/" { + return "" + } + return mountPath +} + +func getLocalTempDir() string { + if tempPath != "" { + return tempPath + } + return filepath.Clean(os.TempDir()) +} + +func doRecursiveRename(fs Fs, source, target string, + renameFn func(string, string, os.FileInfo, int, bool) (int, int64, error), + recursion int, updateModTime bool, +) (int, int64, error) { + var numFiles int + var filesSize int64 + + if recursion > util.MaxRecursion { + return numFiles, filesSize, util.ErrRecursionTooDeep + } + recursion++ + + lister, err := fs.ReadDir(source) + if err != nil { + return numFiles, filesSize, err + } + defer lister.Close() + + for { + entries, err := lister.Next(ListerBatchSize) + finished := errors.Is(err, io.EOF) + if err != nil && !finished { + return numFiles, filesSize, err + } + for _, info := range entries { + sourceEntry := fs.Join(source, info.Name()) + targetEntry := fs.Join(target, info.Name()) + files, size, err := renameFn(sourceEntry, targetEntry, info, recursion, updateModTime) + if err != nil { + if fs.IsNotExist(err) { + fsLog(fs, logger.LevelInfo, "skipping rename for %q: %v", sourceEntry, err) + continue + } + return numFiles, filesSize, err + } + numFiles += files + filesSize += size + } + if finished { + return numFiles, filesSize, nil + } + } +} + +// copied from rclone +func readFill(r io.Reader, buf []byte) (n int, err error) { + var nn int + for n < len(buf) && err == nil { + nn, err = r.Read(buf[n:]) + n += nn + } + return n, err +} + +func writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) error { + written := 0 + for written < count { + n, err := w.WriteAt(buf[written:count], offset+int64(written)) + written += n + if err != nil { + return err + } + } + return nil +} + +type bytesReaderWrapper struct { + *bytes.Reader +} + +func (b *bytesReaderWrapper) Close() error { + return nil +} + +type bufferAllocator struct { + sync.Mutex + available [][]byte + bufferSize int + finalized bool +} + +func newBufferAllocator(size int) *bufferAllocator { + return &bufferAllocator{ + bufferSize: size, + finalized: false, + } +} + +func (b *bufferAllocator) getBuffer() []byte { + b.Lock() + defer b.Unlock() + + if len(b.available) > 0 { + var result []byte + + truncLength := len(b.available) - 1 + result = b.available[truncLength] + + b.available[truncLength] = nil + b.available = b.available[:truncLength] + + return result + } + + return make([]byte, b.bufferSize) +} + +func (b *bufferAllocator) releaseBuffer(buf []byte) { + b.Lock() + defer b.Unlock() + + if b.finalized || len(buf) != b.bufferSize { + return + } + + b.available = append(b.available, buf) +} + +func (b *bufferAllocator) free() { + b.Lock() + defer b.Unlock() + + b.available = nil + b.finalized = true +} + +func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) { + logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...) +} diff --git a/internal/webdavd/file.go b/internal/webdavd/file.go new file mode 100644 index 00000000..88cc6bde --- /dev/null +++ b/internal/webdavd/file.go @@ -0,0 +1,500 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webdavd + +import ( + "context" + "encoding/xml" + "errors" + "io" + "mime" + "net/http" + "os" + "path" + "slices" + "sync/atomic" + "time" + + "github.com/drakkan/webdav" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +var ( + errTransferAborted = errors.New("transfer aborted") + lastModifiedProps = []string{"Win32LastModifiedTime", "getlastmodified"} +) + +type webDavFile struct { + *common.BaseTransfer + writer io.WriteCloser + reader io.ReadCloser + info os.FileInfo + startOffset int64 + isFinished bool + readTried atomic.Bool +} + +func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter vfs.PipeWriter, pipeReader vfs.PipeReader) *webDavFile { + var writer io.WriteCloser + var reader io.ReadCloser + if baseTransfer.File != nil { + writer = baseTransfer.File + reader = baseTransfer.File + } else if pipeWriter != nil { + writer = pipeWriter + } else if pipeReader != nil { + reader = pipeReader + } + f := &webDavFile{ + BaseTransfer: baseTransfer, + writer: writer, + reader: reader, + isFinished: false, + startOffset: 0, + info: nil, + } + f.readTried.Store(false) + return f +} + +type webDavFileInfo struct { + os.FileInfo + Fs vfs.Fs + virtualPath string + fsPath string +} + +// ContentType implements webdav.ContentTyper interface +func (fi *webDavFileInfo) ContentType(_ context.Context) (string, error) { + extension := path.Ext(fi.virtualPath) + if ctype, ok := customMimeTypeMapping[extension]; ok { + return ctype, nil + } + if extension == "" || extension == ".dat" { + return "application/octet-stream", nil + } + contentType := mime.TypeByExtension(extension) + if contentType != "" { + return contentType, nil + } + contentType = mimeTypeCache.getMimeFromCache(extension) + if contentType != "" { + return contentType, nil + } + contentType, err := fi.Fs.GetMimeType(fi.fsPath) + if contentType != "" { + mimeTypeCache.addMimeToCache(extension, contentType) + return contentType, err + } + return "", webdav.ErrNotImplemented +} + +// Readdir reads directory entries from the handle +func (f *webDavFile) Readdir(_ int) ([]os.FileInfo, error) { + return nil, webdav.ErrNotImplemented +} + +// ReadDir implements the FileDirLister interface +func (f *webDavFile) ReadDir() (webdav.DirLister, error) { + if !f.Connection.User.HasPerm(dataprovider.PermListItems, f.GetVirtualPath()) { + return nil, f.Connection.GetPermissionDeniedError() + } + lister, err := f.Connection.ListDir(f.GetVirtualPath()) + if err != nil { + return nil, err + } + return &webDavDirLister{ + DirLister: lister, + fs: f.Fs, + virtualDirPath: f.GetVirtualPath(), + fsDirPath: f.GetFsPath(), + }, nil +} + +// Stat the handle +func (f *webDavFile) Stat() (os.FileInfo, error) { + if f.GetType() == common.TransferDownload && !f.Connection.User.HasPerm(dataprovider.PermListItems, path.Dir(f.GetVirtualPath())) { + return nil, f.Connection.GetPermissionDeniedError() + } + f.Lock() + errUpload := f.ErrTransfer + f.Unlock() + if f.GetType() == common.TransferUpload && errUpload == nil { + info := &webDavFileInfo{ + FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, f.BytesReceived.Load(), time.Now(), false), + Fs: f.Fs, + virtualPath: f.GetVirtualPath(), + fsPath: f.GetFsPath(), + } + return info, nil + } + info, err := f.Fs.Stat(f.GetFsPath()) + if err != nil { + return nil, f.Connection.GetFsError(f.Fs, err) + } + if vfs.IsCryptOsFs(f.Fs) { + info = f.Fs.(*vfs.CryptFs).ConvertFileInfo(info) + } + fi := &webDavFileInfo{ + FileInfo: info, + Fs: f.Fs, + virtualPath: f.GetVirtualPath(), + fsPath: f.GetFsPath(), + } + return fi, nil +} + +func (f *webDavFile) checkFirstRead() error { + if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) { + return f.Connection.GetPermissionDeniedError() + } + transferQuota := f.GetTransferQuota() + if !transferQuota.HasDownloadSpace() { + f.Connection.Log(logger.LevelInfo, "denying file read due to quota limits") + return f.Connection.GetReadQuotaExceededError() + } + if ok, policy := f.Connection.User.IsFileAllowed(f.GetVirtualPath()); !ok { + f.Connection.Log(logger.LevelWarn, "reading file %q is not allowed", f.GetVirtualPath()) + return f.Connection.GetErrorForDeniedFile(policy) + } + _, err := common.ExecutePreAction(f.Connection, common.OperationPreDownload, f.GetFsPath(), f.GetVirtualPath(), 0, 0) + if err != nil { + f.Connection.Log(logger.LevelDebug, "download for file %q denied by pre action: %v", f.GetVirtualPath(), err) + return f.Connection.GetPermissionDeniedError() + } + f.readTried.Store(true) + return nil +} + +// Read reads the contents to downloads. +func (f *webDavFile) Read(p []byte) (n int, err error) { + if f.AbortTransfer.Load() { + return 0, errTransferAborted + } + if !f.readTried.Load() { + if err := f.checkFirstRead(); err != nil { + return 0, err + } + } + f.Connection.UpdateLastActivity() + + // the file is read sequentially we don't need to check for concurrent reads and so + // lock the transfer while opening the remote file + if f.reader == nil { + if f.GetType() != common.TransferDownload { + f.TransferError(common.ErrOpUnsupported) + return 0, common.ErrOpUnsupported + } + file, r, cancelFn, e := f.Fs.Open(f.GetFsPath(), 0) + f.Lock() + if e == nil { + if file != nil { + f.File = file + f.writer = f.File + f.reader = f.File + } else if r != nil { + f.reader = r + } + f.SetCancelFn(cancelFn) + } + f.ErrTransfer = e + f.startOffset = 0 + f.Unlock() + if e != nil { + return 0, f.Connection.GetFsError(f.Fs, e) + } + } + + n, err = f.reader.Read(p) + f.BytesSent.Add(int64(n)) + if err == nil { + err = f.CheckRead() + } + if err != nil && err != io.EOF { + f.TransferError(err) + err = f.ConvertError(err) + return + } + f.HandleThrottle() + return +} + +// Write writes the uploaded contents. +func (f *webDavFile) Write(p []byte) (n int, err error) { + if f.AbortTransfer.Load() { + return 0, errTransferAborted + } + + f.Connection.UpdateLastActivity() + + n, err = f.writer.Write(p) + f.BytesReceived.Add(int64(n)) + + if err == nil { + err = f.CheckWrite() + } + if err != nil { + f.TransferError(err) + err = f.ConvertError(err) + return + } + f.HandleThrottle() + return +} + +func (f *webDavFile) updateStatInfo() error { + if f.info != nil { + return nil + } + info, err := f.Fs.Stat(f.GetFsPath()) + if err != nil { + return err + } + if vfs.IsCryptOsFs(f.Fs) { + info = f.Fs.(*vfs.CryptFs).ConvertFileInfo(info) + } + f.info = info + return nil +} + +func (f *webDavFile) updateTransferQuotaOnSeek() { + transferQuota := f.GetTransferQuota() + if transferQuota.HasSizeLimits() { + go func(ulSize, dlSize int64, user dataprovider.User) { + dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck + }(f.BytesReceived.Load(), f.BytesSent.Load(), f.Connection.User) + } +} + +func (f *webDavFile) checkFile() error { + if f.File == nil && vfs.FsOpenReturnsFile(f.Fs) { + file, _, _, err := f.Fs.Open(f.GetFsPath(), 0) + if err != nil { + f.Connection.Log(logger.LevelWarn, "could not open file %q for seeking: %v", + f.GetFsPath(), err) + f.TransferError(err) + return err + } + f.File = file + f.reader = file + f.writer = file + } + return nil +} + +func (f *webDavFile) seekFile(offset int64, whence int) (int64, error) { + ret, err := f.File.Seek(offset, whence) + if err != nil { + f.TransferError(err) + } + return ret, err +} + +// Seek sets the offset for the next Read or Write on the writer to offset, +// interpreted according to whence: 0 means relative to the origin of the file, +// 1 means relative to the current offset, and 2 means relative to the end. +// It returns the new offset and an error, if any. +func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { + f.Connection.UpdateLastActivity() + if err := f.checkFile(); err != nil { + return 0, err + } + if f.File != nil { + return f.seekFile(offset, whence) + } + if f.GetType() == common.TransferDownload { + readOffset := f.startOffset + f.BytesSent.Load() + if offset == 0 && readOffset == 0 { + switch whence { + case io.SeekStart: + return 0, nil + case io.SeekEnd: + if err := f.updateStatInfo(); err != nil { + return 0, err + } + return f.info.Size(), nil + } + } + + // close the reader and create a new one at startByte + if f.reader != nil { + f.reader.Close() //nolint:errcheck + f.reader = nil + } + startByte := int64(0) + f.BytesReceived.Store(0) + f.BytesSent.Store(0) + f.updateTransferQuotaOnSeek() + + switch whence { + case io.SeekStart: + startByte = offset + case io.SeekCurrent: + startByte = readOffset + offset + case io.SeekEnd: + if err := f.updateStatInfo(); err != nil { + f.TransferError(err) + return 0, err + } + startByte = f.info.Size() - offset + } + + _, r, cancelFn, err := f.Fs.Open(f.GetFsPath(), startByte) + + f.Lock() + if err == nil { + f.startOffset = startByte + f.reader = r + } + f.ErrTransfer = err + f.SetCancelFn(cancelFn) + f.Unlock() + + return startByte, err + } + return 0, common.ErrOpUnsupported +} + +// Close closes the open directory or the current transfer +func (f *webDavFile) Close() error { + if err := f.setFinished(); err != nil { + return err + } + err := f.closeIO() + if f.isTransfer() { + errBaseClose := f.BaseTransfer.Close() + if errBaseClose != nil { + err = errBaseClose + } + } else { + f.Connection.RemoveTransfer(f.BaseTransfer) + } + return f.Connection.GetFsError(f.Fs, err) +} + +func (f *webDavFile) closeIO() error { + var err error + if f.File != nil { + err = f.File.Close() + } else if f.writer != nil { + err = f.writer.Close() + f.Lock() + // we set ErrTransfer here so quota is not updated, in this case the uploads are atomic + if err != nil && f.ErrTransfer == nil { + f.ErrTransfer = err + } + f.Unlock() + } else if f.reader != nil { + err = f.reader.Close() + if metadater, ok := f.reader.(vfs.Metadater); ok { + f.SetMetadata(metadater.Metadata()) + } + } + return err +} + +func (f *webDavFile) setFinished() error { + f.Lock() + defer f.Unlock() + + if f.isFinished { + return common.ErrTransferClosed + } + f.isFinished = true + return nil +} + +func (f *webDavFile) isTransfer() bool { + if f.GetType() == common.TransferDownload { + return f.readTried.Load() + } + return true +} + +// DeadProps returns a copy of the dead properties held. +// We always return nil for now, we only support the last modification time +// and it is already included in "live" properties +func (f *webDavFile) DeadProps() (map[xml.Name]webdav.Property, error) { + return nil, nil +} + +// Patch patches the dead properties held. +// In our minimal implementation we just support Win32LastModifiedTime and +// getlastmodified to set the the modification time. +// We ignore any other property and just return an OK response if the patch sets +// the modification time, otherwise a Forbidden response +func (f *webDavFile) Patch(patches []webdav.Proppatch) ([]webdav.Propstat, error) { + resp := make([]webdav.Propstat, 0, len(patches)) + hasError := false + for _, patch := range patches { + status := http.StatusForbidden + pstat := webdav.Propstat{} + for _, p := range patch.Props { + if status == http.StatusForbidden && !hasError { + if !patch.Remove && slices.Contains(lastModifiedProps, p.XMLName.Local) { + parsed, err := parseTime(util.BytesToString(p.InnerXML)) + if err != nil { + f.Connection.Log(logger.LevelWarn, "unsupported last modification time: %q, err: %v", + p.InnerXML, err) + hasError = true + continue + } + attrs := &common.StatAttributes{ + Flags: common.StatAttrTimes, + Atime: parsed, + Mtime: parsed, + } + if err := f.Connection.SetStat(f.GetVirtualPath(), attrs); err != nil { + f.Connection.Log(logger.LevelWarn, "unable to set modification time for %q, err :%v", + f.GetVirtualPath(), err) + hasError = true + continue + } + status = http.StatusOK + } + } + pstat.Props = append(pstat.Props, webdav.Property{XMLName: p.XMLName}) + } + pstat.Status = status + resp = append(resp, pstat) + } + return resp, nil +} + +type webDavDirLister struct { + vfs.DirLister + fs vfs.Fs + virtualDirPath string + fsDirPath string +} + +func (l *webDavDirLister) Next(limit int) ([]os.FileInfo, error) { + files, err := l.DirLister.Next(limit) + for idx := range files { + info := files[idx] + files[idx] = &webDavFileInfo{ + FileInfo: info, + Fs: l.fs, + virtualPath: path.Join(l.virtualDirPath, info.Name()), + fsPath: l.fs.Join(l.fsDirPath, info.Name()), + } + } + return files, err +} diff --git a/internal/webdavd/handler.go b/internal/webdavd/handler.go new file mode 100644 index 00000000..2fc3ca1b --- /dev/null +++ b/internal/webdavd/handler.go @@ -0,0 +1,311 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webdavd + +import ( + "context" + "net/http" + "os" + "path" + "strconv" + "strings" + "time" + + "github.com/drakkan/webdav" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +// Connection details for a WebDav connection. +type Connection struct { + *common.BaseConnection + request *http.Request + rc *http.ResponseController +} + +func newConnection(conn *common.BaseConnection, w http.ResponseWriter, r *http.Request) *Connection { + rc := http.NewResponseController(w) + responseControllerDeadlines(rc, time.Time{}, time.Time{}) + return &Connection{ + BaseConnection: conn, + request: r, + rc: rc, + } +} + +func (c *Connection) getModificationTime() time.Time { + if c.request == nil { + return time.Time{} + } + if val := c.request.Header.Get("X-OC-Mtime"); val != "" { + if unixTime, err := strconv.ParseInt(val, 10, 64); err == nil { + return time.Unix(unixTime, 0) + } + } + return time.Time{} +} + +// GetClientVersion returns the connected client's version. +func (c *Connection) GetClientVersion() string { + if c.request != nil { + return c.request.UserAgent() + } + return "" +} + +// GetLocalAddress returns local connection address +func (c *Connection) GetLocalAddress() string { + return util.GetHTTPLocalAddress(c.request) +} + +// GetRemoteAddress returns the connected client's address +func (c *Connection) GetRemoteAddress() string { + if c.request != nil { + return c.request.RemoteAddr + } + return "" +} + +// Disconnect closes the active transfer +func (c *Connection) Disconnect() error { + if c.rc != nil { + responseControllerDeadlines(c.rc, time.Now().Add(5*time.Second), time.Now().Add(5*time.Second)) + } + return c.SignalTransfersAbort() +} + +// GetCommand returns the request method +func (c *Connection) GetCommand() string { + if c.request != nil { + return strings.ToUpper(c.request.Method) + } + return "" +} + +// Mkdir creates a directory using the connection filesystem +func (c *Connection) Mkdir(_ context.Context, name string, _ os.FileMode) error { + c.UpdateLastActivity() + + name = util.CleanPath(name) + return c.CreateDir(name, true) +} + +// Rename renames a file or a directory +func (c *Connection) Rename(_ context.Context, oldName, newName string) error { + c.UpdateLastActivity() + + oldName = util.CleanPath(oldName) + newName = util.CleanPath(newName) + + err := c.BaseConnection.Rename(oldName, newName) + if err == nil { + if mtime := c.getModificationTime(); !mtime.IsZero() { + attrs := &common.StatAttributes{ + Flags: common.StatAttrTimes, + Atime: mtime, + Mtime: mtime, + } + setStatErr := c.SetStat(newName, attrs) + c.Log(logger.LevelDebug, "mtime header found for %q, value: %s, err: %v", newName, mtime, setStatErr) + } + } + return err +} + +// Stat returns a FileInfo describing the named file/directory, or an error, +// if any happens +func (c *Connection) Stat(_ context.Context, name string) (os.FileInfo, error) { + c.UpdateLastActivity() + + name = util.CleanPath(name) + if !c.User.HasPerm(dataprovider.PermListItems, path.Dir(name)) { + return nil, c.GetPermissionDeniedError() + } + + fi, err := c.DoStat(name, 0, true) + if err != nil { + return nil, err + } + return fi, err +} + +// RemoveAll removes path and any children it contains. +// If the path does not exist, RemoveAll returns nil (no error). +func (c *Connection) RemoveAll(_ context.Context, name string) error { + c.UpdateLastActivity() + + name = util.CleanPath(name) + return c.BaseConnection.RemoveAll(name) +} + +// OpenFile opens the named file with specified flag. +// This method is used for uploads and downloads but also for Stat and Readdir +func (c *Connection) OpenFile(_ context.Context, name string, flag int, _ os.FileMode) (webdav.File, error) { + c.UpdateLastActivity() + + if err := common.Connections.IsNewTransferAllowed(c.User.Username); err != nil { + c.Log(logger.LevelInfo, "denying transfer due to count limits") + return nil, c.GetPermissionDeniedError() + } + + name = util.CleanPath(name) + fs, p, err := c.GetFsAndResolvedPath(name) + if err != nil { + return nil, err + } + + if flag == os.O_RDONLY || c.request.Method == "PROPPATCH" { + // Download, Stat, Readdir or simply open/close + return c.getFile(fs, p, name) + } + return c.putFile(fs, p, name) +} + +func (c *Connection) getFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) { + var cancelFn func() + + // we open the file when we receive the first read so we only open the file if necessary + baseTransfer := common.NewBaseTransfer(nil, c.BaseConnection, cancelFn, fsPath, fsPath, virtualPath, + common.TransferDownload, 0, 0, 0, 0, false, fs, c.GetTransferQuota()) + + return newWebDavFile(baseTransfer, nil, nil), nil +} + +func (c *Connection) putFile(fs vfs.Fs, fsPath, virtualPath string) (webdav.File, error) { + if ok, _ := c.User.IsFileAllowed(virtualPath); !ok { + c.Log(logger.LevelWarn, "writing file %q is not allowed", virtualPath) + return nil, c.GetPermissionDeniedError() + } + + filePath := fsPath + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + filePath = fs.GetAtomicUploadPath(fsPath) + } + + stat, statErr := fs.Lstat(fsPath) + if (statErr == nil && stat.Mode()&os.ModeSymlink != 0) || fs.IsNotExist(statErr) { + if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(virtualPath)) { + return nil, c.GetPermissionDeniedError() + } + return c.handleUploadToNewFile(fs, fsPath, filePath, virtualPath) + } + + if statErr != nil { + c.Log(logger.LevelError, "error performing file stat %q: %+v", fsPath, statErr) + return nil, c.GetFsError(fs, statErr) + } + + // This happen if we upload a file that has the same name of an existing directory + if stat.IsDir() { + c.Log(logger.LevelError, "attempted to open a directory for writing to: %q", fsPath) + return nil, c.GetOpUnsupportedError() + } + + if !c.User.HasPerm(dataprovider.PermOverwrite, path.Dir(virtualPath)) { + return nil, c.GetPermissionDeniedError() + } + + return c.handleUploadToExistingFile(fs, fsPath, filePath, stat.Size(), virtualPath) +} + +func (c *Connection) handleUploadToNewFile(fs vfs.Fs, resolvedPath, filePath, requestPath string) (webdav.File, error) { + diskQuota, transferQuota := c.HasSpace(true, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, common.ErrQuotaExceeded + } + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, 0, 0); err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, c.GetPermissionDeniedError() + } + file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.GetCreateChecks(requestPath, true, false)) + if err != nil { + c.Log(logger.LevelError, "error creating file %q: %+v", resolvedPath, err) + return nil, c.GetFsError(fs, err) + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + // we can get an error only for resume + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, 0, fs.IsUploadResumeSupported()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, 0, 0, maxWriteSize, 0, true, fs, transferQuota) + mtime := c.getModificationTime() + baseTransfer.SetTimes(resolvedPath, mtime, mtime) + + return newWebDavFile(baseTransfer, w, nil), nil +} + +func (c *Connection) handleUploadToExistingFile(fs vfs.Fs, resolvedPath, filePath string, fileSize int64, + requestPath string, +) (webdav.File, error) { + var err error + diskQuota, transferQuota := c.HasSpace(false, false, requestPath) + if !diskQuota.HasSpace || !transferQuota.HasUploadSpace() { + c.Log(logger.LevelInfo, "denying file write due to quota limits") + return nil, common.ErrQuotaExceeded + } + if _, err := common.ExecutePreAction(c.BaseConnection, common.OperationPreUpload, resolvedPath, requestPath, + fileSize, os.O_TRUNC); err != nil { + c.Log(logger.LevelDebug, "upload for file %q denied by pre action: %v", requestPath, err) + return nil, c.GetPermissionDeniedError() + } + + // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace + // will return false in this case and we deny the upload before + maxWriteSize, _ := c.GetMaxWriteSize(diskQuota, false, fileSize, fs.IsUploadResumeSupported()) + + if common.Config.IsAtomicUploadEnabled() && fs.IsAtomicUploadSupported() { + _, _, err = fs.Rename(resolvedPath, filePath, 0) + if err != nil { + c.Log(logger.LevelError, "error renaming existing file for atomic upload, source: %q, dest: %q, err: %+v", + resolvedPath, filePath, err) + return nil, c.GetFsError(fs, err) + } + } + + file, w, cancelFn, err := fs.Create(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, c.GetCreateChecks(requestPath, false, false)) + if err != nil { + c.Log(logger.LevelError, "error creating file %q: %+v", resolvedPath, err) + return nil, c.GetFsError(fs, err) + } + initialSize := int64(0) + truncatedSize := int64(0) // bytes truncated and not included in quota + if vfs.HasTruncateSupport(fs) { + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) + if err == nil { + dataprovider.UpdateUserFolderQuota(&vfolder, &c.User, 0, -fileSize, false) + } else { + dataprovider.UpdateUserQuota(&c.User, 0, -fileSize, false) //nolint:errcheck + } + } else { + initialSize = fileSize + truncatedSize = fileSize + } + + vfs.SetPathPermissions(fs, filePath, c.User.GetUID(), c.User.GetGID()) + + baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, filePath, requestPath, + common.TransferUpload, 0, initialSize, maxWriteSize, truncatedSize, false, fs, transferQuota) + mtime := c.getModificationTime() + baseTransfer.SetTimes(resolvedPath, mtime, mtime) + + return newWebDavFile(baseTransfer, w, nil), nil +} diff --git a/internal/webdavd/internal_test.go b/internal/webdavd/internal_test.go new file mode 100644 index 00000000..35f71035 --- /dev/null +++ b/internal/webdavd/internal_test.go @@ -0,0 +1,1808 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webdavd + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/drakkan/webdav" + "github.com/eikenb/pipeat" + "github.com/sftpgo/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" +) + +const ( + testFile = "test_dav_file" + webDavCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + webDavKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` + caCRT = `-----BEGIN CERTIFICATE----- +MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 +QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT +CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m +fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 +v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh +yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL +CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U +4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb +KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo +NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb +S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i +M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY +/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy +OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD +AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu +XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F +sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki +pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE +jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX +y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy +WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV +oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV +L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 +DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E +eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli +SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w +kK8u1YiiVenmoQ== +-----END CERTIFICATE-----` + caKey = `-----BEGIN RSA PRIVATE KEY----- +MIIJKgIBAAKCAgEA7WHW216mfi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7x +b64rkpdzx1aWetSiCrEyc3D1v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvD +OBUYZgtMqHZzpE6xRrqQ84zhyzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKz +n/2uEVt33qmO85WtN3RzbSqLCdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj +7B5P5MeamkkogwbExUjdHp3U4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZ +De67V/Q8iB2May1k7zBz1ZtbKF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmOb +cn8AIfH6smLQrn0C3cs7CYfoNlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLq +R/BJTbbyXUB0imne1u00fuzbS7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb7 +8x7ivdyXSF5LVQJ1JvhhWu6iM6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpB +P8d2jcRZVUVrSXGc2mAGuGOY/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2Z +xugCXULtRWJ9p4C9zUl40HEyOQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEA +AQKCAgEA4x0OoceG54ZrVxifqVaQd8qw3uRmUKUMIMdfuMlsdideeLO97ynmSlRY +00kGo/I4Lp6mNEjI9gUie9+uBrcUhri4YLcujHCH+YlNnCBDbGjwbe0ds9SLCWaa +KztZHMSlW5Q4Bqytgu+MpOnxSgqjlOk+vz9TcGFKVnUkHIkAcqKFJX8gOFxPZA/t +Ob1kJaz4kuv5W2Kur/ISKvQtvFvOtQeV0aJyZm8LqXnvS4cPI7yN4329NDU0HyDR +y/deqS2aqV4zII3FFqbz8zix/m1xtVQzWCugZGMKrz0iuJMfNeCABb8rRGc6GsZz ++465v/kobqgeyyneJ1s5rMFrLp2o+dwmnIVMNsFDUiN1lIZDHLvlgonaUO3IdTZc +9asamFWKFKUMgWqM4zB1vmUO12CKowLNIIKb0L+kf1ixaLLDRGf/f9vLtSHE+oyx +lATiS18VNA8+CGsHF6uXMRwf2auZdRI9+s6AAeyRISSbO1khyWKHo+bpOvmPAkDR +nknTjbYgkoZOV+mrsU5oxV8s6vMkuvA3rwFhT2gie8pokuACFcCRrZi9MVs4LmUQ +u0GYTHvp2WJUjMWBm6XX7Hk3g2HV842qpk/mdtTjNsXws81djtJPn4I/soIXSgXz +pY3SvKTuOckP9OZVF0yqKGeZXKpD288PKpC+MAg3GvEJaednagECggEBAPsfLwuP +L1kiDjXyMcRoKlrQ6Q/zBGyBmJbZ5uVGa02+XtYtDAzLoVupPESXL0E7+r8ZpZ39 +0dV4CEJKpbVS/BBtTEkPpTK5kz778Ib04TAyj+YLhsZjsnuja3T5bIBZXFDeDVDM +0ZaoFoKpIjTu2aO6pzngsgXs6EYbo2MTuJD3h0nkGZsICL7xvT9Mw0P1p2Ftt/hN ++jKk3vN220wTWUsq43AePi45VwK+PNP12ZXv9HpWDxlPo3j0nXtgYXittYNAT92u +BZbFAzldEIX9WKKZgsWtIzLaASjVRntpxDCTby/nlzQ5dw3DHU1DV3PIqxZS2+Oe +KV+7XFWgZ44YjYECggEBAPH+VDu3QSrqSahkZLkgBtGRkiZPkZFXYvU6kL8qf5wO +Z/uXMeqHtznAupLea8I4YZLfQim/NfC0v1cAcFa9Ckt9g3GwTSirVcN0AC1iOyv3 +/hMZCA1zIyIcuUplNr8qewoX71uPOvCNH0dix77423mKFkJmNwzy4Q+rV+qkRdLn +v+AAgh7g5N91pxNd6LQJjoyfi1Ka6rRP2yGXM5v7QOwD16eN4JmExUxX1YQ7uNuX +pVS+HRxnBquA+3/DB1LtBX6pa2cUa+LRUmE/NCPHMvJcyuNkYpJKlNTd9vnbfo0H +RNSJSWm+aGxDFMjuPjV3JLj2OdKMPwpnXdh2vBZCPpMCggEAM+yTvrEhmi2HgLIO +hkz/jP2rYyfdn04ArhhqPLgd0dpuI5z24+Jq/9fzZT9ZfwSW6VK1QwDLlXcXRhXH +Q8Hf6smev3CjuORURO61IkKaGWwrAucZPAY7ToNQ4cP9ImDXzMTNPgrLv3oMBYJR +V16X09nxX+9NABqnQG/QjdjzDc6Qw7+NZ9f2bvzvI5qMuY2eyW91XbtJ45ThoLfP +ymAp03gPxQwL0WT7z85kJ3OrROxzwaPvxU0JQSZbNbqNDPXmFTiECxNDhpRAAWlz +1DC5Vg2l05fkMkyPdtD6nOQWs/CYSfB5/EtxiX/xnBszhvZUIe6KFvuKFIhaJD5h +iykagQKCAQEAoBRm8k3KbTIo4ZzvyEq4V/+dF3zBRczx6FkCkYLygXBCNvsQiR2Y +BjtI8Ijz7bnQShEoOmeDriRTAqGGrspEuiVgQ1+l2wZkKHRe/aaij/Zv+4AuhH8q +uZEYvW7w5Uqbs9SbgQzhp2kjTNy6V8lVnjPLf8cQGZ+9Y9krwktC6T5m/i435WdN +38h7amNP4XEE/F86Eb3rDrZYtgLIoCF4E+iCyxMehU+AGH1uABhls9XAB6vvo+8/ +SUp8lEqWWLP0U5KNOtYWfCeOAEiIHDbUq+DYUc4BKtbtV1cx3pzlPTOWw6XBi5Lq +jttdL4HyYvnasAQpwe8GcMJqIRyCVZMiwwKCAQEAhQTTS3CC8PwcoYrpBdTjW1ck +vVFeF1YbfqPZfYxASCOtdx6wRnnEJ+bjqntagns9e88muxj9UhxSL6q9XaXQBD8+ +2AmKUxphCZQiYFZcTucjQEQEI2nN+nAKgRrUSMMGiR8Ekc2iFrcxBU0dnSohw+aB +PbMKVypQCREu9PcDFIp9rXQTeElbaNsIg1C1w/SQjODbmN/QFHTVbRODYqLeX1J/ +VcGsykSIq7hv6bjn7JGkr2JTdANbjk9LnMjMdJFsKRYxPKkOQfYred6Hiojp5Sor +PW5am8ejnNSPhIfqQp3uV3KhwPDKIeIpzvrB4uPfTjQWhekHCb8cKSWux3flqw== +-----END RSA PRIVATE KEY-----` + caCRL = `-----BEGIN X509 CRL----- +MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN +MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg +ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r +rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV ++jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 +XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE +jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF +DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD +MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 +iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE +ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 +hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 +cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 +sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn +w7bXJGfJcIMrrKs= +-----END X509 CRL-----` + client1Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 +HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE +OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 +Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf +XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO +c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo +Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU +OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC +sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I +l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb +BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz +MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 +QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz +J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ +offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX +G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr +kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au +MU3Bo0A= +-----END CERTIFICATE-----` + client1Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki +J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi +85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW +eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 +ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy +bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO +hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw +eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl +X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ +CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL +j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU +NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf +sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB +Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB +LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g +mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 +T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn +RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa +Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW +asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ +DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 +0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls +ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY +nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe +yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== +-----END RSA PRIVATE KEY-----` + // client 2 crt is revoked + client2Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b +Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua +eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ +cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT +A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 +6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f +Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz +PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY +xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D +EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB +4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO +78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 +7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e +qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N +f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv +/ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S +ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR +E0+r2+4= +-----END CERTIFICATE-----` + client2Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV +jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP +kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK +h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P +QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu +4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq +op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM +qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt +Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL +VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO +qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy +VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl +NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 +4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl +/owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd +aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu +khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz +3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ +eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i +vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB +GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb +Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D +8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl +RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA +ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl +-----END RSA PRIVATE KEY-----` + osWindows = "windows" +) + +// MockOsFs mockable OsFs +type MockOsFs struct { + vfs.Fs + err error + isAtomicUploadSupported bool + reader *pipeat.PipeReaderAt +} + +// Name returns the name for the Fs implementation +func (fs *MockOsFs) Name() string { + return "mockOsFs" +} + +// Open returns nil +func (fs *MockOsFs) Open(name string, offset int64) (vfs.File, vfs.PipeReader, func(), error) { + if fs.reader != nil { + return nil, vfs.NewPipeReader(fs.reader), nil, nil + } + return fs.Fs.Open(name, offset) +} + +// IsUploadResumeSupported returns true if resuming uploads is supported +func (*MockOsFs) IsUploadResumeSupported() bool { + return false +} + +// IsAtomicUploadSupported returns true if atomic upload is supported +func (fs *MockOsFs) IsAtomicUploadSupported() bool { + return fs.isAtomicUploadSupported +} + +// Remove removes the named file or (empty) directory. +func (fs *MockOsFs) Remove(name string, _ bool) error { + if fs.err != nil { + return fs.err + } + return os.Remove(name) +} + +// Rename renames (moves) source to target +func (fs *MockOsFs) Rename(source, target string, _ int) (int, int64, error) { + err := os.Rename(source, target) + return -1, -1, err +} + +// GetMimeType returns the content type +func (fs *MockOsFs) GetMimeType(_ string) (string, error) { + if fs.err != nil { + return "", fs.err + } + return "application/custom-mime", nil +} + +func newMockOsFs(atomicUpload bool, connectionID, rootDir string, reader *pipeat.PipeReaderAt, err error) vfs.Fs { + return &MockOsFs{ + Fs: vfs.NewOsFs(connectionID, rootDir, "", nil), + isAtomicUploadSupported: atomicUpload, + reader: reader, + err: err, + } +} + +func TestUserInvalidParams(t *testing.T) { + u := &dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: "username", + HomeDir: "invalid", + }, + } + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + }, + }, + } + + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", u.Username), nil) + assert.NoError(t, err) + + _, err = server.validateUser(u, req, dataprovider.LoginMethodPassword) + if assert.Error(t, err) { + assert.EqualError(t, err, fmt.Sprintf("cannot login user with invalid home dir: %q", u.HomeDir)) + } + + req.TLS = &tls.ConnectionState{} + writeLog(req, http.StatusOK, nil) +} + +func TestAllowedProxyUnixDomainSocket(t *testing.T) { + b := Binding{ + Address: filepath.Join(os.TempDir(), "sock"), + ProxyAllowed: []string{"127.0.0.1", "127.0.1.1"}, + } + err := b.parseAllowedProxy() + assert.NoError(t, err) + if assert.Len(t, b.allowHeadersFrom, 1) { + assert.True(t, b.allowHeadersFrom[0](nil)) + } +} + +func TestProxyListenerWrapper(t *testing.T) { + b := Binding{ + ProxyMode: 0, + } + require.Nil(t, b.listenerWrapper()) + b.ProxyMode = 1 + require.NotNil(t, b.listenerWrapper()) +} + +func TestRemoteAddress(t *testing.T) { + remoteAddr1 := "100.100.100.100" + remoteAddr2 := "172.172.172.172" + + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + ProxyAllowed: []string{remoteAddr2, "10.8.0.0/30"}, + }, + }, + } + + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + err := server.binding.parseAllowedProxy() + assert.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, "/", nil) + assert.NoError(t, err) + assert.Empty(t, req.RemoteAddr) + + trueClientIP := "True-Client-IP" + cfConnectingIP := "CF-Connecting-IP" + xff := "X-Forwarded-For" + xRealIP := "X-Real-IP" + + req.Header.Set(trueClientIP, remoteAddr1) + ip := util.GetRealIP(req, trueClientIP, 0) + assert.Equal(t, remoteAddr1, ip) + ip = util.GetRealIP(req, trueClientIP, 2) + assert.Empty(t, ip) + req.Header.Del(trueClientIP) + req.Header.Set(cfConnectingIP, remoteAddr1) + ip = util.GetRealIP(req, cfConnectingIP, 0) + assert.Equal(t, remoteAddr1, ip) + req.Header.Del(cfConnectingIP) + req.Header.Set(xff, remoteAddr1) + ip = util.GetRealIP(req, xff, 0) + assert.Equal(t, remoteAddr1, ip) + // this will be ignored, remoteAddr1 is not allowed to se this header + req.Header.Set(xff, remoteAddr2) + req.RemoteAddr = remoteAddr1 + ip = server.checkRemoteAddress(req) + assert.Equal(t, remoteAddr1, ip) + req.RemoteAddr = "" + ip = server.checkRemoteAddress(req) + assert.Empty(t, ip) + + req.Header.Set(xff, fmt.Sprintf("%v , %v", remoteAddr2, remoteAddr1)) + ip = util.GetRealIP(req, xff, 1) + assert.Equal(t, remoteAddr2, ip) + + req.RemoteAddr = remoteAddr2 + req.Header.Set(xff, fmt.Sprintf("%v,%v", "12.34.56.78", "172.16.2.4")) + server.binding.ClientIPHeaderDepth = 1 + server.binding.ClientIPProxyHeader = xff + ip = server.checkRemoteAddress(req) + assert.Equal(t, "12.34.56.78", ip) + assert.Equal(t, ip, req.RemoteAddr) + + req.RemoteAddr = remoteAddr2 + req.Header.Set(xff, fmt.Sprintf("%v,%v", "12.34.56.79", "172.16.2.5")) + server.binding.ClientIPHeaderDepth = 0 + ip = server.checkRemoteAddress(req) + assert.Equal(t, "172.16.2.5", ip) + assert.Equal(t, ip, req.RemoteAddr) + + req.RemoteAddr = "10.8.0.2" + req.Header.Set(xff, remoteAddr1) + ip = server.checkRemoteAddress(req) + assert.Equal(t, remoteAddr1, ip) + assert.Equal(t, ip, req.RemoteAddr) + + req.RemoteAddr = "10.8.0.3" + req.Header.Set(xff, "not an ip") + ip = server.checkRemoteAddress(req) + assert.Equal(t, "10.8.0.3", ip) + assert.Equal(t, ip, req.RemoteAddr) + + req.Header.Del(xff) + req.RemoteAddr = "" + req.Header.Set(xRealIP, remoteAddr1) + ip = util.GetRealIP(req, "x-real-ip", 0) + assert.Equal(t, remoteAddr1, ip) + req.RemoteAddr = "" +} + +func TestConnWithNilRequest(t *testing.T) { + c := &Connection{} + assert.Empty(t, c.GetClientVersion()) + assert.Empty(t, c.GetCommand()) + assert.Empty(t, c.GetRemoteAddress()) + assert.True(t, c.getModificationTime().IsZero()) +} + +func TestResolvePathErrors(t *testing.T) { + ctx := context.Background() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: "invalid", + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + + err := connection.Mkdir(ctx, "", os.ModePerm) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + + err = connection.Rename(ctx, "oldName", "newName") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + + _, err = connection.Stat(ctx, "name") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + + err = connection.RemoveAll(ctx, "") + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + + _, err = connection.OpenFile(ctx, "", 0, os.ModePerm) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + } + + if runtime.GOOS != osWindows { + user.HomeDir = filepath.Clean(os.TempDir()) + connection.User = user + fs := vfs.NewOsFs("connID", connection.User.HomeDir, "", nil) + subDir := "sub" + testTxtFile := "file.txt" + err = os.MkdirAll(filepath.Join(os.TempDir(), subDir, subDir), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(os.TempDir(), subDir, subDir, testTxtFile), []byte("content"), os.ModePerm) + assert.NoError(t, err) + err = os.Chmod(filepath.Join(os.TempDir(), subDir, subDir), 0001) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(os.TempDir(), testTxtFile), []byte("test content"), os.ModePerm) + assert.NoError(t, err) + err = connection.Rename(ctx, testTxtFile, path.Join(subDir, subDir, testTxtFile)) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrPermissionDenied.Error()) + } + _, err = connection.putFile(fs, filepath.Join(connection.User.HomeDir, subDir, subDir, testTxtFile), + path.Join(subDir, subDir, testTxtFile)) + if assert.Error(t, err) { + assert.EqualError(t, err, common.ErrPermissionDenied.Error()) + } + err = os.Chmod(filepath.Join(os.TempDir(), subDir, subDir), os.ModePerm) + assert.NoError(t, err) + err = os.RemoveAll(filepath.Join(os.TempDir(), subDir)) + assert.NoError(t, err) + err = os.Remove(filepath.Join(os.TempDir(), testTxtFile)) + assert.NoError(t, err) + } +} + +func TestFileAccessErrors(t *testing.T) { + ctx := context.Background() + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + missingPath := "missing path" + fsMissingPath := filepath.Join(user.HomeDir, missingPath) + err := connection.RemoveAll(ctx, missingPath) + assert.ErrorIs(t, err, os.ErrNotExist) + davFile, err := connection.getFile(fs, fsMissingPath, missingPath) + assert.NoError(t, err) + buf := make([]byte, 64) + _, err = davFile.Read(buf) + assert.ErrorIs(t, err, os.ErrNotExist) + err = davFile.Close() + assert.ErrorIs(t, err, os.ErrNotExist) + p := filepath.Join(user.HomeDir, "adir", missingPath) + _, err = connection.handleUploadToNewFile(fs, p, p, path.Join("adir", missingPath)) + assert.ErrorIs(t, err, os.ErrNotExist) + _, err = connection.handleUploadToExistingFile(fs, p, "_"+p, 0, path.Join("adir", missingPath)) + if assert.Error(t, err) { + assert.ErrorIs(t, err, os.ErrNotExist) + } + + fs = newMockOsFs(false, fs.ConnectionID(), user.HomeDir, nil, nil) + _, err = connection.handleUploadToExistingFile(fs, p, p, 0, path.Join("adir", missingPath)) + assert.ErrorIs(t, err, os.ErrNotExist) + + f, err := os.CreateTemp("", "temp") + assert.NoError(t, err) + err = f.Close() + assert.NoError(t, err) + davFile, err = connection.handleUploadToExistingFile(fs, f.Name(), f.Name(), 123, f.Name()) + if assert.NoError(t, err) { + transfer := davFile.(*webDavFile) + transfers := connection.GetTransfers() + if assert.Equal(t, 1, len(transfers)) { + assert.Equal(t, transfers[0].ID, transfer.GetID()) + assert.Equal(t, int64(123), transfer.InitialSize) + err = transfer.Close() + assert.NoError(t, err) + assert.Equal(t, 0, len(connection.GetTransfers())) + } + // test PROPPATCH date parsing error + pstats, err := transfer.Patch([]webdav.Proppatch{ + { + Props: []webdav.Property{ + { + XMLName: xml.Name{ + Space: "DAV", + Local: "getlastmodified", + }, + InnerXML: []byte(`Wid, 04 Nov 2020 13:25:51 GMT`), + }, + }, + }, + }) + assert.NoError(t, err) + for _, pstat := range pstats { + assert.Equal(t, http.StatusForbidden, pstat.Status) + } + + err = os.Remove(f.Name()) + assert.NoError(t, err) + // the file is deleted PROPPATCH should fail + pstats, err = transfer.Patch([]webdav.Proppatch{ + { + Props: []webdav.Property{ + { + XMLName: xml.Name{ + Space: "DAV", + Local: "getlastmodified", + }, + InnerXML: []byte(`Wed, 04 Nov 2020 13:25:51 GMT`), + }, + }, + }, + }) + assert.NoError(t, err) + for _, pstat := range pstats { + assert.Equal(t, http.StatusForbidden, pstat.Status) + } + } +} + +func TestCheckRequestMethodWithPrefix(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + Permissions: map[string][]string{ + "/": {dataprovider.PermAny}, + }, + }, + } + fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + server := webDavServer{ + binding: Binding{ + Prefix: "/dav", + }, + } + req, err := http.NewRequest(http.MethodGet, "/../dav", nil) + require.NoError(t, err) + server.checkRequestMethod(context.Background(), req, connection) + require.Equal(t, "PROPFIND", req.Method) + require.Equal(t, "1", req.Header.Get("Depth")) +} + +func TestContentType(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + testFilePath := filepath.Join(user.HomeDir, testFile) + ctx := context.Background() + baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown", + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + fs = newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir(), nil, nil) + err := os.WriteFile(testFilePath, []byte(""), os.ModePerm) + assert.NoError(t, err) + davFile := newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = fs + fi, err := davFile.Stat() + if assert.NoError(t, err) { + ctype, err := fi.(*webDavFileInfo).ContentType(ctx) + assert.NoError(t, err) + assert.Equal(t, "application/custom-mime", ctype) + } + _, err = davFile.Readdir(-1) + assert.ErrorIs(t, err, webdav.ErrNotImplemented) + _, err = davFile.ReadDir() + assert.Error(t, err) + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown1", + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "", nil) + fi, err = davFile.Stat() + if assert.NoError(t, err) { + ctype, err := fi.(*webDavFileInfo).ContentType(ctx) + assert.NoError(t, err) + assert.Equal(t, "text/plain; charset=utf-8", ctype) + } + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "", nil) + fi, err = davFile.Stat() + if assert.NoError(t, err) { + ctype, err := fi.(*webDavFileInfo).ContentType(ctx) + assert.NoError(t, err) + assert.Equal(t, "application/octet-stream", ctype) + } + err = davFile.Close() + assert.NoError(t, err) + + for i := 0; i < 2; i++ { + // the second time the cache will be used + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".custom", + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = vfs.NewOsFs("id", user.HomeDir, "", nil) + fi, err = davFile.Stat() + if assert.NoError(t, err) { + ctype, err := fi.(*webDavFileInfo).ContentType(ctx) + assert.NoError(t, err) + assert.Equal(t, "text/plain; charset=utf-8", ctype) + } + err = davFile.Close() + assert.NoError(t, err) + } + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".sftpgo", + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + fs = newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir(), nil, os.ErrInvalid) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = fs + fi, err = davFile.Stat() + if assert.NoError(t, err) { + ctype, err := fi.(*webDavFileInfo).ContentType(ctx) + assert.NoError(t, err) + assert.Equal(t, "application/sftpgo", ctype) + } + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile+".unknown2", + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + fs = newMockOsFs(false, fs.ConnectionID(), user.GetHomeDir(), nil, os.ErrInvalid) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = fs + fi, err = davFile.Stat() + if assert.NoError(t, err) { + ctype, err := fi.(*webDavFileInfo).ContentType(ctx) + assert.EqualError(t, err, webdav.ErrNotImplemented.Error(), "unexpected content type %q", ctype) + } + cache := mimeCache{ + maxSize: 10, + mimeTypes: map[string]string{}, + } + cache.addMimeToCache("", "") + cache.RLock() + assert.Len(t, cache.mimeTypes, 0) + cache.RUnlock() + + err = os.Remove(testFilePath) + assert.NoError(t, err) +} + +func TestTransferReadWriteErrors(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := vfs.NewOsFs("connID", user.HomeDir, "", nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + testFilePath := filepath.Join(user.HomeDir, testFile) + baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + davFile := newWebDavFile(baseTransfer, nil, nil) + p := make([]byte, 1) + _, err := davFile.Read(p) + assert.EqualError(t, err, common.ErrOpUnsupported.Error()) + + r, w, err := pipeat.Pipe() + assert.NoError(t, err) + davFile = newWebDavFile(baseTransfer, nil, vfs.NewPipeReader(r)) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + err = r.Close() + assert.NoError(t, err) + err = w.Close() + assert.NoError(t, err) + err = davFile.BaseTransfer.Close() + assert.Error(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + davFile = newWebDavFile(baseTransfer, nil, nil) + _, err = davFile.Read(p) + assert.True(t, fs.IsNotExist(err)) + _, err = davFile.Stat() + assert.True(t, fs.IsNotExist(err)) + err = davFile.Close() + assert.Error(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + err = os.WriteFile(testFilePath, []byte(""), os.ModePerm) + assert.NoError(t, err) + f, err := os.Open(testFilePath) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + } + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.reader = f + err = davFile.Close() + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + err = davFile.Close() + assert.EqualError(t, err, common.ErrTransferClosed.Error()) + _, err = davFile.Read(p) + assert.Error(t, err) + info, err := davFile.Stat() + if assert.NoError(t, err) { + assert.Equal(t, int64(0), info.Size()) + } + err = davFile.Close() + assert.Error(t, err) + + r, w, err = pipeat.Pipe() + assert.NoError(t, err) + mockFs := newMockOsFs(false, fs.ConnectionID(), user.HomeDir, r, nil) + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, mockFs, dataprovider.TransferQuota{}) + davFile = newWebDavFile(baseTransfer, nil, nil) + + writeContent := []byte("content\r\n") + go func() { + n, err := w.Write(writeContent) + assert.NoError(t, err) + assert.Equal(t, len(writeContent), n) + err = w.Close() + assert.NoError(t, err) + }() + + p = make([]byte, 64) + n, err := davFile.Read(p) + assert.EqualError(t, err, io.EOF.Error()) + assert.Equal(t, len(writeContent), n) + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{}) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.writer = f + err = davFile.Close() + assert.EqualError(t, err, common.ErrGenericFailure.Error()) + + err = os.Remove(testFilePath) + assert.NoError(t, err) +} + +func TestTransferSeek(t *testing.T) { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Clean(os.TempDir()), + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = []string{dataprovider.PermAny} + fs := newMockOsFs(true, "connID", user.HomeDir, nil, nil) + connection := &Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + testFilePath := filepath.Join(user.HomeDir, testFile) + testFileContents := []byte("content") + baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferUpload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile := newWebDavFile(baseTransfer, nil, nil) + _, err := davFile.Seek(0, io.SeekStart) + assert.EqualError(t, err, common.ErrOpUnsupported.Error()) + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile = newWebDavFile(baseTransfer, nil, nil) + _, err = davFile.Seek(0, io.SeekCurrent) + assert.True(t, fs.IsNotExist(err)) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + + err = os.WriteFile(testFilePath, testFileContents, os.ModePerm) + assert.NoError(t, err) + f, err := os.Open(testFilePath) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + } + baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile = newWebDavFile(baseTransfer, nil, nil) + _, err = davFile.Seek(0, io.SeekStart) + assert.Error(t, err) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile = newWebDavFile(baseTransfer, nil, nil) + res, err := davFile.Seek(0, io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(0), res) + err = davFile.Close() + assert.NoError(t, err) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + + davFile = newWebDavFile(baseTransfer, nil, nil) + res, err = davFile.Seek(0, io.SeekEnd) + assert.NoError(t, err) + assert.Equal(t, int64(len(testFileContents)), res) + err = davFile.updateStatInfo() + assert.NoError(t, err) + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile = newWebDavFile(baseTransfer, nil, nil) + _, err = davFile.Seek(0, io.SeekEnd) + assert.True(t, fs.IsNotExist(err)) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + + fs = vfs.NewOsFs(fs.ConnectionID(), user.GetHomeDir(), "", nil) + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile = newWebDavFile(baseTransfer, nil, nil) + _, err = davFile.Seek(0, io.SeekEnd) + assert.True(t, fs.IsNotExist(err)) + davFile.Connection.RemoveTransfer(davFile.BaseTransfer) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFilePath, testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.reader = f + r, _, err := pipeat.Pipe() + assert.NoError(t, err) + davFile.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), r, nil) + res, err = davFile.Seek(2, io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(2), res) + err = davFile.Close() + assert.NoError(t, err) + + r, _, err = pipeat.Pipe() + assert.NoError(t, err) + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), r, nil) + res, err = davFile.Seek(2, io.SeekEnd) + assert.NoError(t, err) + assert.Equal(t, int64(5), res) + err = davFile.Close() + assert.NoError(t, err) + + baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFilePath+"1", testFile, + common.TransferDownload, 0, 0, 0, 0, false, fs, dataprovider.TransferQuota{AllowedTotalSize: 100}) + + davFile = newWebDavFile(baseTransfer, nil, nil) + davFile.Fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), nil, nil) + res, err = davFile.Seek(2, io.SeekEnd) + assert.True(t, fs.IsNotExist(err)) + assert.Equal(t, int64(0), res) + err = davFile.Close() + assert.NoError(t, err) + + assert.Len(t, common.Connections.GetStats(""), 0) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + err = os.Remove(testFilePath) + assert.NoError(t, err) +} + +func TestBasicUsersCache(t *testing.T) { + username := "webdav_internal_test" + password := "pwd" + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: password, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + ExpirationDate: 0, + }, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + err := dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + }, + }, + Cache: Cache{ + Users: UsersCacheConfig{ + MaxSize: 50, + ExpirationTime: 1, + }, + }, + } + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) + assert.NoError(t, err) + + ipAddr := "127.0.0.1" + + _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled + assert.Error(t, err) + + now := time.Now() + req.SetBasicAuth(username, password) + _, isCached, _, loginMethod, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + // now the user should be cached + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.False(t, cachedUser.IsExpired()) + assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) + // authenticate must return the cached user now + authUser, isCached, _, _, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.True(t, isCached) + assert.Equal(t, cachedUser.User, authUser) + } + // a wrong password must fail + req.SetBasicAuth(username, "wrong") + _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled + assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error()) + req.SetBasicAuth(username, password) + + // force cached user expiration + cachedUser.Expiration = now + dataprovider.CacheWebDAVUser(cachedUser) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.True(t, cachedUser.IsExpired()) + } + // now authenticate should get the user from the data provider and update the cache + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.False(t, cachedUser.IsExpired()) + } + // cache is not invalidated after a user modification if the fs does not change + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.True(t, ok) + folderName := "testFolder" + f := &vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), "mapped"), + } + err = dataprovider.AddFolder(f, "", "", "") + assert.NoError(t, err) + user.VirtualFolders = append(user.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vdir", + }) + + err = dataprovider.UpdateUser(&user, "", "", "") + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.False(t, ok) + + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.True(t, ok) + // cache is invalidated after user deletion + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.False(t, ok) + + err = dataprovider.DeleteFolder(folderName, "", "", "") + assert.NoError(t, err) + + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) +} + +func TestCachedUserWithFolders(t *testing.T) { + username := "webdav_internal_folder_test" + password := "dav_pwd" + folderName := "test_folder" + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: password, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + ExpirationDate: 0, + }, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vpath", + }) + f := &vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: filepath.Join(os.TempDir(), folderName), + } + err := dataprovider.AddFolder(f, "", "", "") + assert.NoError(t, err) + err = dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + }, + }, + Cache: Cache{ + Users: UsersCacheConfig{ + MaxSize: 50, + ExpirationTime: 1, + }, + }, + } + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil) + assert.NoError(t, err) + + ipAddr := "127.0.0.1" + + _, _, _, _, err = server.authenticate(req, ipAddr) //nolint:dogsled + assert.Error(t, err) + + now := time.Now() + req.SetBasicAuth(username, password) + _, isCached, _, loginMethod, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + // now the user should be cached + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.False(t, cachedUser.IsExpired()) + assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute))) + // authenticate must return the cached user now + authUser, isCached, _, _, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.True(t, isCached) + assert.Equal(t, cachedUser.User, authUser) + } + + folder, err := dataprovider.GetFolderByName(folderName) + assert.NoError(t, err) + // updating a used folder should invalidate the cache only if the fs changed + err = dataprovider.UpdateFolder(&folder, folder.Users, folder.Groups, "", "", "") + assert.NoError(t, err) + + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.True(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.False(t, cachedUser.IsExpired()) + } + // changing the folder path should invalidate the cache + folder.MappedPath = filepath.Join(os.TempDir(), "anotherpath") + err = dataprovider.UpdateFolder(&folder, folder.Users, folder.Groups, "", "", "") + assert.NoError(t, err) + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.False(t, cachedUser.IsExpired()) + } + + err = dataprovider.DeleteFolder(folderName, "", "", "") + assert.NoError(t, err) + // removing a used folder should invalidate the cache + _, isCached, _, loginMethod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMethod) + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.False(t, cachedUser.IsExpired()) + } + + err = dataprovider.DeleteUser(user.Username, "", "", "") + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.False(t, ok) + + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) + + err = os.RemoveAll(folder.MappedPath) + assert.NoError(t, err) +} + +func TestUsersCacheSizeAndExpiration(t *testing.T) { + username := "webdav_internal_test" + password := "pwd" + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + ExpirationDate: 0, + }, + } + u.Username = username + "1" + u.Password = password + "1" + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + err := dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user1, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + u.Username = username + "2" + u.Password = password + "2" + err = dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user2, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + u.Username = username + "3" + u.Password = password + "3" + err = dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user3, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + u.Username = username + "4" + u.Password = password + "4" + err = dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user4, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + }, + }, + Cache: Cache{ + Users: UsersCacheConfig{ + MaxSize: 3, + ExpirationTime: 1, + }, + }, + } + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + + ipAddr := "127.0.1.1" + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user1.Username, password+"1") + _, isCached, _, loginMehod, err := server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user2.Username, password+"2") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user3.Username, password+"3") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + + // the first 3 users are now cached + _, ok := dataprovider.GetCachedWebDAVUser(user1.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) + assert.True(t, ok) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user4.Username, password+"4") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + // user1, the first cached, should be removed now + _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) + assert.False(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) + assert.True(t, ok) + + // a sleep ensures that expiration times are different + time.Sleep(20 * time.Millisecond) + // user1 logins, user2 should be removed + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user1.Username, password+"1") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) + assert.False(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) + assert.True(t, ok) + + // a sleep ensures that expiration times are different + time.Sleep(20 * time.Millisecond) + // user2 logins, user3 should be removed + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user2.Username, password+"2") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) + assert.False(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) + assert.True(t, ok) + + // a sleep ensures that expiration times are different + time.Sleep(20 * time.Millisecond) + // user3 logins, user4 should be removed + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user3.Username, password+"3") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) + assert.False(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) + assert.True(t, ok) + + // now remove user1 after an update + user1.HomeDir += "_mod" + err = dataprovider.UpdateUser(&user1, "", "", "") + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) + assert.False(t, ok) + + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user4.Username, password+"4") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + + // a sleep ensures that expiration times are different + time.Sleep(20 * time.Millisecond) + req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil) + assert.NoError(t, err) + req.SetBasicAuth(user1.Username, password+"1") + _, isCached, _, loginMehod, err = server.authenticate(req, ipAddr) + assert.NoError(t, err) + assert.False(t, isCached) + assert.Equal(t, dataprovider.LoginMethodPassword, loginMehod) + _, ok = dataprovider.GetCachedWebDAVUser(user2.Username) + assert.False(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user1.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user3.Username) + assert.True(t, ok) + _, ok = dataprovider.GetCachedWebDAVUser(user4.Username) + assert.True(t, ok) + + err = dataprovider.DeleteUser(user1.Username, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteUser(user2.Username, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteUser(user3.Username, "", "", "") + assert.NoError(t, err) + err = dataprovider.DeleteUser(user4.Username, "", "", "") + assert.NoError(t, err) + + err = os.RemoveAll(u.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUserCacheIsolation(t *testing.T) { + dataprovider.InitializeWebDAVUserCache(10) + username := "webdav_internal_cache_test" + password := "dav_pwd" + u := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: username, + Password: password, + HomeDir: filepath.Join(os.TempDir(), username), + Status: 1, + ExpirationDate: 0, + }, + } + u.Permissions = make(map[string][]string) + u.Permissions["/"] = []string{dataprovider.PermAny} + err := dataprovider.AddUser(&u, "", "", "") + assert.NoError(t, err) + user, err := dataprovider.UserExists(u.Username, "") + assert.NoError(t, err) + cachedUser := &dataprovider.CachedUser{ + User: user, + Expiration: time.Now().Add(24 * time.Hour), + Password: password, + LockSystem: webdav.NewMemLS(), + } + cachedUser.User.FsConfig.S3Config.AccessSecret = kms.NewPlainSecret("test secret") + cachedUser.User.FsConfig.S3Config.SSECustomerKey = kms.NewPlainSecret("test key") + err = cachedUser.User.FsConfig.S3Config.AccessSecret.Encrypt() + assert.NoError(t, err) + err = cachedUser.User.FsConfig.S3Config.SSECustomerKey.Encrypt() + assert.NoError(t, err) + dataprovider.CacheWebDAVUser(cachedUser) + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) + + if assert.True(t, ok) { + _, err = cachedUser.User.GetFilesystem("") + assert.NoError(t, err) + // the filesystem is now cached + } + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.True(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted()) + err = cachedUser.User.FsConfig.S3Config.AccessSecret.Decrypt() + assert.NoError(t, err) + assert.True(t, cachedUser.User.FsConfig.S3Config.SSECustomerKey.IsEncrypted()) + err = cachedUser.User.FsConfig.S3Config.SSECustomerKey.Decrypt() + assert.NoError(t, err) + cachedUser.User.FsConfig.Provider = sdk.S3FilesystemProvider + _, err = cachedUser.User.GetFilesystem("") + assert.Error(t, err, "we don't have to get the previously cached filesystem!") + } + cachedUser, ok = dataprovider.GetCachedWebDAVUser(username) + if assert.True(t, ok) { + assert.Equal(t, sdk.LocalFilesystemProvider, cachedUser.User.FsConfig.Provider) + assert.False(t, cachedUser.User.FsConfig.S3Config.AccessSecret.IsEncrypted()) + assert.False(t, cachedUser.User.FsConfig.S3Config.SSECustomerKey.IsEncrypted()) + } + + err = dataprovider.DeleteUser(username, "", "", "") + assert.NoError(t, err) + _, ok = dataprovider.GetCachedWebDAVUser(username) + assert.False(t, ok) +} + +func TestRecoverer(t *testing.T) { + c := &Configuration{ + Bindings: []Binding{ + { + Port: 9000, + }, + }, + } + server := webDavServer{ + config: c, + binding: c.Bindings[0], + } + rr := httptest.NewRecorder() + server.ServeHTTP(rr, nil) + assert.Equal(t, http.StatusInternalServerError, rr.Code) +} + +func TestMimeCache(t *testing.T) { + cache := mimeCache{ + maxSize: 0, + mimeTypes: make(map[string]string), + } + cache.addMimeToCache(".zip", "application/zip") + mtype := cache.getMimeFromCache(".zip") + assert.Equal(t, "", mtype) + cache.maxSize = 1 + cache.addMimeToCache(".zip", "application/zip") + mtype = cache.getMimeFromCache(".zip") + assert.Equal(t, "application/zip", mtype) + cache.addMimeToCache(".jpg", "image/jpeg") + mtype = cache.getMimeFromCache(".jpg") + assert.Equal(t, "", mtype) +} + +func TestVerifyTLSConnection(t *testing.T) { + oldCertMgr := certMgr + + caCrlPath := filepath.Join(os.TempDir(), "testcrl.crt") + certPath := filepath.Join(os.TempDir(), "test.crt") + keyPath := filepath.Join(os.TempDir(), "test.key") + err := os.WriteFile(caCrlPath, []byte(caCRL), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(certPath, []byte(webDavCert), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(keyPath, []byte(webDavKey), os.ModePerm) + assert.NoError(t, err) + + keyPairs := []common.TLSKeyPair{ + { + Cert: certPath, + Key: keyPath, + ID: common.DefaultTLSKeyPaidID, + }, + } + certMgr, err = common.NewCertManager(keyPairs, "", "webdav_test") + assert.NoError(t, err) + + certMgr.SetCARevocationLists([]string{caCrlPath}) + err = certMgr.LoadCRLs() + assert.NoError(t, err) + + crt, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + x509crt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + server := webDavServer{} + state := tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{x509crt}, + } + + err = server.verifyTLSConnection(state) + assert.Error(t, err) // no verified certification chain + + crt, err = tls.X509KeyPair([]byte(caCRT), []byte(caKey)) + assert.NoError(t, err) + + x509CAcrt, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crt, x509CAcrt}) + err = server.verifyTLSConnection(state) + assert.NoError(t, err) + + crt, err = tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + x509crtRevoked, err := x509.ParseCertificate(crt.Certificate[0]) + assert.NoError(t, err) + + state.VerifiedChains = append(state.VerifiedChains, []*x509.Certificate{x509crtRevoked, x509CAcrt}) + state.PeerCertificates = []*x509.Certificate{x509crtRevoked} + err = server.verifyTLSConnection(state) + assert.EqualError(t, err, common.ErrCrtRevoked.Error()) + + err = os.Remove(caCrlPath) + assert.NoError(t, err) + err = os.Remove(certPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + + certMgr = oldCertMgr +} + +func TestMisc(t *testing.T) { + oldCertMgr := certMgr + + certMgr = nil + err := ReloadCertificateMgr() + assert.Nil(t, err) + val := getConfigPath("", ".") + assert.Empty(t, val) + + certMgr = oldCertMgr +} + +func TestParseTime(t *testing.T) { + res, err := parseTime("Sat, 4 Feb 2023 17:00:50 GMT") + require.NoError(t, err) + require.Equal(t, int64(1675530050), res.Unix()) + res, err = parseTime("Wed, 04 Nov 2020 13:25:51 GMT") + require.NoError(t, err) + require.Equal(t, int64(1604496351), res.Unix()) +} + +func TestConfigsFromProvider(t *testing.T) { + configDir := "." + err := dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) + c := Configuration{ + Bindings: []Binding{ + { + Port: 1234, + }, + }, + } + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + configs := dataprovider.Configs{ + ACME: &dataprovider.ACMEConfigs{ + Domain: "domain.com", + Email: "info@domain.com", + HTTP01Challenge: dataprovider.ACMEHTTP01Challenge{Port: 80}, + Protocols: 7, + }, + } + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + util.CertsBasePath = "" + // crt and key empty + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + util.CertsBasePath = filepath.Clean(os.TempDir()) + // crt not found + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs := c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + crtPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".crt") + err = os.WriteFile(crtPath, nil, 0666) + assert.NoError(t, err) + // key not found + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + keyPath := filepath.Join(util.CertsBasePath, util.SanitizeDomain(configs.ACME.Domain)+".key") + err = os.WriteFile(keyPath, nil, 0666) + assert.NoError(t, err) + // acme cert used + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Equal(t, configs.ACME.Domain, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 1) + assert.True(t, c.Bindings[0].EnableHTTPS) + // protocols does not match + configs.ACME.Protocols = 3 + err = dataprovider.UpdateConfigs(&configs, "", "", "") + assert.NoError(t, err) + c.acmeDomain = "" + err = c.loadFromProvider() + assert.NoError(t, err) + assert.Empty(t, c.acmeDomain) + keyPairs = c.getKeyPairs(configDir) + assert.Len(t, keyPairs, 0) + + err = os.Remove(crtPath) + assert.NoError(t, err) + err = os.Remove(keyPath) + assert.NoError(t, err) + util.CertsBasePath = "" + err = dataprovider.UpdateConfigs(nil, "", "", "") + assert.NoError(t, err) +} + +func TestGetCacheExpirationTime(t *testing.T) { + c := UsersCacheConfig{} + assert.True(t, c.getExpirationTime().IsZero()) + c.ExpirationTime = 1 + assert.False(t, c.getExpirationTime().IsZero()) +} + +func TestBindingGetAddress(t *testing.T) { + tests := []struct { + name string + binding Binding + want string + }{ + { + name: "IP address with port", + binding: Binding{Address: "127.0.0.1", Port: 8080}, + want: "127.0.0.1:8080", + }, + { + name: "Unix socket path (no port)", + binding: Binding{Address: "/tmp/app.sock", Port: 0}, + want: "/tmp/app.sock", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.binding.GetAddress(); got != tt.want { + t.Errorf("GetAddress() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBindingIsValid(t *testing.T) { + tests := []struct { + name string + binding Binding + want bool + }{ + { + name: "Valid: Positive port", + binding: Binding{Address: "127.0.0.1", Port: 10080}, + want: true, + }, + { + name: "Valid: Absolute path on Unix (non-Windows)", + binding: Binding{Address: "/var/run/app.sock", Port: 0}, + // This test outcome is dynamic based on the OS + want: runtime.GOOS != osWindows, + }, + { + name: "Invalid: Port 0 and relative path", + binding: Binding{Address: "relative/path", Port: 0}, + want: false, + }, + { + name: "Invalid: Empty address and port 0", + binding: Binding{Address: "", Port: 0}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.binding.IsValid(); got != tt.want { + t.Errorf("IsValid() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/webdavd/mimecache.go b/internal/webdavd/mimecache.go new file mode 100644 index 00000000..3fa3d2b9 --- /dev/null +++ b/internal/webdavd/mimecache.go @@ -0,0 +1,52 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webdavd + +import "sync" + +type mimeCache struct { + maxSize int + sync.RWMutex + mimeTypes map[string]string +} + +var ( + mimeTypeCache mimeCache + customMimeTypeMapping map[string]string +) + +func (c *mimeCache) addMimeToCache(key, value string) { + c.Lock() + defer c.Unlock() + + if key == "" || value == "" { + return + } + + if len(c.mimeTypes) >= c.maxSize { + return + } + c.mimeTypes[key] = value +} + +func (c *mimeCache) getMimeFromCache(key string) string { + c.RLock() + defer c.RUnlock() + + if val, ok := c.mimeTypes[key]; ok { + return val + } + return "" +} diff --git a/internal/webdavd/server.go b/internal/webdavd/server.go new file mode 100644 index 00000000..06533320 --- /dev/null +++ b/internal/webdavd/server.go @@ -0,0 +1,473 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webdavd + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "log" + "net" + "net/http" + "path" + "path/filepath" + "runtime/debug" + "slices" + "strings" + "time" + + "github.com/drakkan/webdav" + "github.com/go-chi/chi/v5/middleware" + "github.com/rs/cors" + "github.com/rs/xid" + "github.com/rs/zerolog" + "github.com/sftpgo/sdk/plugin/notifier" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/metric" + "github.com/drakkan/sftpgo/v2/internal/plugin" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/version" +) + +type webDavServer struct { + config *Configuration + binding Binding +} + +func (s *webDavServer) listenAndServe(compressor *middleware.Compressor) error { + handler := compressor.Handler(s) + httpServer := &http.Server{ + ReadHeaderTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxHeaderBytes: 1 << 16, // 64KB + ErrorLog: log.New(&logger.StdLoggerWrapper{Sender: logSender}, "", 0), + } + if s.config.Cors.Enabled { + c := cors.New(cors.Options{ + AllowedOrigins: util.RemoveDuplicates(s.config.Cors.AllowedOrigins, true), + AllowedMethods: util.RemoveDuplicates(s.config.Cors.AllowedMethods, true), + AllowedHeaders: util.RemoveDuplicates(s.config.Cors.AllowedHeaders, true), + ExposedHeaders: util.RemoveDuplicates(s.config.Cors.ExposedHeaders, true), + MaxAge: s.config.Cors.MaxAge, + AllowCredentials: s.config.Cors.AllowCredentials, + OptionsPassthrough: s.config.Cors.OptionsPassthrough, + OptionsSuccessStatus: s.config.Cors.OptionsSuccessStatus, + AllowPrivateNetwork: s.config.Cors.AllowPrivateNetwork, + }) + handler = c.Handler(handler) + } + httpServer.Handler = handler + if certMgr != nil && s.binding.EnableHTTPS { + serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding) + certID := common.DefaultTLSKeyPaidID + if getConfigPath(s.binding.CertificateFile, "") != "" && getConfigPath(s.binding.CertificateKeyFile, "") != "" { + certID = s.binding.GetAddress() + } + httpServer.TLSConfig = &tls.Config{ + GetCertificate: certMgr.GetCertificateFunc(certID), + MinVersion: util.GetTLSVersion(s.binding.MinTLSVersion), + NextProtos: util.GetALPNProtocols(s.binding.Protocols), + CipherSuites: util.GetTLSCiphersFromNames(s.binding.TLSCipherSuites), + } + logger.Debug(logSender, "", "configured TLS cipher suites for binding %q: %v, certID: %v", + s.binding.GetAddress(), httpServer.TLSConfig.CipherSuites, certID) + if s.binding.isMutualTLSEnabled() { + httpServer.TLSConfig.ClientCAs = certMgr.GetRootCAs() + httpServer.TLSConfig.VerifyConnection = s.verifyTLSConnection + switch s.binding.ClientAuthType { + case 1: + httpServer.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + case 2: + httpServer.TLSConfig.ClientAuth = tls.VerifyClientCertIfGiven + } + } + return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, true, + s.binding.listenerWrapper(), logSender) + } + s.binding.EnableHTTPS = false + serviceStatus.Bindings = append(serviceStatus.Bindings, s.binding) + return util.HTTPListenAndServe(httpServer, s.binding.Address, s.binding.Port, false, + s.binding.listenerWrapper(), logSender) +} + +func (s *webDavServer) verifyTLSConnection(state tls.ConnectionState) error { + if certMgr != nil { + var clientCrt *x509.Certificate + var clientCrtName string + if len(state.PeerCertificates) > 0 { + clientCrt = state.PeerCertificates[0] + clientCrtName = clientCrt.Subject.String() + } + if len(state.VerifiedChains) == 0 { + if s.binding.ClientAuthType == 2 { + return nil + } + logger.Warn(logSender, "", "TLS connection cannot be verified: unable to get verification chain") + return errors.New("TLS connection cannot be verified: unable to get verification chain") + } + for _, verifiedChain := range state.VerifiedChains { + var caCrt *x509.Certificate + if len(verifiedChain) > 0 { + caCrt = verifiedChain[len(verifiedChain)-1] + } + if certMgr.IsRevoked(clientCrt, caCrt) { + logger.Debug(logSender, "", "tls handshake error, client certificate %q has been revoked", clientCrtName) + return common.ErrCrtRevoked + } + } + } + + return nil +} + +// returns true if we have to handle a HEAD response, for a directory, ourself +func (s *webDavServer) checkRequestMethod(ctx context.Context, r *http.Request, connection *Connection) bool { + // see RFC4918, section 9.4 + if r.Method == http.MethodGet || r.Method == http.MethodHead { + p := path.Clean(r.URL.Path) + if s.binding.Prefix != "" { + p = strings.TrimPrefix(p, s.binding.Prefix) + } + info, err := connection.Stat(ctx, p) + if err == nil && info.IsDir() { + if r.Method == http.MethodHead { + return true + } + r.Method = "PROPFIND" + if r.Header.Get("Depth") == "" { + r.Header.Add("Depth", "1") + } + } + } + return false +} + +// ServeHTTP implements the http.Handler interface +func (s *webDavServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer func() { + if r := recover(); r != nil { + logger.Error(logSender, "", "panic in ServeHTTP: %q stack trace: %v", r, string(debug.Stack())) + http.Error(w, common.ErrGenericFailure.Error(), http.StatusInternalServerError) + } + }() + + responseControllerDeadlines( + http.NewResponseController(w), + time.Now().Add(60*time.Second), + time.Now().Add(60*time.Second), + ) + w.Header().Set("Server", version.GetServerVersion("/", false)) + ipAddr := s.checkRemoteAddress(r) + + common.Connections.AddClientConnection(ipAddr) + defer common.Connections.RemoveClientConnection(ipAddr) + + if err := common.Connections.IsNewConnectionAllowed(ipAddr, common.ProtocolWebDAV); err != nil { + logger.Log(logger.LevelDebug, common.ProtocolWebDAV, "", "connection not allowed from ip %q: %v", ipAddr, err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + if common.IsBanned(ipAddr, common.ProtocolWebDAV) { + http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) + return + } + delay, err := common.LimitRate(common.ProtocolWebDAV, ipAddr) + if err != nil { + delay += 499999999 * time.Nanosecond + w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds())) + w.Header().Set("X-Retry-In", delay.String()) + http.Error(w, err.Error(), http.StatusTooManyRequests) + return + } + if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolWebDAV); err != nil { + http.Error(w, common.ErrConnectionDenied.Error(), http.StatusForbidden) + return + } + user, isCached, lockSystem, loginMethod, err := s.authenticate(r, ipAddr) + if err != nil { + if !s.binding.DisableWWWAuthHeader { + w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s WebDAV\"", version.GetServerVersion("_", false))) + } + http.Error(w, fmt.Sprintf("Authentication error: %v", err), http.StatusUnauthorized) + return + } + + connectionID, err := s.validateUser(&user, r, loginMethod) + if err != nil { + // remove the cached user, we have not yet validated its filesystem + dataprovider.RemoveCachedWebDAVUser(user.Username) + updateLoginMetrics(&user, ipAddr, loginMethod, err, r) + http.Error(w, err.Error(), http.StatusForbidden) + return + } + + if !isCached { + err = user.CheckFsRoot(connectionID) + } else { + _, err = user.GetFilesystemForPath("/", connectionID) + } + if err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable to check fs root: %v close fs error: %v", err, errClose) + updateLoginMetrics(&user, ipAddr, loginMethod, common.ErrInternalFailure, r) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + baseConn := common.NewBaseConnection(connectionID, common.ProtocolWebDAV, util.GetHTTPLocalAddress(r), + r.RemoteAddr, user) + connection := newConnection(baseConn, w, r) + if err = common.Connections.Add(connection); err != nil { + errClose := user.CloseFs() + logger.Warn(logSender, connectionID, "unable add connection: %v close fs error: %v", err, errClose) + updateLoginMetrics(&user, ipAddr, loginMethod, err, r) + http.Error(w, err.Error(), http.StatusTooManyRequests) + return + } + defer common.Connections.Remove(connection.GetID()) + + updateLoginMetrics(&user, ipAddr, loginMethod, err, r) + + ctx := context.WithValue(r.Context(), requestIDKey, connectionID) + ctx = context.WithValue(ctx, requestStartKey, time.Now()) + + dataprovider.UpdateLastLogin(&user) + + if s.checkRequestMethod(ctx, r, connection) { + w.Header().Set("Content-Type", "text/xml; charset=utf-8") + w.WriteHeader(http.StatusMultiStatus) + w.Write([]byte("")) //nolint:errcheck + writeLog(r, http.StatusMultiStatus, nil) + return + } + + handler := webdav.Handler{ + Prefix: s.binding.Prefix, + FileSystem: connection, + LockSystem: lockSystem, + Logger: writeLog, + } + handler.ServeHTTP(w, r.WithContext(ctx)) +} + +func (s *webDavServer) getCredentialsAndLoginMethod(r *http.Request) (string, string, string, *x509.Certificate, bool) { + var tlsCert *x509.Certificate + loginMethod := dataprovider.LoginMethodPassword + username, password, ok := r.BasicAuth() + if s.binding.isMutualTLSEnabled() && r.TLS != nil { + if len(r.TLS.PeerCertificates) > 0 { + tlsCert = r.TLS.PeerCertificates[0] + if ok { + loginMethod = dataprovider.LoginMethodTLSCertificateAndPwd + } else { + loginMethod = dataprovider.LoginMethodTLSCertificate + username = tlsCert.Subject.CommonName + password = "" + } + ok = true + } + } + return username, password, loginMethod, tlsCert, ok +} + +func (s *webDavServer) authenticate(r *http.Request, ip string) (dataprovider.User, bool, webdav.LockSystem, string, error) { + var user dataprovider.User + var err error + username, password, loginMethod, tlsCert, ok := s.getCredentialsAndLoginMethod(r) + if !ok { + user.Username = username + return user, false, nil, loginMethod, common.ErrNoCredentials + } + cachedUser, ok := dataprovider.GetCachedWebDAVUser(username) + if ok { + if cachedUser.IsExpired() { + dataprovider.RemoveCachedWebDAVUser(username) + } else { + if !cachedUser.User.IsTLSVerificationEnabled() { + // for backward compatibility with 2.0.x we only check the password + tlsCert = nil + loginMethod = dataprovider.LoginMethodPassword + } + cu, u, err := dataprovider.CheckCachedUserCredentials(cachedUser, password, ip, loginMethod, common.ProtocolWebDAV, tlsCert) + if err == nil { + if cu != nil { + return cu.User, true, cu.LockSystem, loginMethod, nil + } + lockSystem := webdav.NewMemLS() + cachedUser = &dataprovider.CachedUser{ + User: *u, + Password: password, + LockSystem: lockSystem, + Expiration: s.config.Cache.Users.getExpirationTime(), + } + dataprovider.CacheWebDAVUser(cachedUser) + return cachedUser.User, false, cachedUser.LockSystem, loginMethod, nil + } + updateLoginMetrics(&cachedUser.User, ip, loginMethod, dataprovider.ErrInvalidCredentials, r) + return user, false, nil, loginMethod, dataprovider.ErrInvalidCredentials + } + } + user, loginMethod, err = dataprovider.CheckCompositeCredentials(username, password, ip, loginMethod, + common.ProtocolWebDAV, tlsCert) + if err != nil { + user.Username = username + updateLoginMetrics(&user, ip, loginMethod, err, r) + return user, false, nil, loginMethod, dataprovider.ErrInvalidCredentials + } + lockSystem := webdav.NewMemLS() + cachedUser = &dataprovider.CachedUser{ + User: user, + Password: password, + LockSystem: lockSystem, + Expiration: s.config.Cache.Users.getExpirationTime(), + } + dataprovider.CacheWebDAVUser(cachedUser) + return user, false, lockSystem, loginMethod, nil +} + +func (s *webDavServer) validateUser(user *dataprovider.User, r *http.Request, loginMethod string) (string, error) { + connID := xid.New().String() + connectionID := fmt.Sprintf("%v_%v", common.ProtocolWebDAV, connID) + + if !filepath.IsAbs(user.HomeDir) { + logger.Warn(logSender, connectionID, "user %q has an invalid home dir: %q. Home dir must be an absolute path, login not allowed", + user.Username, user.HomeDir) + return connID, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir) + } + if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolWebDAV) { + logger.Info(logSender, connectionID, "cannot login user %q, protocol DAV is not allowed", user.Username) + return connID, fmt.Errorf("protocol DAV is not allowed for user %q", user.Username) + } + if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolWebDAV) { + logger.Info(logSender, connectionID, "cannot login user %q, %v login method is not allowed", + user.Username, loginMethod) + return connID, fmt.Errorf("login method %v is not allowed for user %q", loginMethod, user.Username) + } + if !user.IsLoginFromAddrAllowed(r.RemoteAddr) { + logger.Info(logSender, connectionID, "cannot login user %q, remote address is not allowed: %v", + user.Username, r.RemoteAddr) + return connID, fmt.Errorf("login for user %q is not allowed from this address: %v", user.Username, r.RemoteAddr) + } + return connID, nil +} + +func (s *webDavServer) checkRemoteAddress(r *http.Request) string { + ipAddr := util.GetIPFromRemoteAddress(r.RemoteAddr) + var ip net.IP + isUnixSocket := filepath.IsAbs(s.binding.Address) + if !isUnixSocket { + ip = net.ParseIP(ipAddr) + } + if isUnixSocket || ip != nil { + for _, allow := range s.binding.allowHeadersFrom { + if allow(ip) { + parsedIP := util.GetRealIP(r, s.binding.ClientIPProxyHeader, s.binding.ClientIPHeaderDepth) + if parsedIP != "" { + ipAddr = parsedIP + r.RemoteAddr = ipAddr + } + break + } + } + } + return ipAddr +} + +func responseControllerDeadlines(rc *http.ResponseController, read, write time.Time) { + if err := rc.SetReadDeadline(read); err != nil { + logger.Error(logSender, "", "unable to set read timeout to %s: %v", read, err) + } + if err := rc.SetWriteDeadline(write); err != nil { + logger.Error(logSender, "", "unable to set write timeout to %s: %v", write, err) + } +} + +func writeLog(r *http.Request, status int, err error) { + scheme := "http" + cipherSuite := "" + if r.TLS != nil { + scheme = "https" + cipherSuite = tls.CipherSuiteName(r.TLS.CipherSuite) + } + fields := map[string]any{ + "remote_addr": r.RemoteAddr, + "proto": r.Proto, + "method": r.Method, + "user_agent": r.UserAgent(), + "uri": fmt.Sprintf("%s://%s%s", scheme, r.Host, r.RequestURI), + "cipher_suite": cipherSuite, + } + if reqID, ok := r.Context().Value(requestIDKey).(string); ok { + fields["request_id"] = reqID + } + if reqStart, ok := r.Context().Value(requestStartKey).(time.Time); ok { + fields["elapsed_ms"] = time.Since(reqStart).Nanoseconds() / 1000000 + } + if depth := r.Header.Get("Depth"); depth != "" { + fields["depth"] = depth + } + if contentLength := r.Header.Get("Content-Length"); contentLength != "" { + fields["content_length"] = contentLength + } + if timeout := r.Header.Get("Timeout"); timeout != "" { + fields["timeout"] = timeout + } + if status != 0 { + fields["resp_status"] = status + } + var ev *zerolog.Event + if status >= http.StatusInternalServerError { + ev = logger.GetLogger().Error() + } else if status >= http.StatusBadRequest { + ev = logger.GetLogger().Warn() + } else { + ev = logger.GetLogger().Debug() + } + ev. + Timestamp(). + Str("sender", logSender). + Fields(fields). + Err(err). + Send() +} + +func updateLoginMetrics(user *dataprovider.User, ip, loginMethod string, err error, r *http.Request) { + metric.AddLoginAttempt(loginMethod) + if err == nil { + logger.LoginLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, "", r.UserAgent(), r.TLS != nil, "") + plugin.Handler.NotifyLogEvent(notifier.LogEventTypeLoginOK, common.ProtocolWebDAV, user.Username, ip, "", nil) + common.DelayLogin(nil) + } else if err != common.ErrInternalFailure && err != common.ErrNoCredentials { + logger.ConnectionFailedLog(user.Username, ip, loginMethod, common.ProtocolWebDAV, err.Error()) + event := common.HostEventLoginFailed + logEv := notifier.LogEventTypeLoginFailed + if errors.Is(err, util.ErrNotFound) { + event = common.HostEventUserNotFound + logEv = notifier.LogEventTypeLoginNoUser + } + common.AddDefenderEvent(ip, common.ProtocolWebDAV, event) + plugin.Handler.NotifyLogEvent(logEv, common.ProtocolWebDAV, user.Username, ip, "", err) + if loginMethod != dataprovider.LoginMethodTLSCertificate { + common.DelayLogin(err) + } + } + metric.AddLoginResult(loginMethod, err) + dataprovider.ExecutePostLoginHook(user, loginMethod, ip, common.ProtocolWebDAV, err) +} diff --git a/internal/webdavd/webdavd.go b/internal/webdavd/webdavd.go new file mode 100644 index 00000000..36b3b33c --- /dev/null +++ b/internal/webdavd/webdavd.go @@ -0,0 +1,401 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package webdavd implements the WebDAV protocol +package webdavd + +import ( + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "runtime" + "time" + + "github.com/go-chi/chi/v5/middleware" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/util" +) + +type ctxReqParams int + +const ( + requestIDKey ctxReqParams = iota + requestStartKey +) + +const ( + logSender = "webdavd" +) + +var ( + certMgr *common.CertManager + serviceStatus ServiceStatus + timeFormats = []string{ + http.TimeFormat, + "Mon, _2 Jan 2006 15:04:05 GMT", + time.RFC850, + time.ANSIC, + } +) + +// ServiceStatus defines the service status +type ServiceStatus struct { + IsActive bool `json:"is_active"` + Bindings []Binding `json:"bindings"` +} + +// CorsConfig defines the CORS configuration +type CorsConfig struct { + AllowedOrigins []string `json:"allowed_origins" mapstructure:"allowed_origins"` + AllowedMethods []string `json:"allowed_methods" mapstructure:"allowed_methods"` + AllowedHeaders []string `json:"allowed_headers" mapstructure:"allowed_headers"` + ExposedHeaders []string `json:"exposed_headers" mapstructure:"exposed_headers"` + AllowCredentials bool `json:"allow_credentials" mapstructure:"allow_credentials"` + Enabled bool `json:"enabled" mapstructure:"enabled"` + MaxAge int `json:"max_age" mapstructure:"max_age"` + OptionsPassthrough bool `json:"options_passthrough" mapstructure:"options_passthrough"` + OptionsSuccessStatus int `json:"options_success_status" mapstructure:"options_success_status"` + AllowPrivateNetwork bool `json:"allow_private_network" mapstructure:"allow_private_network"` +} + +// CustomMimeMapping defines additional, user defined mime mappings +type CustomMimeMapping struct { + Ext string `json:"ext" mapstructure:"ext"` + Mime string `json:"mime" mapstructure:"mime"` +} + +// UsersCacheConfig defines the cache configuration for users +type UsersCacheConfig struct { + ExpirationTime int `json:"expiration_time" mapstructure:"expiration_time"` + MaxSize int `json:"max_size" mapstructure:"max_size"` +} + +func (c *UsersCacheConfig) getExpirationTime() time.Time { + if c.ExpirationTime > 0 { + return time.Now().Add(time.Duration(c.ExpirationTime) * time.Minute) + } + return time.Time{} +} + +// MimeCacheConfig defines the cache configuration for mime types +type MimeCacheConfig struct { + Enabled bool `json:"enabled" mapstructure:"enabled"` + MaxSize int `json:"max_size" mapstructure:"max_size"` + CustomMappings []CustomMimeMapping `json:"custom_mappings" mapstructure:"custom_mappings"` +} + +// Cache configuration +type Cache struct { + Users UsersCacheConfig `json:"users" mapstructure:"users"` + MimeTypes MimeCacheConfig `json:"mime_types" mapstructure:"mime_types"` +} + +// Binding defines the configuration for a network listener +type Binding struct { + // The address to listen on. A blank value means listen on all available network interfaces. + Address string `json:"address" mapstructure:"address"` + // The port used for serving requests + Port int `json:"port" mapstructure:"port"` + // you also need to provide a certificate for enabling HTTPS + EnableHTTPS bool `json:"enable_https" mapstructure:"enable_https"` + // Certificate and matching private key for this specific binding, if empty the global + // ones will be used, if any + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // Defines the minimum TLS version. 13 means TLS 1.3, default is TLS 1.2 + MinTLSVersion int `json:"min_tls_version" mapstructure:"min_tls_version"` + // set to 1 to require client certificate authentication in addition to basic auth. + // You need to define at least a certificate authority for this to work + ClientAuthType int `json:"client_auth_type" mapstructure:"client_auth_type"` + // TLSCipherSuites is a list of supported cipher suites for TLS version 1.2. + // If CipherSuites is nil/empty, a default list of secure cipher suites + // is used, with a preference order based on hardware performance. + // Note that TLS 1.3 ciphersuites are not configurable. + // The supported ciphersuites names are defined here: + // + // https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53 + // + // any invalid name will be silently ignored. + // The order matters, the ciphers listed first will be the preferred ones. + TLSCipherSuites []string `json:"tls_cipher_suites" mapstructure:"tls_cipher_suites"` + // HTTP protocols to enable in preference order. Supported values: http/1.1, h2 + Protocols []string `json:"tls_protocols" mapstructure:"tls_protocols"` + // Prefix for WebDAV resources, if empty WebDAV resources will be available at the + // root ("/") URI. If defined it must be an absolute URI. + Prefix string `json:"prefix" mapstructure:"prefix"` + // Defines whether to use the common proxy protocol configuration or the + // binding-specific proxy header configuration. + ProxyMode int `json:"proxy_mode" mapstructure:"proxy_mode"` + // List of IP addresses and IP ranges allowed to set client IP proxy headers + ProxyAllowed []string `json:"proxy_allowed" mapstructure:"proxy_allowed"` + // Allowed client IP proxy header such as "X-Forwarded-For", "X-Real-IP" + ClientIPProxyHeader string `json:"client_ip_proxy_header" mapstructure:"client_ip_proxy_header"` + // Some client IP headers such as "X-Forwarded-For" can contain multiple IP address, this setting + // define the position to trust starting from the right. For example if we have: + // "10.0.0.1,11.0.0.1,12.0.0.1,13.0.0.1" and the depth is 0, SFTPGo will use "13.0.0.1" + // as client IP, if depth is 1, "12.0.0.1" will be used and so on + ClientIPHeaderDepth int `json:"client_ip_header_depth" mapstructure:"client_ip_header_depth"` + // Do not add the WWW-Authenticate header after an authentication error, + // only the 401 status code will be sent + DisableWWWAuthHeader bool `json:"disable_www_auth_header" mapstructure:"disable_www_auth_header"` + allowHeadersFrom []func(net.IP) bool +} + +func (b *Binding) parseAllowedProxy() error { + if filepath.IsAbs(b.Address) && len(b.ProxyAllowed) > 0 { + // unix domain socket + b.allowHeadersFrom = []func(net.IP) bool{func(_ net.IP) bool { return true }} + return nil + } + allowedFuncs, err := util.ParseAllowedIPAndRanges(b.ProxyAllowed) + if err != nil { + return err + } + b.allowHeadersFrom = allowedFuncs + return nil +} + +func (b *Binding) isMutualTLSEnabled() bool { + return b.ClientAuthType == 1 || b.ClientAuthType == 2 +} + +// GetAddress returns the binding address +func (b *Binding) GetAddress() string { + if b.Port > 0 { + return fmt.Sprintf("%s:%d", b.Address, b.Port) + } + return b.Address +} + +// IsValid returns true if the binding is valid +func (b *Binding) IsValid() bool { + if b.Port > 0 { + return true + } + if filepath.IsAbs(b.Address) && runtime.GOOS != "windows" { + return true + } + return false +} + +func (b *Binding) listenerWrapper() func(net.Listener) (net.Listener, error) { + if b.ProxyMode == 1 { + return common.Config.GetProxyListener + } + return nil +} + +// Configuration defines the configuration for the WevDAV server +type Configuration struct { + // Addresses and ports to bind to + Bindings []Binding `json:"bindings" mapstructure:"bindings"` + // If files containing a certificate and matching private key for the server are provided you + // can enable HTTPS connections for the configured bindings + // Certificate and key files can be reloaded on demand sending a "SIGHUP" signal on Unix based systems and a + // "paramchange" request to the running service on Windows. + CertificateFile string `json:"certificate_file" mapstructure:"certificate_file"` + CertificateKeyFile string `json:"certificate_key_file" mapstructure:"certificate_key_file"` + // CACertificates defines the set of root certificate authorities to be used to verify client certificates. + CACertificates []string `json:"ca_certificates" mapstructure:"ca_certificates"` + // CARevocationLists defines a set a revocation lists, one for each root CA, to be used to check + // if a client certificate has been revoked + CARevocationLists []string `json:"ca_revocation_lists" mapstructure:"ca_revocation_lists"` + // CORS configuration + Cors CorsConfig `json:"cors" mapstructure:"cors"` + // Cache configuration + Cache Cache `json:"cache" mapstructure:"cache"` + acmeDomain string +} + +// GetStatus returns the server status +func GetStatus() ServiceStatus { + return serviceStatus +} + +// ShouldBind returns true if there is at least a valid binding +func (c *Configuration) ShouldBind() bool { + for _, binding := range c.Bindings { + if binding.IsValid() { + return true + } + } + + return false +} + +func (c *Configuration) getKeyPairs(configDir string) []common.TLSKeyPair { + var keyPairs []common.TLSKeyPair + + for _, binding := range c.Bindings { + certificateFile := getConfigPath(binding.CertificateFile, configDir) + certificateKeyFile := getConfigPath(binding.CertificateKeyFile, configDir) + if certificateFile != "" && certificateKeyFile != "" { + keyPairs = append(keyPairs, common.TLSKeyPair{ + Cert: certificateFile, + Key: certificateKeyFile, + ID: binding.GetAddress(), + }) + } + } + var certificateFile, certificateKeyFile string + if c.acmeDomain != "" { + certificateFile, certificateKeyFile = util.GetACMECertificateKeyPair(c.acmeDomain) + } else { + certificateFile = getConfigPath(c.CertificateFile, configDir) + certificateKeyFile = getConfigPath(c.CertificateKeyFile, configDir) + } + if certificateFile != "" && certificateKeyFile != "" { + keyPairs = append(keyPairs, common.TLSKeyPair{ + Cert: certificateFile, + Key: certificateKeyFile, + ID: common.DefaultTLSKeyPaidID, + }) + } + return keyPairs +} + +func (c *Configuration) loadFromProvider() error { + configs, err := dataprovider.GetConfigs() + if err != nil { + return fmt.Errorf("unable to load config from provider: %w", err) + } + configs.SetNilsToEmpty() + if configs.ACME.Domain == "" || !configs.ACME.HasProtocol(common.ProtocolWebDAV) { + return nil + } + crt, key := util.GetACMECertificateKeyPair(configs.ACME.Domain) + if crt != "" && key != "" { + if _, err := os.Stat(crt); err != nil { + logger.Error(logSender, "", "unable to load acme cert file %q: %v", crt, err) + return nil + } + if _, err := os.Stat(key); err != nil { + logger.Error(logSender, "", "unable to load acme key file %q: %v", key, err) + return nil + } + for idx := range c.Bindings { + c.Bindings[idx].EnableHTTPS = true + } + c.acmeDomain = configs.ACME.Domain + logger.Info(logSender, "", "acme domain set to %q", c.acmeDomain) + return nil + } + return nil +} + +// Initialize configures and starts the WebDAV server +func (c *Configuration) Initialize(configDir string) error { + if err := c.loadFromProvider(); err != nil { + return err + } + logger.Info(logSender, "", "initializing WebDAV server with config %+v", *c) + mimeTypeCache = mimeCache{ + maxSize: c.Cache.MimeTypes.MaxSize, + mimeTypes: make(map[string]string), + } + if !c.Cache.MimeTypes.Enabled { + mimeTypeCache.maxSize = 0 + } else { + customMimeTypeMapping = make(map[string]string) + for _, m := range c.Cache.MimeTypes.CustomMappings { + if m.Mime != "" { + logger.Debug(logSender, "", "adding custom mime mapping for extension %q, mime type %q", m.Ext, m.Mime) + customMimeTypeMapping[m.Ext] = m.Mime + } + } + } + + if !c.ShouldBind() { + return common.ErrNoBinding + } + + keyPairs := c.getKeyPairs(configDir) + if len(keyPairs) > 0 { + mgr, err := common.NewCertManager(keyPairs, configDir, logSender) + if err != nil { + return err + } + mgr.SetCACertificates(c.CACertificates) + if err := mgr.LoadRootCAs(); err != nil { + return err + } + mgr.SetCARevocationLists(c.CARevocationLists) + if err := mgr.LoadCRLs(); err != nil { + return err + } + certMgr = mgr + } + compressor := middleware.NewCompressor(5, "text/*") + dataprovider.InitializeWebDAVUserCache(c.Cache.Users.MaxSize) + + serviceStatus = ServiceStatus{ + Bindings: nil, + } + + exitChannel := make(chan error, 1) + + for _, binding := range c.Bindings { + if !binding.IsValid() { + continue + } + if err := binding.parseAllowedProxy(); err != nil { + return err + } + + go func(binding Binding) { + server := webDavServer{ + config: c, + binding: binding, + } + exitChannel <- server.listenAndServe(compressor) + }(binding) + } + + serviceStatus.IsActive = true + + return <-exitChannel +} + +// ReloadCertificateMgr reloads the certificate manager +func ReloadCertificateMgr() error { + if certMgr != nil { + return certMgr.Reload() + } + return nil +} + +func getConfigPath(name, configDir string) string { + if !util.IsFileInputValid(name) { + return "" + } + if name != "" && !filepath.IsAbs(name) { + return filepath.Join(configDir, name) + } + return name +} + +func parseTime(text string) (t time.Time, err error) { + for _, layout := range timeFormats { + t, err = time.Parse(layout, text) + if err == nil { + return + } + } + return +} diff --git a/internal/webdavd/webdavd_test.go b/internal/webdavd/webdavd_test.go new file mode 100644 index 00000000..f7985c34 --- /dev/null +++ b/internal/webdavd/webdavd_test.go @@ -0,0 +1,3655 @@ +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package webdavd_test + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "net" + "net/http" + "os" + "os/exec" + "path" + "path/filepath" + "regexp" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/minio/sio" + "github.com/pkg/sftp" + "github.com/rs/zerolog" + "github.com/sftpgo/sdk" + sdkkms "github.com/sftpgo/sdk/kms" + "github.com/stretchr/testify/assert" + "github.com/studio-b12/gowebdav" + "golang.org/x/crypto/ssh" + + "github.com/drakkan/sftpgo/v2/internal/common" + "github.com/drakkan/sftpgo/v2/internal/config" + "github.com/drakkan/sftpgo/v2/internal/dataprovider" + "github.com/drakkan/sftpgo/v2/internal/httpclient" + "github.com/drakkan/sftpgo/v2/internal/httpdtest" + "github.com/drakkan/sftpgo/v2/internal/kms" + "github.com/drakkan/sftpgo/v2/internal/logger" + "github.com/drakkan/sftpgo/v2/internal/sftpd" + "github.com/drakkan/sftpgo/v2/internal/util" + "github.com/drakkan/sftpgo/v2/internal/vfs" + "github.com/drakkan/sftpgo/v2/internal/webdavd" +) + +const ( + logSender = "webavdTesting" + webDavServerAddr = "localhost:9090" + webDavTLSServerAddr = "localhost:9443" + webDavServerPort = 9090 + webDavTLSServerPort = 9443 + sftpServerAddr = "127.0.0.1:9022" + defaultUsername = "test_user_dav" + defaultPassword = "test_password" + osWindows = "windows" + webDavCert = `-----BEGIN CERTIFICATE----- +MIICHTCCAaKgAwIBAgIUHnqw7QnB1Bj9oUsNpdb+ZkFPOxMwCgYIKoZIzj0EAwIw +RTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGElu +dGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMDAyMDQwOTUzMDRaFw0zMDAyMDEw +OTUzMDRaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYD +VQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwdjAQBgcqhkjOPQIBBgUrgQQA +IgNiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVqWvrJ51t5OxV0v25NsOgR82CA +NXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIVCzgWkxiz7XE4lgUwX44FCXZM +3+JeUbKjUzBRMB0GA1UdDgQWBBRhLw+/o3+Z02MI/d4tmaMui9W16jAfBgNVHSME +GDAWgBRhLw+/o3+Z02MI/d4tmaMui9W16jAPBgNVHRMBAf8EBTADAQH/MAoGCCqG +SM49BAMCA2kAMGYCMQDqLt2lm8mE+tGgtjDmtFgdOcI72HSbRQ74D5rYTzgST1rY +/8wTi5xl8TiFUyLMUsICMQC5ViVxdXbhuG7gX6yEqSkMKZICHpO8hqFwOD/uaFVI +dV4vKmHUzwK/eIx+8Ay3neE= +-----END CERTIFICATE-----` + webDavKey = `-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDCfMNsN6miEE3rVyUPwElfiJSWaR5huPCzUenZOfJT04GAcQdWvEju3 +UM2lmBLIXpGgBwYFK4EEACKhZANiAARCjRMqJ85rzMC998X5z761nJ+xL3bkmGVq +WvrJ51t5OxV0v25NsOgR82CANXUgvhVYs7vNFN+jxtb2aj6Xg+/2G/BNxkaFspIV +CzgWkxiz7XE4lgUwX44FCXZM3+JeUbI= +-----END EC PRIVATE KEY-----` + caCRT = `-----BEGIN CERTIFICATE----- +MIIE5jCCAs6gAwIBAgIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0 +QXV0aDAeFw0yNDAxMTAxODEyMDRaFw0zNDAxMTAxODIxNTRaMBMxETAPBgNVBAMT +CENlcnRBdXRoMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA7WHW216m +fi4uF8cx6HWf8wvAxaEWgCHTOi2MwFIzOrOtuT7xb64rkpdzx1aWetSiCrEyc3D1 +v03k0Akvlz1gtnDtO64+MA8bqlTnCydZJY4cCTvDOBUYZgtMqHZzpE6xRrqQ84zh +yzjKQ5bR0st+XGfIkuhjSuf2n/ZPS37fge9j6AKzn/2uEVt33qmO85WtN3RzbSqL +CdOJ6cQ216j3la1C5+NWvzIKC7t6NE1bBGI4+tRj7B5P5MeamkkogwbExUjdHp3U +4yasvoGcCHUQDoa4Dej1faywz6JlwB6rTV4ys4aZDe67V/Q8iB2May1k7zBz1Ztb +KF5Em3xewP1LqPEowF1uc4KtPGcP4bxdaIpSpmObcn8AIfH6smLQrn0C3cs7CYfo +NlFuTbwzENUhjz0X6EsoM4w4c87lO+dRNR7YpHLqR/BJTbbyXUB0imne1u00fuzb +S7OtweiA9w7DRCkr2gU4lmHe7l0T+SA9pxIeVLb78x7ivdyXSF5LVQJ1JvhhWu6i +M6GQdLHat/0fpRFUbEe34RQSDJ2eOBifMJqvsvpBP8d2jcRZVUVrSXGc2mAGuGOY +/tmnCJGW8Fd+sgpCVAqM0pxCM+apqrvJYUqqQZ2ZxugCXULtRWJ9p4C9zUl40HEy +OQ+AaiiwFll/doXELglcJdNg8AZPGhugfxMCAwEAAaNFMEMwDgYDVR0PAQH/BAQD +AgEGMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYDVR0OBBYEFNoJhIvDZQrEf/VQbWuu +XgNnt2m5MA0GCSqGSIb3DQEBCwUAA4ICAQCYhT5SRqk19hGrQ09hVSZOzynXAa5F +sYkEWJzFyLg9azhnTPE1bFM18FScnkd+dal6mt+bQiJvdh24NaVkDghVB7GkmXki +pAiZwEDHMqtbhiPxY8LtSeCBAz5JqXVU2Q0TpAgNSH4W7FbGWNThhxcJVOoIrXKE +jbzhwl1Etcaf0DBKWliUbdlxQQs65DLy+rNBYtOeK0pzhzn1vpehUlJ4eTFzP9KX +y2Mksuq9AspPbqnqpWW645MdTxMb5T57MCrY3GDKw63z5z3kz88LWJF3nOxZmgQy +WFUhbLmZm7x6N5eiu6Wk8/B4yJ/n5UArD4cEP1i7nqu+mbbM/SZlq1wnGpg/sbRV +oUF+a7pRcSbfxEttle4pLFhS+ErKatjGcNEab2OlU3bX5UoBs+TYodnCWGKOuBKV +L/CYc65QyeYZ+JiwYn9wC8YkzOnnVIQjiCEkLgSL30h9dxpnTZDLrdAA8ItelDn5 +DvjuQq58CGDsaVqpSobiSC1DMXYWot4Ets1wwovUNEq1l0MERB+2olE+JU/8E23E +eL1/aA7Kw/JibkWz1IyzClpFDKXf6kR2onJyxerdwUL+is7tqYFLysiHxZDL1bli +SXbW8hMa5gvo0IilFP9Rznn8PplIfCsvBDVv6xsRr5nTAFtwKaMBVgznE2ghs69w +kK8u1YiiVenmoQ== +-----END CERTIFICATE-----` + caCRL = `-----BEGIN X509 CRL----- +MIICpzCBkAIBATANBgkqhkiG9w0BAQsFADATMREwDwYDVQQDEwhDZXJ0QXV0aBcN +MjQwMTEwMTgyMjU4WhcNMjYwMTA5MTgyMjU4WjAkMCICEQDOaeHbjY4pEj8WBmqg +ZuRRFw0yNDAxMTAxODIyNThaoCMwITAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1r +rl4DZ7dpuTANBgkqhkiG9w0BAQsFAAOCAgEAZzZ4aBqCcAJigR9e/mqKpJa4B6FV ++jZmnWXolGeUuVkjdiG9w614x7mB2S768iioJyALejjCZjqsp6ydxtn0epQw4199 +XSfPIxA9lxc7w79GLe0v3ztojvxDPh5V1+lwPzGf9i8AsGqb2BrcBqgxDeatndnE +jF+18bY1saXOBpukNLjtRScUXzy5YcSuO6mwz4548v+1ebpF7W4Yh+yh0zldJKcF +DouuirZWujJwTwxxfJ+2+yP7GAuefXUOhYs/1y9ylvUgvKFqSyokv6OaVgTooKYD +MSADzmNcbRvwyAC5oL2yJTVVoTFeP6fXl/BdFH3sO/hlKXGy4Wh1AjcVE6T0CSJ4 +iYFX3gLFh6dbP9IQWMlIM5DKtAKSjmgOywEaWii3e4M0NFSf/Cy17p2E5/jXSLlE +ypDileK0aALkx2twGWwogh6sY1dQ6R3GpKSRPD2muQxVOG6wXvuJce0E9WLx1Ud4 +hVUdUEMlKUvm77/15U5awarH2cCJQxzS/GMeIintQiG7hUlgRzRdmWVe3vOOvt94 +cp8+ZUH/QSDOo41ATTHpFeC/XqF5E2G/ahXqra+O5my52V/FP0bSJnkorJ8apy67 +sn6DFbkqX9khTXGtacczh2PcqVjcQjBniYl2sPO3qIrrrY3tic96tMnM/u3JRdcn +w7bXJGfJcIMrrKs= +-----END X509 CRL-----` + client1Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAJr32nHRlhyPiS7IfZ/ZWYowDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjM3WhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQxMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+kiJ3X6 +HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi85xE +OkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLWeYl7 +Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4ZdRf +XlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dybbhO +c9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBRUh5Xo +Gzjh6iReaPSOgGatqOw9bDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEAyAK7cOTWqjyLgFM0kyyx1fNPvm2GwKep3MuU +OrSnLuWjoxzb7WcbKNVMlnvnmSUAWuErxsY0PUJNfcuqWiGmEp4d/SWfWPigG6DC +sDej35BlSfX8FCufYrfC74VNk4yBS2LVYmIqcpqUrfay0I2oZA8+ToLEpdUvEv2I +l59eOhJO2jsC3JbOyZZmK2Kv7d94fR+1tg2Rq1Wbnmc9AZKq7KDReAlIJh4u2KHb +BbtF79idusMwZyP777tqSQ4THBMa+VAEc2UrzdZqTIAwqlKQOvO2fRz2P+ARR+Tz +MYJMdCdmPZ9qAc8U1OcFBG6qDDltO8wf/Nu/PsSI5LGCIhIuPPIuKfm0rRfTqCG7 +QPQPWjRoXtGGhwjdIuWbX9fIB+c+NpAEKHgLtV+Rxj8s5IVxqG9a5TtU9VkfVXJz +J20naoz/G+vDsVINpd3kH0ziNvdrKfGRM5UgtnUOPCXB22fVmkIsMH2knI10CKK+ +offI56NTkLRu00xvg98/wdukhkwIAxg6PQI/BHY5mdvoacEHHHdOhMq+GSAh7DDX +G8+HdbABM1ExkPnZLat15q706ztiuUpQv1C2DI8YviUVkMqCslj4cD4F8EFPo4kr +kvme0Cuc9Qlf7N5rjdV3cjwavhFx44dyXj9aesft2Q1okPiIqbGNpcjHcIRlj4Au +MU3Bo0A= +-----END CERTIFICATE-----` + client1Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAtuQFiqvdjd8WLxP0FgPDyDEJ1/uJ+Aoj6QllNV7svWxwW+ki +J3X6HUVNWhhCsNfly4pGW4erF4fZzmesElGx1PoWgQCWZKsa/N08bznelWgdmkyi +85xEOkTj6e/cTWHFSOBURNJaXkGHZ0ROSh7qu0Ld+eqNo3k9W+NqZaqYvs2K7MLW +eYl7Qie8Ctuq5Qaz/jm0XwR2PFBROVQSaCPCukancPQ21ftqHPhAbjxoxvvN5QP4 +ZdRfXlH/LDLhlFnJzPZdHnVy9xisSPPRfFApJiwyfjRYdtslpJOcNgP6oPlpX/dy +bbhOc9FEUgj/Q90Je8EfioBYFYsqVD6/dFv9SwIDAQABAoIBAFjSHK7gENVZxphO +hHg8k9ShnDo8eyDvK8l9Op3U3/yOsXKxolivvyx//7UFmz3vXDahjNHe7YScAXdw +eezbqBXa7xrvghqZzp2HhFYwMJ0210mcdncBKVFzK4ztZHxgQ0PFTqet0R19jZjl +X3A325/eNZeuBeOied4qb/24AD6JGc6A0J55f5/QUQtdwYwrL15iC/KZXDL90PPJ +CFJyrSzcXvOMEvOfXIFxhDVKRCppyIYXG7c80gtNC37I6rxxMNQ4mxjwUI2IVhxL +j+nZDu0JgRZ4NaGjOq2e79QxUVm/GG3z25XgmBFBrXkEVV+sCZE1VDyj6kQfv9FU +NhOrwGECgYEAzq47r/HwXifuGYBV/mvInFw3BNLrKry+iUZrJ4ms4g+LfOi0BAgf +sXsWXulpBo2YgYjFdO8G66f69GlB4B7iLscpABXbRtpDZEnchQpaF36/+4g3i8gB +Z29XHNDB8+7t4wbXvlSnLv1tZWey2fS4hPosc2YlvS87DMmnJMJqhs8CgYEA4oiB +LGQP6VNdX0Uigmh5fL1g1k95eC8GP1ylczCcIwsb2OkAq0MT7SHRXOlg3leEq4+g +mCHk1NdjkSYxDL2ZeTKTS/gy4p1jlcDa6Ilwi4pVvatNvu4o80EYWxRNNb1mAn67 +T8TN9lzc6mEi+LepQM3nYJ3F+ZWTKgxH8uoJwMUCgYEArpumE1vbjUBAuEyi2eGn +RunlFW83fBCfDAxw5KM8anNlja5uvuU6GU/6s06QCxg+2lh5MPPrLdXpfukZ3UVa +Itjg+5B7gx1MSALaiY8YU7cibFdFThM3lHIM72wyH2ogkWcrh0GvSFSUQlJcWCSW +asmMGiYXBgBL697FFZomMyMCgYEAkAnp0JcDQwHd4gDsk2zoqnckBsDb5J5J46n+ +DYNAFEww9bgZ08u/9MzG+cPu8xFE621U2MbcYLVfuuBE2ewIlPaij/COMmeO9Z59 +0tPpOuDH6eTtd1SptxqR6P+8pEn8feOlKHBj4Z1kXqdK/EiTlwAVeep4Al2oCFls +ujkz4F0CgYAe8vHnVFHlWi16zAqZx4ZZZhNuqPtgFkvPg9LfyNTA4dz7F9xgtUaY +nXBPyCe/8NtgBfT79HkPiG3TM0xRZY9UZgsJKFtqAu5u4ManuWDnsZI9RK2QTLHe +yEbH5r3Dg3n9k/3GbjXFIWdU9UaYsdnSKHHtMw9ZODc14LaAogEQug== +-----END RSA PRIVATE KEY-----` + // client 2 crt is revoked + client2Crt = `-----BEGIN CERTIFICATE----- +MIIEITCCAgmgAwIBAgIRAM5p4duNjikSPxYGaqBm5FEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAxMIQ2VydEF1dGgwHhcNMjQwMTEwMTgxMjUyWhcNMzQwMTEwMTgy +MTUzWjASMRAwDgYDVQQDEwdjbGllbnQyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImVjm/b +Qe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TPkZua +eq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEKh4LQ +cr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81PQAmT +A0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu4Ic0 +6tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABo3EwbzAOBgNVHQ8BAf8EBAMC +A7gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBR5mf0f +Zjf8ZCGXqU2+45th7VkkLDAfBgNVHSMEGDAWgBTaCYSLw2UKxH/1UG1rrl4DZ7dp +uTANBgkqhkiG9w0BAQsFAAOCAgEARhFxNAouwbpEfN1M90+ao5rwyxEewerSoCCz +PQzeUZ66MA/FkS/tFUGgGGG+wERN+WLbe1cN6q/XFr0FSMLuUxLXDNV02oUL/FnY +xcyNLaZUZ0pP7sA+Hmx2AdTA6baIwQbyIY9RLAaz6hzo1YbI8yeis645F1bxgL2D +EP5kXa3Obv0tqWByMZtrmJPv3p0W5GJKXVDn51GR/E5KI7pliZX2e0LmMX9mxfPB +4sXFUggMHXxWMMSAmXPVsxC2KX6gMnajO7JUraTwuGm+6V371FzEX+UKXHI+xSvO +78TseTIYsBGLjeiA8UjkKlD3T9qsQm2mb2PlKyqjvIm4i2ilM0E2w4JZmd45b925 +7q/QLV3NZ/zZMi6AMyULu28DWKfAx3RLKwnHWSFcR4lVkxQrbDhEUMhAhLAX+2+e +qc7qZm3dTabi7ZJiiOvYK/yNgFHa/XtZp5uKPB5tigPIa+34hbZF7s2/ty5X3O1N +f5Ardz7KNsxJjZIt6HvB28E/PPOvBqCKJc1Y08J9JbZi8p6QS1uarGoR7l7rT1Hv +/ZXkNTw2bw1VpcWdzDBLLVHYNnJmS14189LVk11PcJJpSmubwCqg+ZZULdgtVr3S +ANas2dgMPVwXhnAalgkcc+lb2QqaEz06axfbRGBsgnyqR5/koKCg1Hr0+vThHSsR +E0+r2+4= +-----END CERTIFICATE-----` + client2Key = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEApNYpNZVmXZtAObpRRIuP2o/7z04H2E161vKZvJ3LSLlUTImV +jm/bQe6DTNCUVLnzQuanmUlu2rUnN3lDSfYoBcJWbvC3y1OCPRkCjDV6KiYMA9TP +kZuaeq6y3+bFFfEmyumsVEe0bSuzNHXCOIBT7PqYMdovECcwBh/RZCA5mqO5omEK +h4LQcr6+sVVkvD3nsyx0Alz/kTLFqc0mVflmpJq+0BpdetHRg4n5vy/I/08jZ81P +QAmTA0kyl0Jh132JBGFdA8eyugPPP8n5edU4f3HXV/nR7XLwBrpSt8KgEg8cwfAu +4Ic06tGzB0CH8lSGtU0tH2/cOlDuguDD7VvokQIDAQABAoIBAQCMnEeg9uXQmdvq +op4qi6bV+ZcDWvvkLwvHikFMnYpIaheYBpF2ZMKzdmO4xgCSWeFCQ4Hah8KxfHCM +qLuWvw2bBBE5J8yQ/JaPyeLbec7RX41GQ2YhPoxDdP0PdErREdpWo4imiFhH/Ewt +Rvq7ufRdpdLoS8dzzwnvX3r+H2MkHoC/QANW2AOuVoZK5qyCH5N8yEAAbWKaQaeL +VBhAYEVKbAkWEtXw7bYXzxRR7WIM3f45v3ncRusDIG+Hf75ZjatoH0lF1gHQNofO +qkCVZVzjkLFuzDic2KZqsNORglNs4J6t5Dahb9v3hnoK963YMnVSUjFvqQ+/RZZy +VILFShilAoGBANucwZU61eJ0tLKBYEwmRY/K7Gu1MvvcYJIOoX8/BL3zNmNO0CLl +NiABtNt9WOVwZxDsxJXdo1zvMtAegNqS6W11R1VAZbL6mQ/krScbLDE6JKA5DmA7 +4nNi1gJOW1ziAfdBAfhe4cLbQOb94xkOK5xM1YpO0xgDJLwrZbehDMmPAoGBAMAl +/owPDAvcXz7JFynT0ieYVc64MSFiwGYJcsmxSAnbEgQ+TR5FtkHYe91OSqauZcCd +aoKXQNyrYKIhyounRPFTdYQrlx6KtEs7LU9wOxuphhpJtGjRnhmA7IqvX703wNvu +khrEavn86G5boH8R80371SrN0Rh9UeAlQGuNBdvfAoGAEAmokW9Ug08miwqrr6Pz +3IZjMZJwALidTM1IufQuMnj6ddIhnQrEIx48yPKkdUz6GeBQkuk2rujA+zXfDxc/ +eMDhzrX/N0zZtLFse7ieR5IJbrH7/MciyG5lVpHGVkgjAJ18uVikgAhm+vd7iC7i +vG1YAtuyysQgAKXircBTIL0CgYAHeTLWVbt9NpwJwB6DhPaWjalAug9HIiUjktiB +GcEYiQnBWn77X3DATOA8clAa/Yt9m2HKJIHkU1IV3ESZe+8Fh955PozJJlHu3yVb +Ap157PUHTriSnxyMF2Sb3EhX/rQkmbnbCqqygHC14iBy8MrKzLG00X6BelZV5n0D +8d85dwKBgGWY2nsaemPH/TiTVF6kW1IKSQoIyJChkngc+Xj/2aCCkkmAEn8eqncl +RKjnkiEZeG4+G91Xu7+HmcBLwV86k5I+tXK9O1Okomr6Zry8oqVcxU5TB6VRS+rA +ubwF00Drdvk2+kDZfxIM137nBiy7wgCJi2Ksm5ihN3dUF6Q0oNPl +-----END RSA PRIVATE KEY-----` + testFileName = "test_file_dav.dat" + testDLFileName = "test_download_dav.dat" + tlsClient1Username = "client1" + tlsClient2Username = "client2" + emptyPwdPlaceholder = "empty" + ocMtimeHeader = "X-OC-Mtime" +) + +var ( + configDir = filepath.Join(".", "..", "..") + allPerms = []string{dataprovider.PermAny} + homeBasePath string + hookCmdPath string + extAuthPath string + preLoginPath string + postConnectPath string + preDownloadPath string + preUploadPath string + logFilePath string + certPath string + keyPath string + caCrtPath string + caCRLPath string +) + +func TestMain(m *testing.M) { + logFilePath = filepath.Join(configDir, "sftpgo_webdavd_test.log") + logger.InitLogger(logFilePath, 5, 1, 28, false, false, zerolog.DebugLevel) + os.Setenv("SFTPGO_DATA_PROVIDER__CREATE_DEFAULT_ADMIN", "1") + os.Setenv("SFTPGO_COMMON__ALLOW_SELF_CONNECTIONS", "1") + os.Setenv("SFTPGO_DEFAULT_ADMIN_USERNAME", "admin") + os.Setenv("SFTPGO_DEFAULT_ADMIN_PASSWORD", "password") + os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__EXT", ".sftpgo") + os.Setenv("SFTPGO_WEBDAVD__CACHE__MIME_TYPES__CUSTOM_MAPPINGS__0__MIME", "application/sftpgo") + err := config.LoadConfig(configDir, "") + if err != nil { + logger.ErrorToConsole("error loading configuration: %v", err) + os.Exit(1) + } + providerConf := config.GetProviderConf() + logger.InfoToConsole("Starting WebDAVD tests, provider: %v", providerConf.Driver) + commonConf := config.GetCommonConfig() + commonConf.UploadMode = 2 + homeBasePath = os.TempDir() + if runtime.GOOS != osWindows { + commonConf.Actions.ExecuteOn = []string{"download", "upload", "rename", "delete"} + commonConf.Actions.Hook = hookCmdPath + hookCmdPath, err = exec.LookPath("true") + if err != nil { + logger.Warn(logSender, "", "unable to get hook command: %v", err) + logger.WarnToConsole("unable to get hook command: %v", err) + } + } + + certPath = filepath.Join(os.TempDir(), "test_dav.crt") + keyPath = filepath.Join(os.TempDir(), "test_dav.key") + caCrtPath = filepath.Join(os.TempDir(), "test_dav_ca.crt") + caCRLPath = filepath.Join(os.TempDir(), "test_dav_crl.crt") + err = os.WriteFile(certPath, []byte(webDavCert), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing WebDAV certificate: %v", err) + os.Exit(1) + } + err = os.WriteFile(keyPath, []byte(webDavKey), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing WebDAV private key: %v", err) + os.Exit(1) + } + err = os.WriteFile(caCrtPath, []byte(caCRT), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing WebDAV CA crt: %v", err) + os.Exit(1) + } + err = os.WriteFile(caCRLPath, []byte(caCRL), os.ModePerm) + if err != nil { + logger.ErrorToConsole("error writing WebDAV CRL: %v", err) + os.Exit(1) + } + + err = dataprovider.Initialize(providerConf, configDir, true) + if err != nil { + logger.ErrorToConsole("error initializing data provider: %v", err) + os.Exit(1) + } + + err = common.Initialize(commonConf, 0) + if err != nil { + logger.WarnToConsole("error initializing common: %v", err) + os.Exit(1) + } + + httpConfig := config.GetHTTPConfig() + httpConfig.Initialize(configDir) //nolint:errcheck + kmsConfig := config.GetKMSConfig() + err = kmsConfig.Initialize() + if err != nil { + logger.ErrorToConsole("error initializing kms: %v", err) + os.Exit(1) + } + + httpdConf := config.GetHTTPDConfig() + httpdConf.Bindings[0].Port = 8078 + httpdtest.SetBaseURL("http://127.0.0.1:8078") + + // required to test sftpfs + sftpdConf := config.GetSFTPDConfig() + sftpdConf.Bindings = []sftpd.Binding{ + { + Port: 9022, + }, + } + hostKeyPath := filepath.Join(os.TempDir(), "id_ecdsa") + sftpdConf.HostKeys = []string{hostKeyPath} + + webDavConf := config.GetWebDAVDConfig() + webDavConf.CACertificates = []string{caCrtPath} + webDavConf.CARevocationLists = []string{caCRLPath} + webDavConf.Bindings = []webdavd.Binding{ + { + Port: webDavServerPort, + }, + { + Port: webDavTLSServerPort, + EnableHTTPS: true, + CertificateFile: certPath, + CertificateKeyFile: keyPath, + ClientAuthType: 2, + }, + } + webDavConf.Cors = webdavd.CorsConfig{ + Enabled: true, + AllowedOrigins: []string{"*"}, + AllowedMethods: []string{ + http.MethodHead, + http.MethodGet, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + }, + AllowedHeaders: []string{"*"}, + AllowCredentials: true, + } + + status := webdavd.GetStatus() + if status.IsActive { + logger.ErrorToConsole("webdav server is already active") + os.Exit(1) + } + + extAuthPath = filepath.Join(homeBasePath, "extauth.sh") + preLoginPath = filepath.Join(homeBasePath, "prelogin.sh") + postConnectPath = filepath.Join(homeBasePath, "postconnect.sh") + preDownloadPath = filepath.Join(homeBasePath, "predownload.sh") + preUploadPath = filepath.Join(homeBasePath, "preupload.sh") + + go func() { + logger.Debug(logSender, "", "initializing WebDAV server with config %+v", webDavConf) + if err := webDavConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start WebDAV server: %v", err) + os.Exit(1) + } + }() + + go func() { + if err := httpdConf.Initialize(configDir, 0); err != nil { + logger.ErrorToConsole("could not start HTTP server: %v", err) + os.Exit(1) + } + }() + + go func() { + logger.Debug(logSender, "", "initializing SFTP server with config %+v", sftpdConf) + if err := sftpdConf.Initialize(configDir); err != nil { + logger.ErrorToConsole("could not start SFTP server: %v", err) + os.Exit(1) + } + }() + + waitTCPListening(webDavConf.Bindings[0].GetAddress()) + waitTCPListening(webDavConf.Bindings[1].GetAddress()) + waitTCPListening(httpdConf.Bindings[0].GetAddress()) + waitTCPListening(sftpdConf.Bindings[0].GetAddress()) + webdavd.ReloadCertificateMgr() //nolint:errcheck + + exitCode := m.Run() + os.Remove(logFilePath) + os.Remove(extAuthPath) + os.Remove(preLoginPath) + os.Remove(postConnectPath) + os.Remove(preDownloadPath) + os.Remove(preUploadPath) + os.Remove(certPath) + os.Remove(keyPath) + os.Remove(caCrtPath) + os.Remove(caCRLPath) + os.Remove(hostKeyPath) + os.Remove(hostKeyPath + ".pub") + os.Exit(exitCode) +} + +func TestInitialization(t *testing.T) { + cfg := webdavd.Configuration{ + Bindings: []webdavd.Binding{ + { + Port: 1234, + EnableHTTPS: true, + }, + { + Port: 0, + }, + }, + CertificateFile: "missing path", + CertificateKeyFile: "bad path", + } + err := cfg.Initialize(configDir) + assert.Error(t, err) + + cfg.Cache = config.GetWebDAVDConfig().Cache + cfg.Bindings[0].Port = webDavServerPort + cfg.CertificateFile = certPath + cfg.CertificateKeyFile = keyPath + err = cfg.Initialize(configDir) + assert.Error(t, err) + err = webdavd.ReloadCertificateMgr() + assert.NoError(t, err) + + cfg.Bindings = []webdavd.Binding{ + { + Port: 0, + }, + } + err = cfg.Initialize(configDir) + assert.EqualError(t, err, common.ErrNoBinding.Error()) + + cfg.CertificateFile = certPath + cfg.CertificateKeyFile = keyPath + cfg.CACertificates = []string{""} + + cfg.Bindings = []webdavd.Binding{ + { + Port: 9022, + ClientAuthType: 1, + EnableHTTPS: true, + }, + } + err = cfg.Initialize(configDir) + assert.Error(t, err) + + cfg.CACertificates = nil + cfg.CARevocationLists = []string{""} + err = cfg.Initialize(configDir) + assert.Error(t, err) + + cfg.CARevocationLists = nil + err = cfg.Initialize(configDir) + assert.Error(t, err) + + cfg.CertificateFile = certPath + cfg.CertificateKeyFile = keyPath + cfg.CACertificates = []string{caCrtPath} + cfg.CARevocationLists = []string{caCRLPath} + cfg.Bindings[0].ProxyAllowed = []string{"not valid"} + err = cfg.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "is not a valid IP address") + } + cfg.Bindings[0].ProxyAllowed = nil + err = cfg.Initialize(configDir) + assert.Error(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = cfg.Initialize(configDir) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "unable to load config from provider") + } + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) +} + +func TestBasicHandling(t *testing.T) { + u := getTestUser() + u.QuotaSize = 6553600 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaSize = 6553600 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), user.FirstUpload) + assert.Equal(t, int64(0), user.FirstDownload) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Greater(t, user.FirstDownload, int64(0)) // webdav read the mime type + // overwrite an existing file + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + // wrong password + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword+"1", + true, testFileSize, client) + assert.Error(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + assert.Greater(t, user.FirstUpload, int64(0)) + assert.Greater(t, user.FirstDownload, int64(0)) + err = client.Rename(testFileName, testFileName+"1", false) + assert.NoError(t, err) + _, err = client.Stat(testFileName) + assert.Error(t, err) + // the webdav client hide the error we check the quota + err = client.Remove(testFileName) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + err = client.Remove(testFileName + "1") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-testFileSize, user.UsedQuotaSize) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + + testDir := "testdir" + err = client.Mkdir(testDir, os.ModePerm) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "sub", "sub"), os.ModePerm) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "sub1", "sub1"), os.ModePerm) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "sub2", "sub2"), os.ModePerm) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName+".txt"), + user.Username, defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName), + user.Username, defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + files, err := client.ReadDir(testDir) + assert.NoError(t, err) + assert.Len(t, files, 5) + err = client.Copy(testDir, testDir+"_copy", false) //nolint:goconst + assert.NoError(t, err) + err = client.RemoveAll(testDir) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + status := webdavd.GetStatus() + assert.True(t, status.IsActive) +} + +func TestBasicHandlingCryptFs(t *testing.T) { + u := getTestUserWithCryptFs() + u.QuotaSize = 6553600 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + encryptedFileSize, err := getEncryptedFileSize(testFileSize) + assert.NoError(t, err) + expectedQuotaSize := user.UsedQuotaSize + encryptedFileSize + expectedQuotaFiles := user.UsedQuotaFiles + 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, + user.Username, defaultPassword, false, testFileSize, client) + assert.NoError(t, err) + // overwrite an existing file + err = uploadFileWithRawClient(testFilePath, testFileName, + user.Username, defaultPassword, false, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + files, err := client.ReadDir("/") + assert.NoError(t, err) + if assert.Len(t, files, 1) { + assert.Equal(t, testFileSize, files[0].Size()) + } + err = client.Remove(testFileName) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles-1, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize-encryptedFileSize, user.UsedQuotaSize) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + testDir := "testdir" + err = client.Mkdir(testDir, os.ModePerm) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "sub", "sub"), os.ModePerm) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "sub1", "sub1"), os.ModePerm) + assert.NoError(t, err) + err = client.MkdirAll(path.Join(testDir, "sub2", "sub2"), os.ModePerm) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName+".txt"), + user.Username, defaultPassword, false, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(testDir, testFileName), + user.Username, defaultPassword, false, testFileSize, client) + assert.NoError(t, err) + files, err = client.ReadDir(testDir) + assert.NoError(t, err) + assert.Len(t, files, 5) + for _, f := range files { + if strings.HasPrefix(f.Name(), testFileName) { + assert.Equal(t, testFileSize, f.Size()) + } else { + assert.True(t, f.IsDir()) + } + } + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestBufferedUser(t *testing.T) { + u := getTestUser() + u.FsConfig.OSConfig = sdk.OSFsConfig{ + WriteBufferSize: 2, + ReadBufferSize: 1, + } + vdirPath := "/crypted" + mappedPath := filepath.Join(os.TempDir(), util.GenerateUniqueID()) + folderName := filepath.Base(mappedPath) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + OSFsConfig: sdk.OSFsConfig{ + WriteBufferSize: 3, + ReadBufferSize: 2, + }, + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, + user.Username, defaultPassword, false, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirPath, testFileName), + user.Username, defaultPassword, false, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestLoginEmptyPassword(t *testing.T) { + u := getTestUser() + u.Password = "" + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + user.Password = emptyPwdPlaceholder + client := getWebDavClient(user, false, nil) + err = checkBasicFunc(client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "401") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestAnonymousUser(t *testing.T) { + u := getTestUser() + u.Password = "" + u.Filters.IsAnonymous = true + _, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.Error(t, err) + user, _, err := httpdtest.GetUserByUsername(u.Username, http.StatusOK) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + user.Password = emptyPwdPlaceholder + client = getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + err = client.Mkdir("testdir", os.ModePerm) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLockAfterDelete(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + lockBody := `` + req, err := http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) + assert.NoError(t, err) + req.SetBasicAuth(u.Username, u.Password) + req.Header.Set("Timeout", "Second-3600") + httpClient := httpclient.GetHTTPClient() + resp, err := httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + response, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + re := regexp.MustCompile(`\.*`) + lockToken := string(re.Find(response)) + lockToken = strings.Replace(lockToken, "", "", 1) + lockToken = strings.Replace(lockToken, "", "", 1) + err = resp.Body.Close() + assert.NoError(t, err) + + req, err = http.NewRequest(http.MethodDelete, fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), nil) + assert.NoError(t, err) + req.Header.Set("If", fmt.Sprintf("(%v)", lockToken)) + req.SetBasicAuth(u.Username, u.Password) + resp, err = httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusNoContent, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + // if we try to lock again it must succeed, the lock must be deleted with the object + req, err = http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) + assert.NoError(t, err) + req.SetBasicAuth(u.Username, u.Password) + resp, err = httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMtimeHeader(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client, dataprovider.KeyValue{Key: ocMtimeHeader, Value: "1668879480"}) + assert.NoError(t, err) + // check the modification time + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, time.Unix(1668879480, 0).UTC(), info.ModTime().UTC()) + } + // test on overwrite + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client, dataprovider.KeyValue{Key: ocMtimeHeader, Value: "1667879480"}) + assert.NoError(t, err) + info, err = client.Stat(testFileName) + if assert.NoError(t, err) { + assert.Equal(t, time.Unix(1667879480, 0).UTC(), info.ModTime().UTC()) + } + // invalid time will be silently ignored and the time set to now + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client, dataprovider.KeyValue{Key: ocMtimeHeader, Value: "not unix time"}) + assert.NoError(t, err) + info, err = client.Stat(testFileName) + if assert.NoError(t, err) { + assert.NotEqual(t, time.Unix(1667879480, 0).UTC(), info.ModTime().UTC()) + } + + req, err := http.NewRequest("MOVE", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), nil) + assert.NoError(t, err) + req.Header.Set("Overwrite", "T") + req.Header.Set("Destination", path.Join("/", testFileName+"rename")) + req.Header.Set(ocMtimeHeader, "1666779480") + req.SetBasicAuth(u.Username, u.Password) + httpClient := httpclient.GetHTTPClient() + resp, err := httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + // check the modification time + info, err = client.Stat(testFileName + "rename") + if assert.NoError(t, err) { + assert.Equal(t, time.Unix(1666779480, 0).UTC(), info.ModTime().UTC()) + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestRenameWithLock(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + + lockBody := `` + req, err := http.NewRequest("LOCK", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(lockBody))) + assert.NoError(t, err) + req.SetBasicAuth(u.Username, u.Password) + httpClient := httpclient.GetHTTPClient() + resp, err := httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + response, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + re := regexp.MustCompile(`\.*`) + lockToken := string(re.Find(response)) + lockToken = strings.Replace(lockToken, "", "", 1) + lockToken = strings.Replace(lockToken, "", "", 1) + err = resp.Body.Close() + assert.NoError(t, err) + // MOVE with a lock should succeeded + req, err = http.NewRequest("MOVE", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), nil) + assert.NoError(t, err) + req.Header.Set("If", fmt.Sprintf("(%v)", lockToken)) + req.Header.Set("Overwrite", "T") + req.Header.Set("Destination", path.Join("/", testFileName+"1")) + req.SetBasicAuth(u.Username, u.Password) + resp, err = httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestPropPatch(t *testing.T) { + u := getTestUser() + u.Username = u.Username + "1" + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser := getTestSFTPUser() + sftpUser.FsConfig.SFTPConfig.Username = localUser.Username + + for _, u := range []dataprovider.User{getTestUser(), getTestUserWithCryptFs(), sftpUser} { + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client), sftpUser.Username) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + httpClient := httpclient.GetHTTPClient() + propatchBody := `Wed, 04 Nov 2020 13:25:51 GMTSat, 05 Dec 2020 21:16:12 GMTWed, 04 Nov 2020 13:25:51 GMT00000000` + req, err := http.NewRequest("PROPPATCH", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(propatchBody))) + assert.NoError(t, err) + req.SetBasicAuth(u.Username, u.Password) + resp, err := httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + info, err := client.Stat(testFileName) + if assert.NoError(t, err) { + expected, err := http.ParseTime("Wed, 04 Nov 2020 13:25:51 GMT") + assert.NoError(t, err) + assert.Equal(t, testFileSize, info.Size()) + assert.Equal(t, expected.Format(http.TimeFormat), info.ModTime().Format(http.TimeFormat)) + } + // wrong date + propatchBody = `Wed, 04 Nov 2020 13:25:51 GMTSat, 05 Dec 2020 21:16:12 GMTWid, 04 Nov 2020 13:25:51 GMT00000000` + req, err = http.NewRequest("PROPPATCH", fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName), bytes.NewReader([]byte(propatchBody))) + assert.NoError(t, err) + req.SetBasicAuth(u.Username, u.Password) + resp, err = httpClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestLoginInvalidPwd(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + user.Password = "wrong" + client = getWebDavClient(user, false, nil) + assert.Error(t, checkBasicFunc(client)) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) +} + +func TestLoginNonExistentUser(t *testing.T) { + user := getTestUser() + client := getWebDavClient(user, true, nil) + assert.Error(t, checkBasicFunc(client)) +} + +func TestRateLimiter(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.RateLimitersConfig = []common.RateLimiterConfig{ + { + Average: 1, + Period: 1000, + Burst: 3, + Type: 1, + Protocols: []string{common.ProtocolWebDAV}, + }, + } + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + _, err = client.ReadDir(".") + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "429") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestDefender(t *testing.T) { + oldConfig := config.GetCommonConfig() + + cfg := config.GetCommonConfig() + cfg.DefenderConfig.Enabled = true + cfg.DefenderConfig.Threshold = 3 + cfg.DefenderConfig.ScoreLimitExceeded = 2 + cfg.DefenderConfig.ScoreValid = 1 + + err := common.Initialize(cfg, 0) + assert.NoError(t, err) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + + user.Password = "wrong_pwd" + client = getWebDavClient(user, false, nil) + assert.Error(t, checkBasicFunc(client)) + hosts, _, err := httpdtest.GetDefenderHosts(http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, hosts, 1) { + host := hosts[0] + assert.Empty(t, host.GetBanTime()) + assert.Equal(t, 1, host.Score) + } + + for i := 0; i < 2; i++ { + client = getWebDavClient(user, false, nil) + assert.Error(t, checkBasicFunc(client)) + } + + user.Password = defaultPassword + client = getWebDavClient(user, true, nil) + err = checkBasicFunc(client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + err = common.Initialize(oldConfig, 0) + assert.NoError(t, err) +} + +func TestLoginExternalAuth(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + client := getWebDavClient(u, false, nil) + assert.NoError(t, checkBasicFunc(client)) + u.Username = defaultUsername + "1" + client = getWebDavClient(u, false, nil) + assert.Error(t, checkBasicFunc(client)) + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, defaultUsername, user.Username) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthPasswordChange(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, defaultPassword), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + client := getWebDavClient(u, false, nil) + assert.NoError(t, checkBasicFunc(client)) + u.Username = defaultUsername + "1" + client = getWebDavClient(u, false, nil) + assert.Error(t, checkBasicFunc(client)) + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, defaultPassword+"1"), os.ModePerm) + assert.NoError(t, err) + client = getWebDavClient(u, false, nil) + assert.Error(t, checkBasicFunc(client)) + u.Password = defaultPassword + "1" + client = getWebDavClient(u, false, nil) + assert.NoError(t, checkBasicFunc(client)) + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, defaultUsername, user.Username) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(defaultUsername+"1", http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthReturningAnonymousUser(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + u.Filters.IsAnonymous = true + u.Filters.DeniedProtocols = []string{common.ProtocolSSH} + u.Password = "" + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + client := getWebDavClient(u, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, u.Username, emptyPwdPlaceholder, + false, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.True(t, user.Filters.IsAnonymous) + assert.Equal(t, []string{dataprovider.PermListItems, dataprovider.PermDownload}, user.Permissions["/"]) + assert.Equal(t, []string{common.ProtocolSSH, common.ProtocolHTTP}, user.Filters.DeniedProtocols) + assert.Equal(t, []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodPassword, + dataprovider.SSHLoginMethodKeyboardInteractive, dataprovider.SSHLoginMethodKeyAndPassword, + dataprovider.SSHLoginMethodKeyAndKeyboardInt, dataprovider.LoginMethodTLSCertificate, + dataprovider.LoginMethodTLSCertificateAndPwd}, user.Filters.DeniedLoginMethods) + + u.Password = emptyPwdPlaceholder + client = getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + err = client.Mkdir("testdir", os.ModePerm) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestExternalAuthAnonymousGroupInheritance(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + g := dataprovider.Group{ + BaseGroup: sdk.BaseGroup{ + Name: "test_group", + }, + UserSettings: dataprovider.GroupUserSettings{ + BaseGroupUserSettings: sdk.BaseGroupUserSettings{ + Permissions: map[string][]string{ + "/": allPerms, + }, + Filters: sdk.BaseUserFilters{ + IsAnonymous: true, + }, + }, + }, + } + u := getTestUser() + u.Groups = []sdk.GroupMapping{ + { + Name: g.Name, + Type: sdk.GroupTypePrimary, + }, + } + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + group, _, err := httpdtest.AddGroup(g, http.StatusCreated) + assert.NoError(t, err) + + u.Password = emptyPwdPlaceholder + client := getWebDavClient(u, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + err = client.Mkdir("tdir", os.ModePerm) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), "403") + } + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + assert.False(t, user.Filters.IsAnonymous) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveGroup(group, http.StatusOK) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestPreLoginHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername(defaultUsername, http.StatusNotFound) + assert.NoError(t, err) + client := getWebDavClient(u, true, nil) + assert.NoError(t, checkBasicFunc(client)) + + user, _, err := httpdtest.GetUserByUsername(defaultUsername, http.StatusOK) + assert.NoError(t, err) + // test login with an existing user + client = getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) + assert.NoError(t, err) + // update the user to remove it from the cache + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret(defaultPassword) + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client = getWebDavClient(user, true, nil) + assert.Error(t, checkBasicFunc(client)) + // update the user to remove it from the cache + user.FsConfig.Provider = sdk.LocalFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = nil + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user.Status = 0 + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) + assert.NoError(t, err) + client = getWebDavClient(user, true, nil) + assert.Error(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestPreDownloadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} + common.Config.Actions.Hook = preDownloadPath + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + + err = os.WriteFile(preDownloadPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.Actions.ExecuteOn = []string{common.OperationPreDownload} + common.Config.Actions.Hook = preDownloadPath + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPreUploadHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + oldExecuteOn := common.Config.Actions.ExecuteOn + oldHook := common.Config.Actions.Hook + + common.Config.Actions.ExecuteOn = []string{common.OperationPreUpload} + common.Config.Actions.Hook = preUploadPath + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(preUploadPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + + err = os.WriteFile(preUploadPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.Error(t, err) + + err = uploadFileWithRawClient(testFilePath, testFileName+"1", user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.Actions.ExecuteOn = oldExecuteOn + common.Config.Actions.Hook = oldHook +} + +func TestPostConnectHook(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + common.Config.PostConnectHook = postConnectPath + + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(0), os.ModePerm) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + err = os.WriteFile(postConnectPath, getExitCodeScriptContent(1), os.ModePerm) + assert.NoError(t, err) + assert.Error(t, checkBasicFunc(client)) + + common.Config.PostConnectHook = "http://127.0.0.1:8078/healthz" + assert.NoError(t, checkBasicFunc(client)) + + common.Config.PostConnectHook = "http://127.0.0.1:8078/notfound" + assert.Error(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + + common.Config.PostConnectHook = "" +} + +func TestMaxConnections(t *testing.T) { + oldValue := common.Config.MaxTotalConnections + common.Config.MaxTotalConnections = 1 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + // now add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), "", nil) + connection := &webdavd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + err = common.Connections.Add(connection) + assert.NoError(t, err) + assert.Error(t, checkBasicFunc(client)) + common.Connections.Remove(connection.GetID()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxTotalConnections = oldValue +} + +func TestMaxPerHostConnections(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 1 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + // now add a fake connection + addrs, err := net.LookupHost("localhost") + assert.NoError(t, err) + for _, addr := range addrs { + common.Connections.AddClientConnection(addr) + } + assert.Error(t, checkBasicFunc(client)) + for _, addr := range addrs { + common.Connections.RemoveClientConnection(addr) + } + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestMaxTransfers(t *testing.T) { + oldValue := common.Config.MaxPerHostConnections + common.Config.MaxPerHostConnections = 2 + + assert.Eventually(t, func() bool { + return common.Connections.GetClientConnections() == 0 + }, 1000*time.Millisecond, 50*time.Millisecond) + + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + + conn, sftpClient, err := getSftpClient(user) + assert.NoError(t, err) + defer conn.Close() + defer sftpClient.Close() + + f1, err := sftpClient.Create("file1") + assert.NoError(t, err) + f2, err := sftpClient.Create("file2") + assert.NoError(t, err) + _, err = f1.Write([]byte(" ")) + assert.NoError(t, err) + _, err = f2.Write([]byte(" ")) + assert.NoError(t, err) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + + err = f1.Close() + assert.NoError(t, err) + err = f2.Close() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + common.Config.MaxPerHostConnections = oldValue +} + +func TestMustChangePasswordRequirement(t *testing.T) { + u := getTestUser() + u.Filters.RequirePasswordChange = true + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + assert.Error(t, checkBasicFunc(client)) + + err = dataprovider.UpdateUserPassword(user.Username, defaultPassword, "", "", "") + assert.NoError(t, err) + + client = getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestMaxSessions(t *testing.T) { + u := getTestUser() + u.MaxSessions = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + // now add a fake connection + fs := vfs.NewOsFs("id", os.TempDir(), "", nil) + connection := &webdavd.Connection{ + BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, "", "", user), + } + err = common.Connections.Add(connection) + assert.NoError(t, err) + assert.Error(t, checkBasicFunc(client)) + common.Connections.Remove(connection.GetID()) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func TestLoginWithIPilters(t *testing.T) { + u := getTestUser() + u.Filters.DeniedIP = []string{"192.167.0.0/24", "172.18.0.0/16"} + u.Filters.AllowedIP = []string{"172.19.0.0/16"} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.Error(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDownloadErrors(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1 + subDir1 := "sub1" + subDir2 := "sub2" + u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems} + u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDelete, dataprovider.PermDownload} + // use an unknown mime to trigger content type detection + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/sub2", + AllowedPatterns: []string{}, + DeniedPatterns: []string{"*.jpg", "*.zipp"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + testFilePath1 := filepath.Join(user.HomeDir, subDir1, "file.zipp") + testFilePath2 := filepath.Join(user.HomeDir, subDir2, "file.zipp") + testFilePath3 := filepath.Join(user.HomeDir, subDir2, "file.jpg") + err = os.MkdirAll(filepath.Dir(testFilePath1), os.ModePerm) + assert.NoError(t, err) + err = os.MkdirAll(filepath.Dir(testFilePath2), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath1, []byte("file1"), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath2, []byte("file2"), os.ModePerm) + assert.NoError(t, err) + err = os.WriteFile(testFilePath3, []byte("file3"), os.ModePerm) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = downloadFile(path.Join("/", subDir1, "file.zipp"), localDownloadPath, 5, client) + assert.Error(t, err) + err = downloadFile(path.Join("/", subDir2, "file.zipp"), localDownloadPath, 5, client) + assert.Error(t, err) + err = downloadFile(path.Join("/", subDir2, "file.jpg"), localDownloadPath, 5, client) + assert.Error(t, err) + err = downloadFile(path.Join("missing.zip"), localDownloadPath, 5, client) + assert.Error(t, err) + + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadErrors(t *testing.T) { + u := getTestUser() + u.QuotaSize = 65535 + subDir1 := "sub1" + subDir2 := "sub2" + // we need download permission to get size since PROPFIND will open the file + u.Permissions[path.Join("/", subDir1)] = []string{dataprovider.PermListItems, dataprovider.PermDownload} + u.Permissions[path.Join("/", subDir2)] = []string{dataprovider.PermListItems, dataprovider.PermUpload, + dataprovider.PermDelete, dataprovider.PermDownload} + u.Filters.FilePatterns = []sdk.PatternsFilter{ + { + Path: "/sub2", + AllowedPatterns: []string{}, + DeniedPatterns: []string{"*.zip"}, + }, + } + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := user.QuotaSize + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = client.Mkdir(subDir1, os.ModePerm) + assert.NoError(t, err) + err = client.Mkdir(subDir2, os.ModePerm) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir1, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.Error(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName+".zip"), user.Username, + defaultPassword, true, testFileSize, client) + + assert.Error(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = client.Rename(path.Join(subDir2, testFileName), path.Join(subDir1, testFileName), false) + assert.Error(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(subDir2, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.Error(t, err) + err = uploadFileWithRawClient(testFilePath, subDir1, user.Username, + defaultPassword, true, testFileSize, client) + assert.Error(t, err) + // overquota + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.Error(t, err) + err = client.Remove(path.Join(subDir2, testFileName)) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDeniedLoginMethod(t *testing.T) { + u := getTestUser() + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.Error(t, checkBasicFunc(client)) + + user.Filters.DeniedLoginMethods = []string{dataprovider.SSHLoginMethodPublicKey, dataprovider.SSHLoginMethodKeyAndKeyboardInt} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client = getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestDeniedProtocols(t *testing.T) { + u := getTestUser() + u.Filters.DeniedProtocols = []string{common.ProtocolWebDAV} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + assert.Error(t, checkBasicFunc(client)) + + user.Filters.DeniedProtocols = []string{common.ProtocolSSH, common.ProtocolFTP} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client = getWebDavClient(user, false, nil) + assert.NoError(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestQuotaLimits(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 1 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + testFileSize := int64(65536) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + testFileSize2 := int64(32768) + testFileName2 := "test_file2.dat" + testFilePath2 := filepath.Join(homeBasePath, testFileName2) + err = createTestFile(testFilePath2, testFileSize2) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + // test quota files + err = uploadFileWithRawClient(testFilePath, testFileName+".quota", user.Username, defaultPassword, false, //nolint:goconst + testFileSize, client) + if !assert.NoError(t, err, "username: %v", user.Username) { + info, err := os.Stat(testFilePath) + if assert.NoError(t, err) { + fmt.Printf("local file size: %v\n", info.Size()) + } + printLatestLogs(20) + } + err = uploadFileWithRawClient(testFilePath, testFileName+".quota1", user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err, "username: %v", user.Username) + err = client.Rename(testFileName+".quota", testFileName, false) + assert.NoError(t, err) + files, err := client.ReadDir("/") + assert.NoError(t, err) + assert.Len(t, files, 1) + // test quota size + user.QuotaSize = testFileSize - 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName+".quota", user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + err = client.Rename(testFileName, testFileName+".quota", false) + assert.NoError(t, err) + // now test quota limits while uploading the current file, we have 1 bytes remaining + user.QuotaSize = testFileSize + 1 + user.QuotaFiles = 0 + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, + false, testFileSize1, client) + assert.Error(t, err) + _, err = client.Stat(testFileName1) + assert.Error(t, err) + err = client.Rename(testFileName+".quota", testFileName, false) + assert.NoError(t, err) + // overwriting an existing file will work if the resulting size is lesser or equal than the current one + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath2, testFileName, user.Username, defaultPassword, + false, testFileSize2, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath1, testFileName, user.Username, defaultPassword, + false, testFileSize1, client) + assert.Error(t, err) + err = uploadFileWithRawClient(testFilePath2, testFileName, user.Username, defaultPassword, + false, testFileSize2, client) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + err = os.Remove(testFilePath2) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + user.QuotaFiles = 0 + user.QuotaSize = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestTransferQuotaLimits(t *testing.T) { + u := getTestUser() + u.DownloadDataTransfer = 1 + u.UploadDataTransfer = 1 + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(550000) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + client := getWebDavClient(user, false, nil) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + // error while download is active + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + // error before starting the download + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + // error while upload is active + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + // error before starting the upload + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.Error(t, err) + + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadMaxSize(t *testing.T) { + testFileSize := int64(65535) + u := getTestUser() + u.Filters.MaxUploadFileSize = testFileSize + 1 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.Filters.MaxUploadFileSize = testFileSize + 1 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + testFileSize1 := int64(131072) + testFileName1 := "test_file_dav1.dat" + testFilePath1 := filepath.Join(homeBasePath, testFileName1) + err = createTestFile(testFilePath1, testFileSize1) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, + false, testFileSize1, client) + assert.Error(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + false, testFileSize, client) + assert.NoError(t, err) + // now test overwrite an existing file with a size bigger than the allowed one + err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName1), testFileSize1) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath1, testFileName1, user.Username, defaultPassword, + false, testFileSize1, client) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(testFilePath1) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Filters.MaxUploadFileSize = 65536000 + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientClose(t *testing.T) { + u := getTestUser() + u.UploadBandwidth = 64 + u.DownloadBandwidth = 64 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.UploadBandwidth = 64 + u.DownloadBandwidth = 64 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + testFileSize := int64(1048576) + testFilePath := filepath.Join(homeBasePath, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + assert.NoError(t, checkBasicFunc(client)) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.Error(t, err) + wg.Done() + }() + + assert.Eventually(t, func() bool { + for _, stat := range common.Connections.GetStats("") { + if len(stat.Transfers) > 0 { + return true + } + } + return false + }, 1*time.Second, 50*time.Millisecond) + + for _, stat := range common.Connections.GetStats("") { + common.Connections.Close(stat.ConnectionID, "") + } + wg.Wait() + // for the sftp user a stat is done after the failed upload and + // this triggers a new connection + for _, stat := range common.Connections.GetStats("") { + common.Connections.Close(stat.ConnectionID, "") + } + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + testFilePath = filepath.Join(user.HomeDir, testFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + + wg.Add(1) + go func() { + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.Error(t, err) + wg.Done() + }() + + assert.Eventually(t, func() bool { + for _, stat := range common.Connections.GetStats("") { + if len(stat.Transfers) > 0 { + return true + } + } + return false + }, 1*time.Second, 50*time.Millisecond) + + for _, stat := range common.Connections.GetStats("") { + common.Connections.Close(stat.ConnectionID, "") + } + wg.Wait() + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) + + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + } + + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginWithDatabaseCredentials(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret(`{ "type": "service_account", "private_key": " ", "client_email": "example@iam.gserviceaccount.com" }`) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + assert.Equal(t, sdkkms.SecretStatusSecretBox, user.FsConfig.GCSConfig.Credentials.GetStatus()) + assert.NotEmpty(t, user.FsConfig.GCSConfig.Credentials.GetPayload()) + assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetAdditionalData()) + assert.Empty(t, user.FsConfig.GCSConfig.Credentials.GetKey()) + + client := getWebDavClient(user, false, nil) + + err = client.Connect() + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestLoginInvalidFs(t *testing.T) { + u := getTestUser() + u.FsConfig.Provider = sdk.GCSFilesystemProvider + u.FsConfig.GCSConfig.Bucket = "test" + u.FsConfig.GCSConfig.Credentials = kms.NewPlainSecret("invalid JSON for credentials") + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(user, true, nil) + assert.Error(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestSFTPBuffered(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 1000 + u.HomeDir = filepath.Join(os.TempDir(), u.Username) + u.FsConfig.SFTPConfig.BufferSize = 2 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(sftpUser, true, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + expectedQuotaSize := testFileSize + expectedQuotaFiles := 1 + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + // overwrite an existing file + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + + user, _, err := httpdtest.GetUserByUsername(sftpUser.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + + fileContent := []byte("test file contents") + err = os.WriteFile(testFilePath, fileContent, os.ModePerm) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, defaultPassword, + true, int64(len(fileContent)), client) + assert.NoError(t, err) + remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName) + req, err := http.NewRequest(http.MethodGet, remotePath, nil) + assert.NoError(t, err) + httpClient := httpclient.GetHTTPClient() + req.SetBasicAuth(user.Username, defaultPassword) + req.Header.Set("Range", "bytes=5-") + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusPartialContent, resp.StatusCode) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "file contents", string(bodyBytes)) + } + req.Header.Set("Range", "bytes=5-8") + resp, err = httpClient.Do(req) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusPartialContent, resp.StatusCode) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "file", string(bodyBytes)) + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(sftpUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestBytesRangeRequests(t *testing.T) { + u := getTestUser() + u.Username = u.Username + "1" + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + sftpUser := getTestSFTPUser() + sftpUser.FsConfig.SFTPConfig.Username = localUser.Username + + for _, u := range []dataprovider.User{getTestUser(), getTestUserWithCryptFs(), sftpUser} { + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + testFileName := "test_file.txt" + testFilePath := filepath.Join(homeBasePath, testFileName) + fileContent := []byte("test file contents") + err = os.WriteFile(testFilePath, fileContent, os.ModePerm) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, int64(len(fileContent)), client) + assert.NoError(t, err) + remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName) + req, err := http.NewRequest(http.MethodGet, remotePath, nil) + if assert.NoError(t, err) { + httpClient := httpclient.GetHTTPClient() + req.SetBasicAuth(user.Username, defaultPassword) + req.Header.Set("Range", "bytes=5-") + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusPartialContent, resp.StatusCode) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "file contents", string(bodyBytes)) + } + req.Header.Set("Range", "bytes=5-8") + resp, err = httpClient.Do(req) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusPartialContent, resp.StatusCode) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "file", string(bodyBytes)) + } + } + // seek on a missing file + remotePath = fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName+"_missing") + req, err = http.NewRequest(http.MethodGet, remotePath, nil) + if assert.NoError(t, err) { + httpClient := httpclient.GetHTTPClient() + req.SetBasicAuth(user.Username, defaultPassword) + req.Header.Set("Range", "bytes=5-") + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + } + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestContentTypeGET(t *testing.T) { + user, _, err := httpdtest.AddUser(getTestUser(), http.StatusCreated) + assert.NoError(t, err) + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(64) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + err = uploadFileWithRawClient(testFilePath, testFileName+".sftpgo", user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + remotePath := fmt.Sprintf("http://%v/%v", webDavServerAddr, testFileName+".sftpgo") + req, err := http.NewRequest(http.MethodGet, remotePath, nil) + if assert.NoError(t, err) { + httpClient := httpclient.GetHTTPClient() + req.SetBasicAuth(user.Username, defaultPassword) + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "application/sftpgo", resp.Header.Get("Content-Type")) + } + } + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestHEAD(t *testing.T) { + u := getTestUser() + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + rootPath := fmt.Sprintf("http://%v", webDavServerAddr) + httpClient := httpclient.GetHTTPClient() + req, err := http.NewRequest(http.MethodHead, rootPath, nil) + if assert.NoError(t, err) { + req.SetBasicAuth(u.Username, u.Password) + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) + assert.Equal(t, "text/xml; charset=utf-8", resp.Header.Get("Content-Type")) + resp.Body.Close() + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestGETAsPROPFIND(t *testing.T) { + u := getTestUser() + subDir1 := "/sub1" + u.Permissions[subDir1] = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) + httpClient := httpclient.GetHTTPClient() + req, err := http.NewRequest(http.MethodGet, rootPath, nil) + if assert.NoError(t, err) { + req.SetBasicAuth(u.Username, u.Password) + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) + resp.Body.Close() + } + } + client := getWebDavClient(user, false, nil) + err = client.MkdirAll(path.Join(subDir1, "sub", "sub1"), os.ModePerm) + assert.NoError(t, err) + subPath := fmt.Sprintf("http://%v/%v", webDavServerAddr, subDir1) + req, err = http.NewRequest(http.MethodGet, subPath, nil) + if assert.NoError(t, err) { + req.SetBasicAuth(u.Username, u.Password) + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + // before the performance patch we have a 500 here, now we have 207 but an empty list + //assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + assert.Equal(t, http.StatusMultiStatus, resp.StatusCode) + resp.Body.Close() + } + } + // we cannot stat the sub at all + subPath1 := fmt.Sprintf("http://%v/%v", webDavServerAddr, path.Join(subDir1, "sub")) + req, err = http.NewRequest(http.MethodGet, subPath1, nil) + if assert.NoError(t, err) { + req.SetBasicAuth(u.Username, u.Password) + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + // here the stat will fail, so the request will not be changed in propfind + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + resp.Body.Close() + } + } + + // we have no permission, we get an empty list + files, err := client.ReadDir(subDir1) + assert.NoError(t, err) + assert.Len(t, files, 0) + // if we grant the permissions the files are listed + user.Permissions[subDir1] = []string{dataprovider.PermDownload, dataprovider.PermListItems} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + files, err = client.ReadDir(subDir1) + assert.NoError(t, err) + assert.Len(t, files, 1) + // PROPFIND with infinity depth is forbidden + req, err = http.NewRequest(http.MethodGet, rootPath, nil) + if assert.NoError(t, err) { + req.SetBasicAuth(u.Username, u.Password) + req.Header.Set("Depth", "infinity") + resp, err := httpClient.Do(req) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + resp.Body.Close() + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestStat(t *testing.T) { + u := getTestUser() + u.Permissions["/subdir"] = []string{dataprovider.PermUpload, dataprovider.PermListItems, dataprovider.PermDownload} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, true, nil) + subDir := "subdir" + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = client.Mkdir(subDir, os.ModePerm) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, testFileName, user.Username, defaultPassword, + true, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join("/", subDir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + user.Permissions["/subdir"] = []string{dataprovider.PermUpload, dataprovider.PermDownload} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + _, err = client.Stat(testFileName) + assert.NoError(t, err) + _, err = client.Stat(path.Join("/", subDir, testFileName)) + assert.Error(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestUploadOverwriteVfolder(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 1000 + vdir := "/vdir" + mappedPath := filepath.Join(os.TempDir(), "mappedDir") + folderName := filepath.Base(mappedPath) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdir, + QuotaSize: -1, + QuotaFiles: -1, + }) + err = os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + files, err := client.ReadDir(".") + assert.NoError(t, err) + vdirFound := false + for _, info := range files { + if info.Name() == path.Base(vdir) { + vdirFound = true + break + } + } + assert.True(t, vdirFound) + info, err := client.Stat(vdir) + if assert.NoError(t, err) { + assert.Equal(t, path.Base(vdir), info.Name()) + } + + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + folder, _, err := httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, 0, folder.UsedQuotaFiles) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + err = uploadFileWithRawClient(testFilePath, path.Join(vdir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + folder, _, err = httpdtest.GetFolderByName(folderName, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, int64(0), folder.UsedQuotaSize) + assert.Equal(t, 0, folder.UsedQuotaFiles) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, testFileSize, user.UsedQuotaSize) + assert.Equal(t, 1, user.UsedQuotaFiles) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestOsErrors(t *testing.T) { + u := getTestUser() + vdir := "/vdir" + mappedPath := filepath.Join(os.TempDir(), "mappedDir") + folderName := filepath.Base(mappedPath) + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdir, + QuotaSize: -1, + QuotaFiles: -1, + }) + f := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + client := getWebDavClient(user, false, nil) + files, err := client.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, files, 1) + info, err := client.Stat(vdir) + assert.NoError(t, err) + assert.True(t, info.IsDir()) + // now remove the folder mapped to vdir. It still appear in directory listing + // virtual folders are automatically added + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + files, err = client.ReadDir(".") + assert.NoError(t, err) + assert.Len(t, files, 1) + err = createTestFile(filepath.Join(user.GetHomeDir(), testFileName), 32768) + assert.NoError(t, err) + files, err = client.ReadDir(".") + assert.NoError(t, err) + if assert.Len(t, files, 2) { + var names []string + for _, info := range files { + names = append(names, info.Name()) + } + assert.Contains(t, names, testFileName) + assert.Contains(t, names, "vdir") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + +func TestMiscCommands(t *testing.T) { + u := getTestUser() + u.QuotaFiles = 100 + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + u.QuotaFiles = 100 + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + for _, user := range []dataprovider.User{localUser, sftpUser} { + dir := "testDir" + client := getWebDavClient(user, true, nil) + err = client.MkdirAll(path.Join(dir, "sub1", "sub2"), os.ModePerm) + assert.NoError(t, err) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(dir, testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(dir, "sub1", testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(dir, "sub1", "sub2", testFileName), user.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = client.Copy(dir, dir+"_copy", false) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, 6*testFileSize, user.UsedQuotaSize) + err = client.Copy(dir, dir+"_copy1", false) //nolint:goconst + assert.NoError(t, err) + err = client.Copy(dir+"_copy", dir+"_copy1", false) + assert.Error(t, err) + err = client.Copy(dir+"_copy", dir+"_copy1", true) + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 9, user.UsedQuotaFiles) + assert.Equal(t, 9*testFileSize, user.UsedQuotaSize) + err = client.Rename(dir+"_copy1", dir+"_copy2", false) + assert.NoError(t, err) + err = client.Remove(path.Join(dir+"_copy", testFileName)) + assert.NoError(t, err) + err = client.Rename(dir+"_copy2", dir+"_copy", true) + assert.NoError(t, err) + err = client.Copy(dir+"_copy", dir+"_copy1", false) + assert.NoError(t, err) + err = client.RemoveAll(dir + "_copy1") + assert.NoError(t, err) + user, _, err = httpdtest.GetUserByUsername(user.Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 6, user.UsedQuotaFiles) + assert.Equal(t, 6*testFileSize, user.UsedQuotaSize) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + if user.Username == defaultUsername { + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + user.Password = defaultPassword + user.ID = 0 + user.CreatedAt = 0 + user.QuotaFiles = 0 + _, resp, err := httpdtest.AddUser(user, http.StatusCreated) + assert.NoError(t, err, string(resp)) + } + } + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientCertificateAuthRevokedCert(t *testing.T) { + u := getTestUser() + u.Username = tlsClient2Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client2Crt), []byte(client2Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client := getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + if assert.Error(t, err) { + if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "broken pipe") { + t.Errorf("unexpected error: %v", err) + } + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientCertificateAuth(t *testing.T) { + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificateAndPwd} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + // TLS username is not enabled, mutual TLS should fail + resp, err := getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v/", webDavTLSServerAddr)) + if assert.NoError(t, err) { + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) + } + + user.Filters.TLSUsername = sdk.TLSUsernameCN + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client := getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + assert.NoError(t, err) + user.Filters.TLSUsername = sdk.TLSUsernameNone + user.Filters.TLSCerts = []string{client1Crt} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client = getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + assert.NoError(t, err) + + user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client = getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestWrongClientCertificate(t *testing.T) { + u := getTestUser() + u.Username = tlsClient2Username + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificateAndPwd} + u.Filters.TLSUsername = sdk.TLSUsernameCN + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + + // the certificate common name is client1 and it does not exists + resp, err := getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v/", webDavTLSServerAddr)) + if assert.NoError(t, err) { + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) + } + + user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + // now create client1 + u = getTestUser() + u.Username = tlsClient1Username + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodPassword, dataprovider.LoginMethodTLSCertificate} + u.Filters.TLSUsername = sdk.TLSUsernameCN + user1, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + resp, err = getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v:%v@%v/", tlsClient2Username, defaultPassword, + webDavTLSServerAddr)) + if assert.NoError(t, err) { + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) + assert.Contains(t, string(body), "invalid credentials") + } + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) +} + +func TestClientCertificateAuthCachedUser(t *testing.T) { + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificateAndPwd} + user, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client := getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + assert.NoError(t, err) + // the user is now cached without a password, try a simple password login with and without TLS + client = getWebDavClient(user, true, nil) + err = checkBasicFunc(client) + assert.NoError(t, err) + + client = getWebDavClient(user, false, nil) + err = checkBasicFunc(client) + assert.NoError(t, err) + + // and now with a wrong password + user.Password = "wrong" + client = getWebDavClient(user, false, nil) + err = checkBasicFunc(client) + assert.Error(t, err) + + // allow cert+password only + user.Password = "" + user.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificate} + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + + client = getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + assert.NoError(t, err) + // the user is now cached + client = getWebDavClient(user, true, tlsConfig) + err = checkBasicFunc(client) + assert.NoError(t, err) + // password auth should work too + client = getWebDavClient(user, false, nil) + err = checkBasicFunc(client) + assert.NoError(t, err) + + client = getWebDavClient(user, true, nil) + err = checkBasicFunc(client) + assert.NoError(t, err) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) +} + +func TestExternatAuthWithClientCert(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodPassword} + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(extAuthPath, getExtAuthScriptContent(u, ""), os.ModePerm) + assert.NoError(t, err) + providerConf.ExternalAuthHook = extAuthPath + providerConf.ExternalAuthScope = 0 + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client := getWebDavClient(u, true, tlsConfig) + assert.NoError(t, checkBasicFunc(client)) + + resp, err := getTLSHTTPClient(tlsConfig).Get(fmt.Sprintf("https://%v:%v@%v/", tlsClient2Username, defaultPassword, + webDavTLSServerAddr)) + if assert.NoError(t, err) { + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, string(body)) + assert.Contains(t, string(body), "invalid credentials") + } + + user, _, err := httpdtest.GetUserByUsername(tlsClient1Username, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, tlsClient1Username, user.Username) + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(extAuthPath) + assert.NoError(t, err) +} + +func TestPreLoginHookWithClientCert(t *testing.T) { + if runtime.GOOS == osWindows { + t.Skip("this test is not available on Windows") + } + u := getTestUser() + u.Username = tlsClient1Username + u.Filters.TLSUsername = sdk.TLSUsernameCN + u.Filters.DeniedLoginMethods = []string{dataprovider.LoginMethodTLSCertificate, dataprovider.LoginMethodPassword} + err := dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf := config.GetProviderConf() + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(u, false), os.ModePerm) + assert.NoError(t, err) + providerConf.PreLoginHook = preLoginPath + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + _, _, err = httpdtest.GetUserByUsername(tlsClient1Username, http.StatusNotFound) + assert.NoError(t, err) + tlsConfig := &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + tlsCert, err := tls.X509KeyPair([]byte(client1Crt), []byte(client1Key)) + assert.NoError(t, err) + tlsConfig.Certificates = append(tlsConfig.Certificates, tlsCert) + client := getWebDavClient(u, true, tlsConfig) + assert.NoError(t, checkBasicFunc(client)) + + user, _, err := httpdtest.GetUserByUsername(tlsClient1Username, http.StatusOK) + assert.NoError(t, err) + // test login with an existing user + client = getWebDavClient(user, true, tlsConfig) + assert.NoError(t, checkBasicFunc(client)) + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, true), os.ModePerm) + assert.NoError(t, err) + // update the user to remove it from the cache + user.Password = defaultPassword + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + client = getWebDavClient(user, true, tlsConfig) + assert.Error(t, checkBasicFunc(client)) + // update the user to remove it from the cache + user.Password = defaultPassword + user, _, err = httpdtest.UpdateUser(user, http.StatusOK, "") + assert.NoError(t, err) + user.Status = 0 + err = os.WriteFile(preLoginPath, getPreLoginScriptContent(user, false), os.ModePerm) + assert.NoError(t, err) + client = getWebDavClient(user, true, tlsConfig) + assert.Error(t, checkBasicFunc(client)) + + _, err = httpdtest.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + err = dataprovider.Close() + assert.NoError(t, err) + err = config.LoadConfig(configDir, "") + assert.NoError(t, err) + providerConf = config.GetProviderConf() + err = dataprovider.Initialize(providerConf, configDir, true) + assert.NoError(t, err) + err = os.Remove(preLoginPath) + assert.NoError(t, err) +} + +func TestSFTPLoopVirtualFolders(t *testing.T) { + user1 := getTestUser() + user2 := getTestUser() + user1.Username += "1" + user2.Username += "2" + // user1 is a local account with a virtual SFTP folder to user2 + // user2 has user1 as SFTP fs + folderName := "sftp" + user1.VirtualFolders = append(user1.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: "/vdir", + }) + user2.FsConfig.Provider = sdk.SFTPFilesystemProvider + user2.FsConfig.SFTPConfig = vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user1.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + } + f := vfs.BaseVirtualFolder{ + Name: folderName, + FsConfig: vfs.Filesystem{ + Provider: sdk.SFTPFilesystemProvider, + SFTPConfig: vfs.SFTPFsConfig{ + BaseSFTPFsConfig: sdk.BaseSFTPFsConfig{ + Endpoint: sftpServerAddr, + Username: user2.Username, + }, + Password: kms.NewPlainSecret(defaultPassword), + }, + }, + } + _, _, err := httpdtest.AddFolder(f, http.StatusCreated) + assert.NoError(t, err) + + user1, resp, err := httpdtest.AddUser(user1, http.StatusCreated) + assert.NoError(t, err, string(resp)) + user2, resp, err = httpdtest.AddUser(user2, http.StatusCreated) + assert.NoError(t, err, string(resp)) + + client := getWebDavClient(user1, true, nil) + + testDir := "tdir" + err = client.Mkdir(testDir, os.ModePerm) + assert.NoError(t, err) + + contents, err := client.ReadDir("/") + assert.NoError(t, err) + if assert.Len(t, contents, 2) { + expected := 0 + for _, info := range contents { + switch info.Name() { + case testDir, "vdir": + assert.True(t, info.IsDir()) + expected++ + default: + t.Errorf("unexpected file/dir %q", info.Name()) + } + } + assert.Equal(t, expected, 2) + } + + _, err = httpdtest.RemoveUser(user1, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user1.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(user2, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user2.GetHomeDir()) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) +} + +func TestNestedVirtualFolders(t *testing.T) { + u := getTestUser() + localUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + u = getTestSFTPUser() + mappedPathCrypt := filepath.Join(os.TempDir(), "crypt") + folderNameCrypt := filepath.Base(mappedPathCrypt) + vdirCryptPath := "/vdir/crypt" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + }, + VirtualPath: vdirCryptPath, + }) + mappedPath := filepath.Join(os.TempDir(), "local") + folderName := filepath.Base(mappedPath) + vdirPath := "/vdir/local" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderName, + }, + VirtualPath: vdirPath, + }) + mappedPathNested := filepath.Join(os.TempDir(), "nested") + folderNameNested := filepath.Base(mappedPathNested) + vdirNestedPath := "/vdir/crypt/nested" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + Name: folderNameNested, + }, + VirtualPath: vdirNestedPath, + QuotaFiles: -1, + QuotaSize: -1, + }) + f1 := vfs.BaseVirtualFolder{ + Name: folderNameCrypt, + FsConfig: vfs.Filesystem{ + Provider: sdk.CryptedFilesystemProvider, + CryptConfig: vfs.CryptFsConfig{ + Passphrase: kms.NewPlainSecret(defaultPassword), + }, + }, + MappedPath: mappedPathCrypt, + } + _, _, err = httpdtest.AddFolder(f1, http.StatusCreated) + assert.NoError(t, err) + f2 := vfs.BaseVirtualFolder{ + Name: folderName, + MappedPath: mappedPath, + } + _, _, err = httpdtest.AddFolder(f2, http.StatusCreated) + assert.NoError(t, err) + f3 := vfs.BaseVirtualFolder{ + Name: folderNameNested, + MappedPath: mappedPathNested, + } + _, _, err = httpdtest.AddFolder(f3, http.StatusCreated) + assert.NoError(t, err) + sftpUser, _, err := httpdtest.AddUser(u, http.StatusCreated) + assert.NoError(t, err) + + client := getWebDavClient(sftpUser, true, nil) + assert.NoError(t, checkBasicFunc(client)) + testFilePath := filepath.Join(homeBasePath, testFileName) + testFileSize := int64(65535) + localDownloadPath := filepath.Join(homeBasePath, testDLFileName) + err = createTestFile(testFilePath, testFileSize) + assert.NoError(t, err) + + err = uploadFileWithRawClient(testFilePath, testFileName, sftpUser.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(testFileName, localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join("/vdir", testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join("/vdir", testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirPath, testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirCryptPath, testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirCryptPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + err = uploadFileWithRawClient(testFilePath, path.Join(vdirNestedPath, testFileName), sftpUser.Username, + defaultPassword, true, testFileSize, client) + assert.NoError(t, err) + err = downloadFile(path.Join(vdirNestedPath, testFileName), localDownloadPath, testFileSize, client) + assert.NoError(t, err) + + err = os.Remove(testFilePath) + assert.NoError(t, err) + err = os.Remove(localDownloadPath) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(sftpUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveUser(localUser, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameCrypt}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderName}, http.StatusOK) + assert.NoError(t, err) + _, err = httpdtest.RemoveFolder(vfs.BaseVirtualFolder{Name: folderNameNested}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathCrypt) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) + err = os.RemoveAll(mappedPathNested) + assert.NoError(t, err) + err = os.RemoveAll(localUser.GetHomeDir()) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return len(common.Connections.GetStats("")) == 0 }, + 1*time.Second, 100*time.Millisecond) + assert.Equal(t, int32(0), common.Connections.GetTotalTransfers()) +} + +func checkBasicFunc(client *gowebdav.Client) error { + err := client.Connect() + if err != nil { + return err + } + _, err = client.ReadDir("/") + return err +} + +func checkFileSize(remoteDestPath string, expectedSize int64, client *gowebdav.Client) error { + info, err := client.Stat(remoteDestPath) + if err != nil { + return err + } + if info.Size() != expectedSize { + return fmt.Errorf("uploaded file size does not match, actual: %v, expected: %v", info.Size(), expectedSize) + } + return nil +} + +func uploadFileWithRawClient(localSourcePath string, remoteDestPath string, username, password string, + useTLS bool, expectedSize int64, client *gowebdav.Client, headers ...dataprovider.KeyValue, +) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + + var tlsConfig *tls.Config + rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) + if useTLS { + rootPath = fmt.Sprintf("https://%v/", webDavTLSServerAddr) + tlsConfig = &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + } + req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%v%v", rootPath, remoteDestPath), srcFile) + if err != nil { + return err + } + req.SetBasicAuth(username, password) + for _, kv := range headers { + req.Header.Set(kv.Key, kv.Value) + } + httpClient := &http.Client{Timeout: 10 * time.Second} + if tlsConfig != nil { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = tlsConfig + httpClient.Transport = customTransport + } + defer httpClient.CloseIdleConnections() + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusCreated { + return fmt.Errorf("unexpected status code: %v", resp.StatusCode) + } + if expectedSize > 0 { + return checkFileSize(remoteDestPath, expectedSize, client) + } + return nil +} + +// This method is buggy. I have to find time to better investigate and eventually report the issue upstream. +// For now we upload using the uploadFileWithRawClient method +/*func uploadFile(localSourcePath string, remoteDestPath string, expectedSize int64, client *gowebdav.Client) error { + srcFile, err := os.Open(localSourcePath) + if err != nil { + return err + } + defer srcFile.Close() + err = client.WriteStream(remoteDestPath, srcFile, os.ModePerm) + if err != nil { + return err + } + if expectedSize > 0 { + return checkFileSize(remoteDestPath, expectedSize, client) + } + return nil +}*/ + +func downloadFile(remoteSourcePath string, localDestPath string, expectedSize int64, client *gowebdav.Client) error { + downloadDest, err := os.Create(localDestPath) + if err != nil { + return err + } + defer downloadDest.Close() + + reader, err := client.ReadStream(remoteSourcePath) + if err != nil { + return err + } + defer reader.Close() + written, err := io.Copy(downloadDest, reader) + if err != nil { + return err + } + if written != expectedSize { + return fmt.Errorf("downloaded file size does not match, actual: %v, expected: %v", written, expectedSize) + } + return nil +} + +func getTLSHTTPClient(tlsConfig *tls.Config) *http.Client { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = tlsConfig + + return &http.Client{ + Timeout: 5 * time.Second, + Transport: customTransport, + } +} + +func getWebDavClient(user dataprovider.User, useTLS bool, tlsConfig *tls.Config) *gowebdav.Client { + rootPath := fmt.Sprintf("http://%v/", webDavServerAddr) + if useTLS { + rootPath = fmt.Sprintf("https://%v/", webDavTLSServerAddr) + if tlsConfig == nil { + tlsConfig = &tls.Config{ + ServerName: "localhost", + InsecureSkipVerify: true, // use this for tests only + MinVersion: tls.VersionTLS12, + } + } + } + pwd := defaultPassword + if user.Password != "" { + if user.Password == emptyPwdPlaceholder { + pwd = "" + } else { + pwd = user.Password + } + } + client := gowebdav.NewClient(rootPath, user.Username, pwd) + client.SetTimeout(10 * time.Second) + if tlsConfig != nil { + customTransport := http.DefaultTransport.(*http.Transport).Clone() + customTransport.TLSClientConfig = tlsConfig + client.SetTransport(customTransport) + } + return client +} + +func waitTCPListening(address string) { + for { + conn, err := net.Dial("tcp", address) + if err != nil { + logger.WarnToConsole("tcp server %v not listening: %v", address, err) + time.Sleep(100 * time.Millisecond) + continue + } + logger.InfoToConsole("tcp server %v now listening", address) + conn.Close() + break + } +} + +func getTestUser() dataprovider.User { + user := dataprovider.User{ + BaseUser: sdk.BaseUser{ + Username: defaultUsername, + Password: defaultPassword, + HomeDir: filepath.Join(homeBasePath, defaultUsername), + Status: 1, + ExpirationDate: 0, + }, + } + user.Permissions = make(map[string][]string) + user.Permissions["/"] = allPerms + return user +} + +func getTestSFTPUser() dataprovider.User { + u := getTestUser() + u.Username = u.Username + "_sftp" + u.FsConfig.Provider = sdk.SFTPFilesystemProvider + u.FsConfig.SFTPConfig.Endpoint = sftpServerAddr + u.FsConfig.SFTPConfig.Username = defaultUsername + u.FsConfig.SFTPConfig.Password = kms.NewPlainSecret(defaultPassword) + return u +} + +func getTestUserWithCryptFs() dataprovider.User { + user := getTestUser() + user.FsConfig.Provider = sdk.CryptedFilesystemProvider + user.FsConfig.CryptConfig.Passphrase = kms.NewPlainSecret("testPassphrase") + return user +} + +func getSftpClient(user dataprovider.User) (*ssh.Client, *sftp.Client, error) { + var sftpClient *sftp.Client + config := &ssh.ClientConfig{ + User: user.Username, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + if user.Password != "" { + config.Auth = []ssh.AuthMethod{ssh.Password(user.Password)} + } else { + config.Auth = []ssh.AuthMethod{ssh.Password(defaultPassword)} + } + + conn, err := ssh.Dial("tcp", sftpServerAddr, config) + if err != nil { + return conn, sftpClient, err + } + sftpClient, err = sftp.NewClient(conn) + if err != nil { + conn.Close() + } + return conn, sftpClient, err +} + +func getEncryptedFileSize(size int64) (int64, error) { + encSize, err := sio.EncryptedSize(uint64(size)) + return int64(encSize) + 33, err +} + +func getExtAuthScriptContent(user dataprovider.User, password string) []byte { + extAuthContent := []byte("#!/bin/sh\n\n") + if password != "" { + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%s\" -a \"$SFTPGO_AUTHD_PASSWORD\" = \"%s\"; then\n", user.Username, password))...) + } else { + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("if test \"$SFTPGO_AUTHD_USERNAME\" = \"%s\"; then\n", user.Username))...) + } + u, _ := json.Marshal(user) + extAuthContent = append(extAuthContent, []byte(fmt.Sprintf("echo '%s'\n", string(u)))...) + extAuthContent = append(extAuthContent, []byte("else\n")...) + extAuthContent = append(extAuthContent, []byte("echo '{\"username\":\"\"}'\n")...) + extAuthContent = append(extAuthContent, []byte("fi\n")...) + return extAuthContent +} + +func getPreLoginScriptContent(user dataprovider.User, nonJSONResponse bool) []byte { + content := []byte("#!/bin/sh\n\n") + if nonJSONResponse { + content = append(content, []byte("echo 'text response'\n")...) + return content + } + if len(user.Username) > 0 { + u, _ := json.Marshal(user) + content = append(content, []byte(fmt.Sprintf("echo '%v'\n", string(u)))...) + } + return content +} + +func getExitCodeScriptContent(exitCode int) []byte { + content := []byte("#!/bin/sh\n\n") + content = append(content, []byte(fmt.Sprintf("exit %v", exitCode))...) + return content +} + +func createTestFile(path string, size int64) error { + baseDir := filepath.Dir(path) + if _, err := os.Stat(baseDir); errors.Is(err, fs.ErrNotExist) { + err = os.MkdirAll(baseDir, os.ModePerm) + if err != nil { + return err + } + } + content := make([]byte, size) + _, err := rand.Read(content) + if err != nil { + return err + } + + err = os.WriteFile(path, content, os.ModePerm) + if err != nil { + return err + } + fi, err := os.Stat(path) + if err != nil { + return err + } + if fi.Size() != size { + return fmt.Errorf("unexpected size %v, expected %v", fi.Size(), size) + } + return nil +} + +func printLatestLogs(maxNumberOfLines int) { + var lines []string + f, err := os.Open(logFilePath) + if err != nil { + return + } + defer f.Close() + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()+"\r\n") + for len(lines) > maxNumberOfLines { + lines = lines[1:] + } + } + if scanner.Err() != nil { + logger.WarnToConsole("Unable to print latest logs: %v", scanner.Err()) + return + } + for _, line := range lines { + logger.DebugToConsole("%s", line) + } +} diff --git a/logger/logger.go b/logger/logger.go deleted file mode 100644 index a385d972..00000000 --- a/logger/logger.go +++ /dev/null @@ -1,116 +0,0 @@ -// Package logger provides logging capabilities. -// It is a wrapper around zerolog for logging and lumberjack for log rotation. -// Logs are written to the specified log file. -// Logging on the console is provided to print initialization info, errors and warnings. -// The package provides a request logger to log the HTTP requests for REST API too. -// The request logger uses chi.middleware.RequestLogger, -// chi.middleware.LogFormatter and chi.middleware.LogEntry to build a structured -// logger using zerlog -package logger - -import ( - "fmt" - "os" - "runtime" - - "github.com/rs/zerolog" - lumberjack "gopkg.in/natefinch/lumberjack.v2" -) - -const ( - dateFormat = "2006-01-02T15:04.05.000" // YYYY-MM-DDTHH:MM.SS.ZZZ -) - -var ( - logger zerolog.Logger - consoleLogger zerolog.Logger -) - -// GetLogger get the configured logger instance -func GetLogger() *zerolog.Logger { - return &logger -} - -// InitLogger configures the logger using the given parameters -func InitLogger(logFilePath string, logMaxSize int, logMaxBackups int, logMaxAge int, logCompress bool, level zerolog.Level) { - zerolog.TimeFieldFormat = dateFormat - logger = zerolog.New(&lumberjack.Logger{ - Filename: logFilePath, - MaxSize: logMaxSize, - MaxBackups: logMaxBackups, - MaxAge: logMaxAge, - Compress: logCompress, - }).With().Timestamp().Logger().Level(level) - - consoleOutput := zerolog.ConsoleWriter{ - Out: os.Stdout, - TimeFormat: dateFormat, - NoColor: runtime.GOOS == "windows", - } - consoleLogger = zerolog.New(consoleOutput).With().Timestamp().Logger().Level(level) -} - -// Debug logs at debug level for the specified sender -func Debug(sender string, format string, v ...interface{}) { - logger.Debug().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) -} - -// Info logs at info level for the specified sender -func Info(sender string, format string, v ...interface{}) { - logger.Info().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) -} - -// Warn logs at warn level for the specified sender -func Warn(sender string, format string, v ...interface{}) { - logger.Warn().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) -} - -// Error logs at error level for the specified sender -func Error(sender string, format string, v ...interface{}) { - logger.Error().Str("sender", sender).Msg(fmt.Sprintf(format, v...)) -} - -// DebugToConsole logs at debug level to stdout -func DebugToConsole(format string, v ...interface{}) { - consoleLogger.Debug().Msg(fmt.Sprintf(format, v...)) -} - -// InfoToConsole logs at info level to stdout -func InfoToConsole(format string, v ...interface{}) { - consoleLogger.Info().Msg(fmt.Sprintf(format, v...)) -} - -// WarnToConsole logs at info level to stdout -func WarnToConsole(format string, v ...interface{}) { - consoleLogger.Warn().Msg(fmt.Sprintf(format, v...)) -} - -// ErrorToConsole logs at error level to stdout -func ErrorToConsole(format string, v ...interface{}) { - consoleLogger.Error().Msg(fmt.Sprintf(format, v...)) -} - -// TransferLog logs an SFTP/SCP upload or download -func TransferLog(operation string, path string, elapsed int64, size int64, user string, connectionID string, protocol string) { - logger.Info(). - Str("sender", operation). - Int64("elapsed_ms", elapsed). - Int64("size_bytes", size). - Str("username", user). - Str("file_path", path). - Str("connection_id", connectionID). - Str("protocol", protocol). - Msg("") -} - -// CommandLog logs an SFTP/SCP command -func CommandLog(command string, path string, target string, user string, connectionID string, protocol string) { - logger.Info(). - Str("sender", command). - Str("username", user). - Str("file_path", path). - Str("target_path", target). - Str("connection_id", connectionID). - Str("protocol", protocol). - Msg("") -} diff --git a/main.go b/main.go index 1604364e..1d9cf48e 100644 --- a/main.go +++ b/main.go @@ -1,13 +1,26 @@ -// Full featured and highly configurable SFTP server. -// For more details about features, installation, configuration and usage please refer to the README inside the source tree: -// https://github.com/drakkan/sftpgo/blob/master/README.md +// Copyright (C) 2019 Nicola Murino +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Fully featured and highly configurable SFTP server with optional +// FTP/S and WebDAV support. +// For more details about features, installation, configuration and usage +// please refer to the README inside the source tree: +// https://github.com/drakkan/sftpgo/blob/main/README.md package main // import "github.com/drakkan/sftpgo" import ( - "github.com/drakkan/sftpgo/cmd" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" + "github.com/drakkan/sftpgo/v2/internal/cmd" ) func main() { diff --git a/openapi/httpfs.yaml b/openapi/httpfs.yaml new file mode 100644 index 00000000..99405753 --- /dev/null +++ b/openapi/httpfs.yaml @@ -0,0 +1,613 @@ +openapi: 3.0.3 +tags: + - name: fs +info: + title: SFTPGo HTTPFs + description: | + SFTPGo can use custom storage backend implementations compliant with the API defined here. + HTTPFs is a work in progress and makes no API stability promises. + version: 0.1.0 + license: + name: AGPL-3.0-only + url: 'https://www.gnu.org/licenses/agpl-3.0.en.html' +servers: +- url: /v1 +security: +- ApiKeyAuth: [] +- BasicAuth: [] +paths: + /stat/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + get: + tags: + - fs + summary: Describes the named object + operationId: stat + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/FileInfo' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /open/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + - name: offset + in: query + description: 'offset, in bytes, from the start. If not specified 0 must be assumed' + required: false + schema: + type: integer + format: int64 + get: + tags: + - fs + summary: Opens the named file for reading + operationId: open + responses: + '200': + description: successful operation + content: + '*/*': + schema: + type: string + format: binary + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /create/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + - name: flags + in: query + description: 'flags to use for opening the file, if omitted O_RDWR|O_CREATE|O_TRUNC must be assumed. Supported flags: https://pkg.go.dev/os#pkg-constants' + required: false + schema: + type: integer + format: int32 + - name: checks + in: query + description: 'If set to `1`, the parent directory must exist before creating the file' + required: false + schema: + type: integer + format: int32 + post: + tags: + - fs + summary: Creates or opens the named file for writing + operationId: create + requestBody: + content: + '*/*': + schema: + type: string + format: binary + required: true + responses: + 201: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /rename/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + - name: target + in: query + description: target name + required: true + schema: + type: string + patch: + tags: + - fs + summary: Renames (moves) source to target + operationId: rename + responses: + 200: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /remove/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + delete: + tags: + - fs + summary: Removes the named file or (empty) directory. + operationId: delete + responses: + 200: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /mkdir/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + post: + tags: + - fs + summary: Creates a new directory with the specified name + operationId: mkdir + responses: + 200: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /chmod/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + - name: mode + in: query + required: true + schema: + type: integer + patch: + tags: + - fs + summary: Changes the mode of the named file + operationId: chmod + responses: + 200: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /chtimes/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + - name: access_time + in: query + required: true + schema: + type: string + format: date-time + - name: modification_time + in: query + required: true + schema: + type: string + format: date-time + patch: + tags: + - fs + summary: Changes the access and modification time of the named file + operationId: chtimes + responses: + 200: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /truncate/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + - name: size + in: query + required: true + description: 'new file size in bytes' + schema: + type: integer + format: int64 + patch: + tags: + - fs + summary: Changes the size of the named file + operationId: truncate + responses: + 200: + $ref: '#/components/responses/OKResponse' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /readdir/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + get: + tags: + - fs + summary: Reads the named directory and returns the contents + operationId: readdir + responses: + 200: + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/FileInfo' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /dirsize/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + get: + tags: + - fs + summary: Returns the number of files and the size for the named directory including any sub-directory + operationId: dirsize + responses: + 200: + description: successful operation + content: + application/json: + schema: + type: object + properties: + files: + type: integer + description: 'Total number of files' + size: + type: integer + format: int64 + description: 'Total size of files' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /mimetype/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + get: + tags: + - fs + summary: Returns the mime type for the named file + operationId: mimetype + responses: + 200: + description: successful operation + content: + application/json: + schema: + type: object + properties: + mime: + type: string + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' + /statvfs/{name}: + parameters: + - name: name + in: path + description: object name + required: true + schema: + type: string + get: + tags: + - fs + summary: Returns the VFS stats for the specified path + operationId: statvfs + responses: + 200: + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/StatVFS' + 401: + $ref: '#/components/responses/Unauthorized' + 403: + $ref: '#/components/responses/Forbidden' + 404: + $ref: '#/components/responses/NotFound' + 500: + $ref: '#/components/responses/InternalServerError' + 501: + $ref: '#/components/responses/NotImplemented' + default: + $ref: '#/components/responses/DefaultResponse' +components: + responses: + OKResponse: + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + BadRequest: + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + Unauthorized: + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + Forbidden: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + NotFound: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + NotImplemented: + description: Not Implemented + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + Conflict: + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + RequestEntityTooLarge: + description: Request Entity Too Large, max allowed size exceeded + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + InternalServerError: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + DefaultResponse: + description: Unexpected Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + schemas: + ApiResponse: + type: object + properties: + message: + type: string + description: 'message, can be empty' + error: + type: string + description: error description if any + FileInfo: + type: object + properties: + name: + type: string + description: base name of the file + size: + type: integer + format: int64 + description: length in bytes for regular files; system-dependent for others + mode: + type: integer + description: | + File mode and permission bits. More details here: https://golang.org/pkg/io/fs/#FileMode. + Let's see some examples: + - for a directory mode&2147483648 != 0 + - for a symlink mode&134217728 != 0 + - for a regular file mode&2401763328 == 0 + last_modified: + type: string + format: date-time + StatVFS: + type: object + properties: + bsize: + type: integer + description: file system block size + frsize: + type: integer + description: fundamental fs block size + blocks: + type: integer + description: number of blocks + bfree: + type: integer + description: free blocks in file system + bavail: + type: integer + description: free blocks for non-root + files: + type: integer + description: total file inodes + ffree: + type: integer + description: free file inodes + favail: + type: integer + description: free file inodes for non-root + fsid: + type: integer + description: file system id + flag: + type: integer + description: bit mask of f_flag values + namemax: + type: integer + description: maximum filename length + securitySchemes: + BasicAuth: + type: http + scheme: basic + ApiKeyAuth: + type: apiKey + in: header + name: X-API-KEY \ No newline at end of file diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml new file mode 100644 index 00000000..32dc3dc7 --- /dev/null +++ b/openapi/openapi.yaml @@ -0,0 +1,7409 @@ +openapi: 3.0.3 +tags: + - name: healthcheck + - name: token + - name: maintenance + - name: admins + - name: API keys + - name: connections + - name: IP Lists + - name: defender + - name: quota + - name: folders + - name: groups + - name: roles + - name: users + - name: data retention + - name: events + - name: metadata + - name: user APIs + - name: public shares + - name: event manager +info: + title: SFTPGo + description: | + SFTPGo allows you to securely share your files over SFTP and optionally over HTTP/S, FTP/S and WebDAV as well. + Several storage backends are supported and they are configurable per-user, so you can serve a local directory for a user and an S3 bucket (or part of it) for another one. + SFTPGo also supports virtual folders, a virtual folder can use any of the supported storage backends. So you can have, for example, a user with the S3 backend mapping a Google Cloud Storage bucket (or part of it) on a specified path and an encrypted local filesystem on another one. + Virtual folders can be private or shared among multiple users, for shared virtual folders you can define different quota limits for each user. + SFTPGo supports groups to simplify the administration of multiple accounts by letting you assign settings once to a group, instead of multiple times to each individual user. + The SFTPGo WebClient allows end users to change their credentials, browse and manage their files in the browser and setup two-factor authentication which works with Authy, Google Authenticator and other compatible apps. + From the WebClient each authorized user can also create HTTP/S links to externally share files and folders securely, by setting limits to the number of downloads/uploads, protecting the share with a password, limiting access by source IP address, setting an automatic expiration date. + version: v2.7.0 + contact: + name: API support + url: 'https://github.com/drakkan/sftpgo' + license: + name: AGPL-3.0-only + url: 'https://www.gnu.org/licenses/agpl-3.0.en.html' +servers: + - url: /api/v2 +security: + - BearerAuth: [] + - APIKeyAuth: [] +paths: + /healthz: + get: + security: [] + servers: + - url: / + tags: + - healthcheck + summary: health check + description: This endpoint can be used to check if the application is running and responding to requests + operationId: healthz + responses: + '200': + description: successful operation + content: + text/plain; charset=utf-8: + schema: + type: string + example: ok + /shares/{id}: + parameters: + - name: id + in: path + description: the share id + required: true + schema: + type: string + get: + security: + - BasicAuth: [] + tags: + - public shares + summary: Download shared files and folders as a single zip file + description: A zip file, containing the shared files and folders, will be generated on the fly and returned as response body. Only folders and regular files will be included in the zip. The share must be defined with the read scope and the associated user must have list and download permissions + operationId: get_share + parameters: + - in: query + name: compress + schema: + type: boolean + default: true + required: false + responses: + '200': + description: successful operation + content: + '*/*': + schema: + type: string + format: binary + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + security: + - BasicAuth: [] + tags: + - public shares + summary: Upload one or more files to the shared path + description: The share must be defined with the write scope and the associated user must have the upload permission + operationId: upload_to_share + requestBody: + content: + multipart/form-data: + schema: + type: object + properties: + filenames: + type: array + items: + type: string + format: binary + minItems: 1 + uniqueItems: true + required: true + responses: + '201': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '413': + $ref: '#/components/responses/RequestEntityTooLarge' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /shares/{id}/files: + parameters: + - name: id + in: path + description: the share id + required: true + schema: + type: string + get: + security: + - BasicAuth: [] + tags: + - public shares + summary: Download a single file + description: Returns the file contents as response body. The share must have exactly one path defined and it must be a directory for this to work + operationId: download_share_file + parameters: + - in: query + name: path + required: true + description: Path to the file to download. It must be URL encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt" + schema: + type: string + - in: query + name: inline + required: false + description: 'If set, the response will not have the Content-Disposition header set to `attachment`' + schema: + type: string + responses: + '200': + description: successful operation + content: + '*/*': + schema: + type: string + format: binary + '206': + description: successful operation + content: + '*/*': + schema: + type: string + format: binary + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /shares/{id}/dirs: + parameters: + - name: id + in: path + description: the share id + required: true + schema: + type: string + get: + security: + - BasicAuth: [] + tags: + - public shares + summary: Read directory contents + description: Returns the contents of the specified directory for the specified share. The share must have exactly one path defined and it must be a directory for this to work + operationId: get_share_dir_contents + parameters: + - in: query + name: path + description: Path to the folder to read. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir". If empty or missing the user's start directory is assumed. If relative, the user's start directory is used as the base + schema: + type: string + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/DirEntry' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /shares/{id}/{fileName}: + parameters: + - name: id + in: path + description: the share id + required: true + schema: + type: string + - name: fileName + in: path + description: the name of the new file. It must be path encoded. Sub directories are not accepted + required: true + schema: + type: string + - name: X-SFTPGO-MTIME + in: header + schema: + type: integer + description: File modification time as unix timestamp in milliseconds + post: + security: + - BasicAuth: [] + tags: + - public shares + summary: Upload a single file to the shared path + description: The share must be defined with the write scope and the associated user must have the upload/overwrite permissions + operationId: upload_single_to_share + requestBody: + content: + application/*: + schema: + type: string + format: binary + text/*: + schema: + type: string + format: binary + image/*: + schema: + type: string + format: binary + audio/*: + schema: + type: string + format: binary + video/*: + schema: + type: string + format: binary + required: true + responses: + '201': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '413': + $ref: '#/components/responses/RequestEntityTooLarge' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /token: + get: + security: + - BasicAuth: [] + tags: + - token + summary: Get a new admin access token + description: Returns an access token and its expiration + operationId: get_token + parameters: + - in: header + name: X-SFTPGO-OTP + schema: + type: string + required: false + description: 'If you have 2FA configured for the admin attempting to log in you need to set the authentication code using this header parameter' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/Token' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /logout: + get: + security: + - BearerAuth: [] + tags: + - token + summary: Invalidate an admin access token + description: Allows to invalidate an admin token before its expiration + operationId: logout + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/token: + get: + security: + - BasicAuth: [] + tags: + - token + summary: Get a new user access token + description: Returns an access token and its expiration + operationId: get_user_token + parameters: + - in: header + name: X-SFTPGO-OTP + schema: + type: string + required: false + description: 'If you have 2FA configured, for the HTTP protocol, for the user attempting to log in you need to set the authentication code using this header parameter' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/Token' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/logout: + get: + security: + - BearerAuth: [] + tags: + - token + summary: Invalidate a user access token + description: Allows to invalidate a client token before its expiration + operationId: client_logout + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /version: + get: + tags: + - maintenance + summary: Get version details + description: 'Returns version details such as the version number, build date, commit hash and enabled features' + operationId: get_version + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/VersionInfo' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/changepwd: + put: + security: + - BearerAuth: [] + tags: + - admins + summary: Change admin password + description: Changes the password for the logged in admin + operationId: change_admin_password + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PwdChange' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/profile: + get: + security: + - BearerAuth: [] + tags: + - admins + summary: Get admin profile + description: 'Returns the profile for the logged in admin' + operationId: get_admin_profile + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/AdminProfile' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + security: + - BearerAuth: [] + tags: + - admins + summary: Update admin profile + description: 'Allows to update the profile for the logged in admin' + operationId: update_admin_profile + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/AdminProfile' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/2fa/recoverycodes: + get: + security: + - BearerAuth: [] + tags: + - admins + summary: Get recovery codes + description: 'Returns the recovery codes for the logged in admin. Recovery codes can be used if the admin loses access to their second factor auth device. Recovery codes are returned unencrypted' + operationId: get_admin_recovery_codes + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/RecoveryCode' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + security: + - BearerAuth: [] + tags: + - admins + summary: Generate recovery codes + description: 'Generates new recovery codes for the logged in admin. Generating new recovery codes you automatically invalidate old ones' + operationId: generate_admin_recovery_codes + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + type: string + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/totp/configs: + get: + security: + - BearerAuth: [] + tags: + - admins + summary: Get available TOTP configuration + description: Returns the available TOTP configurations for the logged in admin + operationId: get_admin_totp_configs + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/TOTPConfig' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/totp/generate: + post: + security: + - BearerAuth: [] + tags: + - admins + summary: Generate a new TOTP secret + description: 'Generates a new TOTP secret, including the QR code as png, using the specified configuration for the logged in admin' + operationId: generate_admin_totp_secret + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + config_name: + type: string + description: 'name of the configuration to use to generate the secret' + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: object + properties: + config_name: + type: string + issuer: + type: string + secret: + type: string + url: + type: string + qr_code: + type: string + format: byte + description: 'QR code png encoded as BASE64' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/totp/validate: + post: + security: + - BearerAuth: [] + tags: + - admins + summary: Validate a one time authentication code + description: 'Checks if the given authentication code can be validated using the specified secret and config name' + operationId: validate_admin_totp_secret + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + config_name: + type: string + description: 'name of the configuration to use to validate the passcode' + passcode: + type: string + description: 'passcode to validate' + secret: + type: string + description: 'secret to use to validate the passcode' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Passcode successfully validated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admin/totp/save: + post: + security: + - BearerAuth: [] + tags: + - admins + summary: Save a TOTP config + description: 'Saves the specified TOTP config for the logged in admin' + operationId: save_admin_totp_config + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/AdminTOTPConfig' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: TOTP configuration saved + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /connections: + get: + tags: + - connections + summary: Get connections details + description: Returns the active users and info about their current uploads/downloads + operationId: get_connections + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ConnectionStatus' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/connections/{connectionID}': + delete: + tags: + - connections + summary: Close connection + description: Terminates an active connection + operationId: close_connection + parameters: + - name: connectionID + in: path + description: ID of the connection to close + required: true + schema: + type: string + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Connection closed + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /iplists/{type}: + parameters: + - name: type + in: path + description: IP list type + required: true + schema: + $ref: '#/components/schemas/IPListType' + get: + tags: + - IP Lists + summary: Get IP list entries + description: Returns an array with one or more IP list entry + operationId: get_ip_list_entries + parameters: + - in: query + name: filter + schema: + type: string + description: restrict results to ipornet matching or starting with this filter + - in: query + name: from + schema: + type: string + description: ipornet to start from + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering entries by ipornet field. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/IPListEntry' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - IP Lists + summary: Add a new IP list entry + description: Add an IP address or a CIDR network to a supported list + operationId: add_ip_list_entry + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/IPListEntry' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Entry added + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /iplists/{type}/{ipornet}: + parameters: + - name: type + in: path + description: IP list type + required: true + schema: + $ref: '#/components/schemas/IPListType' + - name: ipornet + in: path + required: true + schema: + type: string + get: + tags: + - IP Lists + summary: Find entry by ipornet + description: Returns the entry with the given ipornet if it exists. + operationId: get_ip_list_by_ipornet + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/IPListEntry' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - IP Lists + summary: Update IP list entry + description: Updates an existing IP list entry + operationId: update_ip_list_entry + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/IPListEntry' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Entry updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - IP Lists + summary: Delete IP list entry + description: Deletes an existing IP list entry + operationId: delete_ip_list_entry + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Entry deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /defender/hosts: + get: + tags: + - defender + summary: Get hosts + description: Returns hosts that are banned or for which some violations have been detected + operationId: get_defender_hosts + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/DefenderEntry' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /defender/hosts/{id}: + parameters: + - name: id + in: path + description: host id + required: true + schema: + type: string + get: + tags: + - defender + summary: Get host by id + description: Returns the host with the given id, if it exists + operationId: get_defender_host_by_id + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/DefenderEntry' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - defender + summary: Removes a host from the defender lists + description: Unbans the specified host or clears its violations + operationId: delete_defender_host_by_id + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /retention/users/checks: + get: + tags: + - data retention + summary: Get retention checks + description: Returns the active retention checks + operationId: get_users_retention_checks + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/RetentionCheck' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/users/scans: + get: + tags: + - quota + summary: Get active user quota scans + description: Returns the active user quota scans + operationId: get_users_quota_scans + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/QuotaScan' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/users/{username}/scan: + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + post: + tags: + - quota + summary: Start a user quota scan + description: Starts a new quota scan for the given user. A quota scan updates the number of files and their total size for the specified user and the virtual folders, if any, included in his quota + operationId: start_user_quota_scan + responses: + '202': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Scan started + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '409': + $ref: '#/components/responses/Conflict' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/users/{username}/usage: + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + - in: query + name: mode + required: false + description: the update mode specifies if the given quota usage values should be added or replace the current ones + schema: + type: string + enum: + - add + - reset + description: | + Update type: + * `add` - add the specified quota limits to the current used ones + * `reset` - reset the values to the specified ones. This is the default + example: reset + put: + tags: + - quota + summary: Update disk quota usage limits + description: Sets the current used quota limits for the given user + operationId: user_quota_update_usage + requestBody: + required: true + description: 'If used_quota_size and used_quota_files are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' + content: + application/json: + schema: + $ref: '#/components/schemas/QuotaUsage' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Quota updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '409': + $ref: '#/components/responses/Conflict' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/users/{username}/transfer-usage: + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + - in: query + name: mode + required: false + description: the update mode specifies if the given quota usage values should be added or replace the current ones + schema: + type: string + enum: + - add + - reset + description: | + Update type: + * `add` - add the specified quota limits to the current used ones + * `reset` - reset the values to the specified ones. This is the default + example: reset + put: + tags: + - quota + summary: Update transfer quota usage limits + description: Sets the current used transfer quota limits for the given user + operationId: user_transfer_quota_update_usage + requestBody: + required: true + description: 'If used_upload_data_transfer and used_download_data_transfer are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' + content: + application/json: + schema: + $ref: '#/components/schemas/TransferQuotaUsage' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Quota updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '409': + $ref: '#/components/responses/Conflict' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/folders/scans: + get: + tags: + - quota + summary: Get active folder quota scans + description: Returns the active folder quota scans + operationId: get_folders_quota_scans + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/FolderQuotaScan' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/folders/{name}/scan: + parameters: + - name: name + in: path + description: folder name + required: true + schema: + type: string + post: + tags: + - quota + summary: Start a folder quota scan + description: Starts a new quota scan for the given folder. A quota scan update the number of files and their total size for the specified folder + operationId: start_folder_quota_scan + responses: + '202': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Scan started + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '409': + $ref: '#/components/responses/Conflict' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /quotas/folders/{name}/usage: + parameters: + - name: name + in: path + description: folder name + required: true + schema: + type: string + - in: query + name: mode + required: false + description: the update mode specifies if the given quota usage values should be added or replace the current ones + schema: + type: string + enum: + - add + - reset + description: | + Update type: + * `add` - add the specified quota limits to the current used ones + * `reset` - reset the values to the specified ones. This is the default + example: reset + put: + tags: + - quota + summary: Update folder quota usage limits + description: Sets the current used quota limits for the given folder + operationId: folder_quota_update_usage + parameters: + - in: query + name: mode + required: false + description: the update mode specifies if the given quota usage values should be added or replace the current ones + schema: + type: string + enum: + - add + - reset + description: | + Update type: + * `add` - add the specified quota limits to the current used ones + * `reset` - reset the values to the specified ones. This is the default + example: reset + requestBody: + required: true + description: 'If used_quota_size and used_quota_files are missing they will default to 0, this means that if mode is "add" the current value, for the missing field, will remain unchanged, if mode is "reset" the missing field is set to 0' + content: + application/json: + schema: + $ref: '#/components/schemas/QuotaUsage' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Quota updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '409': + $ref: '#/components/responses/Conflict' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /folders: + get: + tags: + - folders + summary: Get folders + description: Returns an array with one or more folders + operationId: get_folders + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering folders by name. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/BaseVirtualFolder' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - folders + summary: Add folder + operationId: add_folder + description: Adds a new folder. A quota scan is required to update the used files/size + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/BaseVirtualFolder' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/BaseVirtualFolder' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/folders/{name}': + parameters: + - name: name + in: path + description: folder name + required: true + schema: + type: string + get: + tags: + - folders + summary: Find folders by name + description: Returns the folder with the given name if it exists. + operationId: get_folder_by_name + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/BaseVirtualFolder' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - folders + summary: Update folder + description: Updates an existing folder + operationId: update_folder + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/BaseVirtualFolder' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Folder updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - folders + summary: Delete folder + description: Deletes an existing folder + operationId: delete_folder + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: User deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /groups: + get: + tags: + - groups + summary: Get groups + description: Returns an array with one or more groups + operationId: get_groups + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering groups by name. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Group' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - groups + summary: Add group + operationId: add_group + description: Adds a new group + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Group' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/Group' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/groups/{name}': + parameters: + - name: name + in: path + description: group name + required: true + schema: + type: string + get: + tags: + - groups + summary: Find groups by name + description: Returns the group with the given name if it exists. + operationId: get_group_by_name + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/Group' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - groups + summary: Update group + description: Updates an existing group + operationId: update_group + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Group' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Group updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - groups + summary: Delete group + description: Deletes an existing group + operationId: delete_group + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Group deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /roles: + get: + tags: + - roles + summary: Get roles + description: Returns an array with one or more roles + operationId: get_roles + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering groups by name. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Role' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - roles + summary: Add role + operationId: add_role + description: Adds a new role + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Role' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/Role' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/roles/{name}': + parameters: + - name: name + in: path + description: role name + required: true + schema: + type: string + get: + tags: + - roles + summary: Find roles by name + description: Returns the role with the given name if it exists. + operationId: get_role_by_name + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/Role' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - roles + summary: Update role + description: Updates an existing role + operationId: update_role + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Role' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Group updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - roles + summary: Delete role + description: Deletes an existing role + operationId: delete_role + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Group deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /eventactions: + get: + tags: + - event manager + summary: Get event actions + description: Returns an array with one or more event actions + operationId: get_event_actons + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering actions by name. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/BaseEventAction' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - event manager + summary: Add event action + operationId: add_event_action + description: Adds a new event actions + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/BaseEventAction' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/BaseEventAction' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/eventactions/{name}': + parameters: + - name: name + in: path + description: action name + required: true + schema: + type: string + get: + tags: + - event manager + summary: Find event actions by name + description: Returns the event action with the given name if it exists. + operationId: get_event_action_by_name + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/BaseEventAction' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - event manager + summary: Update event action + description: Updates an existing event action + operationId: update_event_action + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/BaseEventAction' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Event action updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - event manager + summary: Delete event action + description: Deletes an existing event action + operationId: delete_event_action + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Event action deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /eventrules: + get: + tags: + - event manager + summary: Get event rules + description: Returns an array with one or more event rules + operationId: get_event_rules + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering rules by name. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/EventRule' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - event manager + summary: Add event rule + operationId: add_event_rule + description: Adds a new event rule + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EventRuleMinimal' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/EventRule' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/eventrules/{name}': + parameters: + - name: name + in: path + description: rule name + required: true + schema: + type: string + get: + tags: + - event manager + summary: Find event rules by name + description: Returns the event rule with the given name if it exists. + operationId: get_event_rile_by_name + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/EventRule' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - event manager + summary: Update event rule + description: Updates an existing event rule + operationId: update_event_rule + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/EventRuleMinimal' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Event rules updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - event manager + summary: Delete event rule + description: Deletes an existing event rule + operationId: delete_event_rule + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Event rules deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/eventrules/run/{name}': + parameters: + - name: name + in: path + description: on-demand rule name + required: true + schema: + type: string + post: + tags: + - event manager + summary: Run an on-demand event rule + description: The rule's actions will run in background. SFTPGo will not monitor any concurrency and such. If you want to be notified at the end of the execution please add an appropriate action + operationId: run_event_rule + responses: + '202': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Event rule started + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /events/fs: + get: + tags: + - events + summary: Get filesystem events + description: 'Returns an array with one or more filesystem events applying the specified filters. This API is only available if you configure an "eventsearcher" plugin' + operationId: get_fs_events + parameters: + - in: query + name: start_timestamp + schema: + type: integer + format: int64 + minimum: 0 + default: 0 + required: false + description: 'the event timestamp, unix timestamp in nanoseconds, must be greater than or equal to the specified one. 0 or missing means omit this filter' + - in: query + name: end_timestamp + schema: + type: integer + format: int64 + minimum: 0 + default: 0 + required: false + description: 'the event timestamp, unix timestamp in nanoseconds, must be less than or equal to the specified one. 0 or missing means omit this filter' + - in: query + name: actions + schema: + type: array + items: + $ref: '#/components/schemas/FsEventAction' + description: 'the event action must be included among those specified. Empty or missing means omit this filter. Actions must be specified comma separated' + explode: false + required: false + - in: query + name: username + schema: + type: string + description: 'the event username must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: ip + schema: + type: string + description: 'the event IP must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: ssh_cmd + schema: + type: string + description: 'the event SSH command must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: fs_provider + schema: + $ref: '#/components/schemas/FsProviders' + description: 'the event filesystem provider must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: bucket + schema: + type: string + description: 'the bucket must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: endpoint + schema: + type: string + description: 'the endpoint must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: protocols + schema: + type: array + items: + $ref: '#/components/schemas/EventProtocols' + description: 'the event protocol must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: statuses + schema: + type: array + items: + $ref: '#/components/schemas/FsEventStatus' + description: 'the event status must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: instance_ids + schema: + type: array + items: + type: string + description: 'the event instance id must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: from_id + schema: + type: string + description: 'the event id to start from. This is useful for cursor based pagination. Empty or missing means omit this filter.' + required: false + - in: query + name: role + schema: + type: string + description: 'User role. Empty or missing means omit this filter. Ignored if the admin has a role' + required: false + - in: query + name: csv_export + schema: + type: boolean + default: false + required: false + description: 'If enabled, events are exported as a CSV file' + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 1000 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 1000, default is 100' + - in: query + name: order + required: false + description: Ordering events by timestamp. Default DESC + schema: + type: string + enum: + - ASC + - DESC + example: DESC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/FsEvent' + text/csv: + schema: + type: string + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /events/provider: + get: + tags: + - events + summary: Get provider events + description: 'Returns an array with one or more provider events applying the specified filters. This API is only available if you configure an "eventsearcher" plugin' + operationId: get_provider_events + parameters: + - in: query + name: start_timestamp + schema: + type: integer + format: int64 + minimum: 0 + default: 0 + required: false + description: 'the event timestamp, unix timestamp in nanoseconds, must be greater than or equal to the specified one. 0 or missing means omit this filter' + - in: query + name: end_timestamp + schema: + type: integer + format: int64 + minimum: 0 + default: 0 + required: false + description: 'the event timestamp, unix timestamp in nanoseconds, must be less than or equal to the specified one. 0 or missing means omit this filter' + - in: query + name: actions + schema: + type: array + items: + $ref: '#/components/schemas/ProviderEventAction' + description: 'the event action must be included among those specified. Empty or missing means omit this filter. Actions must be specified comma separated' + explode: false + required: false + - in: query + name: username + schema: + type: string + description: 'the event username must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: ip + schema: + type: string + description: 'the event IP must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: object_name + schema: + type: string + description: 'the event object name must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: object_types + schema: + type: array + items: + $ref: '#/components/schemas/ProviderEventObjectType' + description: 'the event object type must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: instance_ids + schema: + type: array + items: + type: string + description: 'the event instance id must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: from_id + schema: + type: string + description: 'the event id to start from. This is useful for cursor based pagination. Empty or missing means omit this filter.' + required: false + - in: query + name: role + schema: + type: string + description: 'Admin role. Empty or missing means omit this filter. Ignored if the admin has a role' + required: false + - in: query + name: csv_export + schema: + type: boolean + default: false + required: false + description: 'If enabled, events are exported as a CSV file' + - in: query + name: omit_object_data + schema: + type: boolean + default: false + required: false + description: 'If enabled, returned events will not contain the `object_data` field' + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 1000 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 1000, default is 100' + - in: query + name: order + required: false + description: Ordering events by timestamp. Default DESC + schema: + type: string + enum: + - ASC + - DESC + example: DESC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProviderEvent' + text/csv: + schema: + type: string + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /events/logs: + get: + tags: + - events + summary: Get log events + description: 'Returns an array with one or more log events applying the specified filters. This API is only available if you configure an "eventsearcher" plugin' + operationId: get_log_events + parameters: + - in: query + name: start_timestamp + schema: + type: integer + format: int64 + minimum: 0 + default: 0 + required: false + description: 'the event timestamp, unix timestamp in nanoseconds, must be greater than or equal to the specified one. 0 or missing means omit this filter' + - in: query + name: end_timestamp + schema: + type: integer + format: int64 + minimum: 0 + default: 0 + required: false + description: 'the event timestamp, unix timestamp in nanoseconds, must be less than or equal to the specified one. 0 or missing means omit this filter' + - in: query + name: events + schema: + type: array + items: + $ref: '#/components/schemas/LogEventType' + description: 'the log events must be included among those specified. Empty or missing means omit this filter. Events must be specified comma separated' + explode: false + required: false + - in: query + name: username + schema: + type: string + description: 'the event username must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: ip + schema: + type: string + description: 'the event IP must be the same as the one specified. Empty or missing means omit this filter' + required: false + - in: query + name: protocols + schema: + type: array + items: + $ref: '#/components/schemas/EventProtocols' + description: 'the event protocol must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: instance_ids + schema: + type: array + items: + type: string + description: 'the event instance id must be included among those specified. Empty or missing means omit this filter. Values must be specified comma separated' + explode: false + required: false + - in: query + name: from_id + schema: + type: string + description: 'the event id to start from. This is useful for cursor based pagination. Empty or missing means omit this filter.' + required: false + - in: query + name: role + schema: + type: string + description: 'User role. Empty or missing means omit this filter. Ignored if the admin has a role' + required: false + - in: query + name: csv_export + schema: + type: boolean + default: false + required: false + description: 'If enabled, events are exported as a CSV file' + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 1000 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 1000, default is 100' + - in: query + name: order + required: false + description: Ordering events by timestamp. Default DESC + schema: + type: string + enum: + - ASC + - DESC + example: DESC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/LogEvent' + text/csv: + schema: + type: string + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /apikeys: + get: + security: + - BearerAuth: [] + tags: + - API keys + summary: Get API keys + description: Returns an array with one or more API keys. For security reasons hashed keys are omitted in the response + operationId: get_api_keys + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering API keys by id. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/APIKey' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + security: + - BearerAuth: [] + tags: + - API keys + summary: Add API key + description: Adds a new API key + operationId: add_api_key + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/APIKey' + responses: + '201': + description: successful operation + headers: + X-Object-ID: + schema: + type: string + description: ID for the new created API key + Location: + schema: + type: string + description: URI to retrieve the details for the new created API key + content: + application/json: + schema: + type: object + properties: + mesage: + type: string + example: 'API key created. This is the only time the API key is visible, please save it.' + key: + type: string + description: 'generated API key' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/apikeys/{id}': + parameters: + - name: id + in: path + description: the key id + required: true + schema: + type: string + get: + security: + - BearerAuth: [] + tags: + - API keys + summary: Find API key by id + description: Returns the API key with the given id, if it exists. For security reasons the hashed key is omitted in the response + operationId: get_api_key_by_id + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/APIKey' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + security: + - BearerAuth: [] + tags: + - API keys + summary: Update API key + description: Updates an existing API key. You cannot update the key itself, the creation date and the last use + operationId: update_api_key + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/APIKey' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: API key updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + security: + - BearerAuth: [] + tags: + - API keys + summary: Delete API key + description: Deletes an existing API key + operationId: delete_api_key + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Admin deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /admins: + get: + tags: + - admins + summary: Get admins + description: Returns an array with one or more admins. For security reasons hashed passwords are omitted in the response + operationId: get_admins + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering admins by username. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Admin' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - admins + summary: Add admin + description: 'Adds a new admin. Recovery codes and TOTP configuration cannot be set using this API: each admin must use the specific APIs' + operationId: add_admin + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Admin' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/Admin' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/admins/{username}': + parameters: + - name: username + in: path + description: the admin username + required: true + schema: + type: string + get: + tags: + - admins + summary: Find admins by username + description: 'Returns the admin with the given username, if it exists. For security reasons the hashed password is omitted in the response' + operationId: get_admin_by_username + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/Admin' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - admins + summary: Update admin + description: 'Updates an existing admin. Recovery codes and TOTP configuration cannot be set/updated using this API: each admin must use the specific APIs. You are not allowed to update the admin impersonated using an API key' + operationId: update_admin + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Admin' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Admin updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - admins + summary: Delete admin + description: Deletes an existing admin + operationId: delete_admin + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Admin deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/admins/{username}/2fa/disable': + parameters: + - name: username + in: path + description: the admin username + required: true + schema: + type: string + put: + tags: + - admins + summary: Disable second factor authentication + description: 'Disables second factor authentication for the given admin. This API must be used if the admin loses access to their second factor auth device and has no recovery codes' + operationId: disable_admin_2fa + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: 2FA disabled + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/admins/{username}/forgot-password': + parameters: + - name: username + in: path + description: the admin username + required: true + schema: + type: string + post: + security: [] + tags: + - admins + summary: Send a password reset code by email + description: 'You must set up an SMTP server and the account must have a valid email address, in which case SFTPGo will send a code via email to reset the password. If the specified admin does not exist, the request will be silently ignored (a success response will be returned) to avoid disclosing existing admins' + operationId: admin_forgot_password + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/admins/{username}/reset-password': + parameters: + - name: username + in: path + description: the admin username + required: true + schema: + type: string + post: + security: [] + tags: + - admins + summary: Reset the password + description: 'Set a new password using the code received via email' + operationId: admin_reset_password + requestBody: + content: + application/json: + schema: + type: object + properties: + code: + type: string + password: + type: string + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /users: + get: + tags: + - users + summary: Get users + description: Returns an array with one or more users. For security reasons hashed passwords are omitted in the response + operationId: get_users + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering users by username. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/User' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - users + summary: Add user + description: 'Adds a new user.Recovery codes and TOTP configuration cannot be set using this API: each user must use the specific APIs' + operationId: add_user + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the hash of the password and the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/User' + responses: + '201': + description: successful operation + headers: + Location: + schema: + type: string + description: 'URI of the newly created object' + content: + application/json: + schema: + $ref: '#/components/schemas/User' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/users/{username}': + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + get: + tags: + - users + summary: Find users by username + description: Returns the user with the given username if it exists. For security reasons the hashed password is omitted in the response + operationId: get_user_by_username + parameters: + - in: query + name: confidential_data + schema: + type: integer + description: 'If set to 1 confidential data will not be hidden. This means that the response will contain the hash of the password and the key and additional data for secrets. If a master key is not set or an external KMS is used, the data returned are enough to get the secrets in cleartext. Ignored if the * permission is not granted.' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/User' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - users + summary: Update user + description: 'Updates an existing user and optionally disconnects it, if connected, to apply the new settings. The current password will be preserved if the password field is omitted in the request body. Recovery codes and TOTP configuration cannot be set/updated using this API: each user must use the specific APIs' + operationId: update_user + parameters: + - in: query + name: disconnect + schema: + type: integer + enum: + - 0 + - 1 + description: | + Disconnect: + * `0` The user will not be disconnected and it will continue to use the old configuration until connected. This is the default + * `1` The user will be disconnected after a successful update. It must login again and so it will be forced to use the new configuration + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/User' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: User updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - users + summary: Delete user + description: Deletes an existing user + operationId: delete_user + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: User deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/users/{username}/2fa/disable': + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + put: + tags: + - users + summary: Disable second factor authentication + description: 'Disables second factor authentication for the given user. This API must be used if the user loses access to their second factor auth device and has no recovery codes' + operationId: disable_user_2fa + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: 2FA disabled + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/users/{username}/forgot-password': + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + post: + security: [] + tags: + - users + summary: Send a password reset code by email + description: 'You must configure an SMTP server, the account must have a valid email address and must not have the "reset-password-disabled" restriction, in which case SFTPGo will send a code via email to reset the password. If the specified user does not exist, the request will be silently ignored (a success response will be returned) to avoid disclosing existing users' + operationId: user_forgot_password + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/users/{username}/reset-password': + parameters: + - name: username + in: path + description: the username + required: true + schema: + type: string + post: + security: [] + tags: + - users + summary: Reset the password + description: 'Set a new password using the code received via email' + operationId: user_reset_password + requestBody: + content: + application/json: + schema: + type: object + properties: + code: + type: string + password: + type: string + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /status: + get: + tags: + - maintenance + summary: Get status + description: Retrieves the status of the active services + operationId: get_status + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ServicesStatus' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /dumpdata: + get: + tags: + - maintenance + summary: Dump data + description: 'Backups data as data provider independent JSON. The backup can be saved in a local file on the server, to avoid exposing sensitive data over the network, or returned as response body. The output of dumpdata can be used as input for loaddata' + operationId: dumpdata + parameters: + - in: query + name: output-file + schema: + type: string + description: Path for the file to write the JSON serialized data to. This path is relative to the configured "backups_path". If this file already exists it will be overwritten. To return the backup as response body set `output_data` to true instead. + - in: query + name: output-data + schema: + type: integer + enum: + - 0 + - 1 + description: | + output data: + * `0` or any other value != 1, the backup will be saved to a file on the server, `output_file` is required + * `1` the backup will be returned as response body + - in: query + name: indent + schema: + type: integer + enum: + - 0 + - 1 + description: | + indent: + * `0` no indentation. This is the default + * `1` format the output JSON + - in: query + name: scopes + schema: + type: array + items: + $ref: '#/components/schemas/DumpDataScopes' + description: 'You can limit the dump contents to the specified scopes. Empty or missing means any supported scope. Scopes must be specified comma separated' + explode: false + required: false + responses: + '200': + description: successful operation + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/ApiResponse' + - $ref: '#/components/schemas/BackupData' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /loaddata: + parameters: + - in: query + name: scan-quota + schema: + type: integer + enum: + - 0 + - 1 + - 2 + description: | + Quota scan: + * `0` no quota scan is done, the imported users/folders will have used_quota_size and used_quota_files = 0 or the existing values if they already exists. This is the default + * `1` scan quota + * `2` scan quota if the user has quota restrictions + required: false + - in: query + name: mode + schema: + type: integer + enum: + - 0 + - 1 + - 2 + description: | + Mode: + * `0` New objects are added, existing ones are updated. This is the default + * `1` New objects are added, existing ones are not modified + * `2` New objects are added, existing ones are updated and connected users are disconnected and so forced to use the new configuration + get: + tags: + - maintenance + summary: Load data from path + description: 'Restores SFTPGo data from a JSON backup file on the server. Objects will be restored one by one and the restore is stopped if a object cannot be added or updated, so it could happen a partial restore' + operationId: loaddata_from_file + parameters: + - in: query + name: input-file + schema: + type: string + required: true + description: Path for the file to read the JSON serialized data from. This can be an absolute path or a path relative to the configured "backups_path". The max allowed file size is 10MB + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Data restored + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - maintenance + summary: Load data + description: 'Restores SFTPGo data from a JSON backup. Objects will be restored one by one and the restore is stopped if a object cannot be added or updated, so it could happen a partial restore' + operationId: loaddata_from_request_body + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/BackupData' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Data restored + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/changepwd: + put: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Change user password + description: Changes the password for the logged in user + operationId: change_user_password + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PwdChange' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/profile: + get: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Get user profile + description: 'Returns the profile for the logged in user' + operationId: get_user_profile + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/UserProfile' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Update user profile + description: 'Allows to update the profile for the logged in user' + operationId: update_user_profile + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserProfile' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/2fa/recoverycodes: + get: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Get recovery codes + description: 'Returns the recovery codes for the logged in user. Recovery codes can be used if the user loses access to their second factor auth device. Recovery codes are returned unencrypted' + operationId: get_user_recovery_codes + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/RecoveryCode' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Generate recovery codes + description: 'Generates new recovery codes for the logged in user. Generating new recovery codes you automatically invalidate old ones' + operationId: generate_user_recovery_codes + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + type: string + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/totp/configs: + get: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Get available TOTP configuration + description: Returns the available TOTP configurations for the logged in user + operationId: get_user_totp_configs + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/TOTPConfig' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/totp/generate: + post: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Generate a new TOTP secret + description: 'Generates a new TOTP secret, including the QR code as png, using the specified configuration for the logged in user' + operationId: generate_user_totp_secret + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + config_name: + type: string + description: 'name of the configuration to use to generate the secret' + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: object + properties: + config_name: + type: string + issuer: + type: string + secret: + type: string + qr_code: + type: string + format: byte + description: 'QR code png encoded as BASE64' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/totp/validate: + post: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Validate a one time authentication code + description: 'Checks if the given authentication code can be validated using the specified secret and config name' + operationId: validate_user_totp_secret + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + config_name: + type: string + description: 'name of the configuration to use to validate the passcode' + passcode: + type: string + description: 'passcode to validate' + secret: + type: string + description: 'secret to use to validate the passcode' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Passcode successfully validated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/totp/save: + post: + security: + - BearerAuth: [] + tags: + - user APIs + summary: Save a TOTP config + description: 'Saves the specified TOTP config for the logged in user' + operationId: save_user_totp_config + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/UserTOTPConfig' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: TOTP configuration saved + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/shares: + get: + tags: + - user APIs + summary: List user shares + description: Returns the share for the logged in user + operationId: get_user_shares + parameters: + - in: query + name: offset + schema: + type: integer + minimum: 0 + default: 0 + required: false + - in: query + name: limit + schema: + type: integer + minimum: 1 + maximum: 500 + default: 100 + required: false + description: 'The maximum number of items to return. Max value is 500, default is 100' + - in: query + name: order + required: false + description: Ordering shares by ID. Default ASC + schema: + type: string + enum: + - ASC + - DESC + example: ASC + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Share' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - user APIs + summary: Add a share + operationId: add_share + description: 'Adds a new share. The share id will be auto-generated' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Share' + responses: + '201': + description: successful operation + headers: + X-Object-ID: + schema: + type: string + description: ID for the new created share + Location: + schema: + type: string + description: URI to retrieve the details for the new created share + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + '/user/shares/{id}': + parameters: + - name: id + in: path + description: the share id + required: true + schema: + type: string + get: + tags: + - user APIs + summary: Get share by id + description: Returns a share by id for the logged in user + operationId: get_user_share_by_id + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/Share' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + put: + tags: + - user APIs + summary: Update share + description: 'Updates an existing share belonging to the logged in user' + operationId: update_user_share + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Share' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Share updated + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - user APIs + summary: Delete share + description: 'Deletes an existing share belonging to the logged in user' + operationId: delete_user_share + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + example: + message: Share deleted + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/file-actions/copy: + parameters: + - in: query + name: path + description: Path to the file/folder to copy. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + - in: query + name: target + description: New name. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + post: + tags: + - user APIs + summary: 'Copy a file or a directory' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/file-actions/move: + parameters: + - in: query + name: path + description: Path to the file/folder to rename. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + - in: query + name: target + description: New name. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + post: + tags: + - user APIs + summary: 'Move (rename) a file or a directory' + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/dirs: + get: + tags: + - user APIs + summary: Read directory contents + description: Returns the contents of the specified directory for the logged in user + operationId: get_user_dir_contents + parameters: + - in: query + name: path + description: Path to the folder to read. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir". If empty or missing the user's start directory is assumed. If relative, the user's start directory is used as the base + schema: + type: string + responses: + '200': + description: successful operation + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/DirEntry' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - user APIs + summary: Create a directory + description: Create a directory for the logged in user + operationId: create_user_dir + parameters: + - in: query + name: path + description: Path to the folder to create. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + - in: query + name: mkdir_parents + description: Create parent directories if they do not exist? + schema: + type: boolean + required: false + responses: + '201': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + patch: + tags: + - user APIs + deprecated: true + summary: 'Rename a directory. Deprecated, use "file-actions/move"' + description: Rename a directory for the logged in user. The rename is allowed for empty directory or for non empty local directories, with no virtual folders inside + operationId: rename_user_dir + parameters: + - in: query + name: path + description: Path to the folder to rename. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + - in: query + name: target + description: New name. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - user APIs + summary: Delete a directory + description: Delete a directory and any children it contains for the logged in user + operationId: delete_user_dir + parameters: + - in: query + name: path + description: Path to the folder to delete. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir" + schema: + type: string + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/files: + get: + tags: + - user APIs + summary: Download a single file + description: Returns the file contents as response body + operationId: download_user_file + parameters: + - in: query + name: path + required: true + description: Path to the file to download. It must be URL encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt" + schema: + type: string + - in: query + name: inline + required: false + description: 'If set, the response will not have the Content-Disposition header set to `attachment`' + schema: + type: string + responses: + '200': + description: successful operation + content: + '*/*': + schema: + type: string + format: binary + '206': + description: successful operation + content: + '*/*': + schema: + type: string + format: binary + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + post: + tags: + - user APIs + summary: Upload files + description: Upload one or more files for the logged in user + operationId: create_user_files + parameters: + - in: query + name: path + description: Parent directory for the uploaded files. It must be URL encoded, for example the path "my dir/àdir" must be sent as "my%20dir%2F%C3%A0dir". If empty or missing the root path is assumed. If a file with the same name already exists, it will be overwritten + schema: + type: string + - in: query + name: mkdir_parents + description: Create parent directories if they do not exist? + schema: + type: boolean + required: false + requestBody: + content: + multipart/form-data: + schema: + type: object + properties: + filenames: + type: array + items: + type: string + format: binary + minItems: 1 + uniqueItems: true + required: true + responses: + '201': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '413': + $ref: '#/components/responses/RequestEntityTooLarge' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + patch: + tags: + - user APIs + deprecated: true + summary: Rename a file + description: 'Rename a file for the logged in user. Deprecated, use "file-actions/move"' + operationId: rename_user_file + parameters: + - in: query + name: path + description: Path to the file to rename. It must be URL encoded + schema: + type: string + required: true + - in: query + name: target + description: New name. It must be URL encoded + schema: + type: string + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + delete: + tags: + - user APIs + summary: Delete a file + description: Delete a file for the logged in user. + operationId: delete_user_file + parameters: + - in: query + name: path + description: Path to the file to delete. It must be URL encoded + schema: + type: string + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/files/upload: + post: + tags: + - user APIs + summary: Upload a single file + description: 'Upload a single file for the logged in user to an existing directory. This API does not use multipart/form-data and so no temporary files are created server side but only a single file can be uploaded as POST body' + operationId: create_user_file + parameters: + - in: query + name: path + description: Full file path. It must be path encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt". The parent directory must exist. If a file with the same name already exists, it will be overwritten + schema: + type: string + required: true + - in: query + name: mkdir_parents + description: Create parent directories if they do not exist? + schema: + type: boolean + required: false + - in: header + name: X-SFTPGO-MTIME + schema: + type: integer + description: File modification time as unix timestamp in milliseconds + requestBody: + content: + application/*: + schema: + type: string + format: binary + text/*: + schema: + type: string + format: binary + image/*: + schema: + type: string + format: binary + audio/*: + schema: + type: string + format: binary + video/*: + schema: + type: string + format: binary + required: true + responses: + '201': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '413': + $ref: '#/components/responses/RequestEntityTooLarge' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/files/metadata: + patch: + tags: + - user APIs + summary: Set metadata for a file/directory + description: 'Set supported metadata attributes for the specified file or directory' + operationId: setprops_user_file + parameters: + - in: query + name: path + description: Full file/directory path. It must be URL encoded, for example the path "my dir/àdir/file.txt" must be sent as "my%20dir%2F%C3%A0dir%2Ffile.txt" + schema: + type: string + required: true + requestBody: + content: + application/json: + schema: + type: object + properties: + modification_time: + type: integer + description: File modification time as unix timestamp in milliseconds + required: true + responses: + '200': + description: successful operation + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '413': + $ref: '#/components/responses/RequestEntityTooLarge' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' + /user/streamzip: + post: + tags: + - user APIs + summary: Download multiple files and folders as a single zip file + description: A zip file, containing the specified files and folders, will be generated on the fly and returned as response body. Only folders and regular files will be included in the zip + operationId: streamzip + requestBody: + required: true + content: + application/json: + schema: + type: array + items: + type: string + description: Absolute file or folder path + responses: + '200': + description: successful operation + content: + 'application/zip': + schema: + type: string + format: binary + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '500': + $ref: '#/components/responses/InternalServerError' + default: + $ref: '#/components/responses/DefaultResponse' +components: + responses: + BadRequest: + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + Unauthorized: + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + Forbidden: + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + NotFound: + description: Not Found + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + Conflict: + description: Conflict + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + RequestEntityTooLarge: + description: Request Entity Too Large, max allowed size exceeded + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + InternalServerError: + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + DefaultResponse: + description: Unexpected Error + content: + application/json: + schema: + $ref: '#/components/schemas/ApiResponse' + schemas: + Permission: + type: string + enum: + - '*' + - list + - download + - upload + - overwrite + - delete + - delete_files + - delete_dirs + - rename + - rename_files + - rename_dirs + - create_dirs + - create_symlinks + - chmod + - chown + - chtimes + - copy + description: | + Permissions: + * `*` - all permissions are granted + * `list` - list items is allowed + * `download` - download files is allowed + * `upload` - upload files is allowed + * `overwrite` - overwrite an existing file, while uploading, is allowed. upload permission is required to allow file overwrite + * `delete` - delete files or directories is allowed + * `delete_files` - delete files is allowed + * `delete_dirs` - delete directories is allowed + * `rename` - rename files or directories is allowed + * `rename_files` - rename files is allowed + * `rename_dirs` - rename directories is allowed + * `create_dirs` - create directories is allowed + * `create_symlinks` - create links is allowed + * `chmod` changing file or directory permissions is allowed + * `chown` changing file or directory owner and group is allowed + * `chtimes` changing file or directory access and modification time is allowed + * `copy`, copying files or directories is allowed + AdminPermissions: + type: string + enum: + - '*' + - add_users + - edit_users + - del_users + - view_users + - view_conns + - close_conns + - view_status + - manage_folders + - manage_groups + - quota_scans + - manage_defender + - view_defender + - view_events + - disable_mfa + description: | + Admin permissions: + * `*` - super admin permissions are granted + * `add_users` - add new users is allowed + * `edit_users` - change existing users is allowed + * `del_users` - remove users is allowed + * `view_users` - list users is allowed + * `view_conns` - list active connections is allowed + * `close_conns` - close active connections is allowed + * `view_status` - view the server status is allowed + * `manage_folders` - manage folders is allowed + * `manage_groups` - manage groups is allowed + * `quota_scans` - view and start quota scans is allowed + * `manage_defender` - remove ip from the dynamic blocklist is allowed + * `view_defender` - list the dynamic blocklist is allowed + * `view_events` - view and search filesystem and provider events is allowed + * `disable_mfa` - allow to disable two-factor authentication for users and admins + FsProviders: + type: integer + enum: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + description: | + Filesystem providers: + * `0` - Local filesystem + * `1` - S3 Compatible Object Storage + * `2` - Google Cloud Storage + * `3` - Azure Blob Storage + * `4` - Local filesystem encrypted + * `5` - SFTP + * `6` - HTTP filesystem + EventActionTypes: + type: integer + enum: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 11 + - 12 + - 13 + - 14 + - 15 + description: | + Supported event action types: + * `1` - HTTP + * `2` - Command + * `3` - Email + * `4` - Backup + * `5` - User quota reset + * `6` - Folder quota reset + * `7` - Transfer quota reset + * `8` - Data retention check + * `9` - Filesystem + * `11` - Password expiration check + * `12` - User expiration check + * `13` - Identity Provider account check + * `14` - User inactivity check + * `15` - Rotate log file + FilesystemActionTypes: + type: integer + enum: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + description: | + Supported filesystem action types: + * `1` - Rename + * `2` - Delete + * `3` - Mkdis + * `4` - Exist + * `5` - Compress + * `6` - Copy + EventTriggerTypes: + type: integer + enum: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + description: | + Supported event trigger types: + * `1` - Filesystem event + * `2` - Provider event + * `3` - Schedule + * `4` - IP blocked + * `5` - Certificate renewal + * `6` - On demand, like schedule but executed on demand + * `7` - Identity provider login + LoginMethods: + type: string + enum: + - publickey + - password + - password-over-SSH + - keyboard-interactive + - publickey+password + - publickey+keyboard-interactive + - TLSCertificate + - TLSCertificate+password + description: | + Available login methods. To enable multi-step authentication you have to allow only multi-step login methods + * `publickey` + * `password`, password for all the supported protocols + * `password-over-SSH`, password over SSH protocol (SSH/SFTP/SCP) + * `keyboard-interactive` + * `publickey+password` - multi-step auth: public key and password + * `publickey+keyboard-interactive` - multi-step auth: public key and keyboard interactive + * `TLSCertificate` + * `TLSCertificate+password` - multi-step auth: TLS client certificate and password + SupportedProtocols: + type: string + enum: + - SSH + - FTP + - DAV + - HTTP + description: | + Protocols: + * `SSH` - includes both SFTP and SSH commands + * `FTP` - plain FTP and FTPES/FTPS + * `DAV` - WebDAV over HTTP/HTTPS + * `HTTP` - WebClient/REST API + MFAProtocols: + type: string + enum: + - SSH + - FTP + - HTTP + description: | + Protocols: + * `SSH` - includes both SFTP and SSH commands + * `FTP` - plain FTP and FTPES/FTPS + * `HTTP` - WebClient/REST API + EventProtocols: + type: string + enum: + - SSH + - SFTP + - SCP + - FTP + - DAV + - HTTP + - HTTPShare + - DataRetention + - EventAction + - OIDC + description: | + Protocols: + * `SSH` - SSH commands + * `SFTP` - SFTP protocol + * `SCP` - SCP protocol + * `FTP` - plain FTP and FTPES/FTPS + * `DAV` - WebDAV + * `HTTP` - WebClient/REST API + * `HTTPShare` - the event is generated in a public share + * `DataRetention` - the event is generated by a data retention check + * `EventAction` - the event is generated by an EventManager action + * `OIDC` - OpenID Connect + WebClientOptions: + type: string + enum: + - publickey-change-disabled + - tls-cert-change-disabled + - write-disabled + - mfa-disabled + - password-change-disabled + - api-key-auth-change-disabled + - info-change-disabled + - shares-disabled + - password-reset-disabled + - shares-without-password-disabled + description: | + Options: + * `publickey-change-disabled` - changing SSH public keys is not allowed + * `tls-cert-change-disabled` - changing TLS certificates is not allowed + * `write-disabled` - upload, rename, delete are not allowed even if the user has permissions for these actions + * `mfa-disabled` - enabling multi-factor authentication is not allowed. This option cannot be set if the user has MFA already enabled + * `password-change-disabled` - changing password is not allowed + * `api-key-auth-change-disabled` - enabling/disabling API key authentication is not allowed + * `info-change-disabled` - changing info such as email and description is not allowed + * `shares-disabled` - sharing files and directories with external users is not allowed + * `password-reset-disabled` - resetting the password is not allowed + * `shares-without-password-disabled` - creating shares without password protection is not allowed + APIKeyScope: + type: integer + enum: + - 1 + - 2 + description: | + Options: + * `1` - admin scope. The API key will be used to impersonate an SFTPGo admin + * `2` - user scope. The API key will be used to impersonate an SFTPGo user + ShareScope: + type: integer + enum: + - 1 + - 2 + description: | + Options: + * `1` - read scope + * `2` - write scope + TOTPHMacAlgo: + type: string + enum: + - sha1 + - sha256 + - sha512 + description: 'Supported HMAC algorithms for Time-based one time passwords' + UserType: + type: string + enum: + - '' + - LDAPUser + - OSUser + description: This is an hint for authentication plugins. It is ignored when using SFTPGo internal authentication + DumpDataScopes: + type: string + enum: + - users + - folders + - groups + - admins + - api_keys + - shares + - actions + - rules + - roles + - ip_lists + - configs + LogEventType: + type: integer + enum: + - 1 + - 2 + - 3 + - 4 + - 5 + description: > + Event status: + * `1` - Login failed + * `2` - Login failed non-existent user + * `3` - No login tried + * `4` - Algorithm negotiation failed + * `5` - Login succeeded + FsEventStatus: + type: integer + enum: + - 1 + - 2 + - 3 + description: > + Event status: + * `1` - no error + * `2` - generic error + * `3` - quota exceeded error + FsEventAction: + type: string + enum: + - download + - upload + - first-upload + - first-download + - delete + - rename + - mkdir + - rmdir + - ssh_cmd + ProviderEventAction: + type: string + enum: + - add + - update + - delete + ProviderEventObjectType: + type: string + enum: + - user + - folder + - group + - admin + - api_key + - share + - event_action + - event_rule + - role + SSHAuthentications: + type: string + enum: + - publickey + - password + - keyboard-interactive + - publickey+password + - publickey+keyboard-interactive + TLSVersions: + type: integer + enum: + - 12 + - 13 + description: > + TLS version: + * `12` - TLS 1.2 + * `13` - TLS 1.3 + IPListType: + type: integer + enum: + - 1 + - 2 + - 3 + description: > + IP List types: + * `1` - allow list + * `2` - defender + * `3` - rate limiter safe list + IPListMode: + type: integer + enum: + - 1 + - 2 + description: > + IP list modes + * `1` - allow + * `2` - deny, supported for defender list type only + TOTPConfig: + type: object + properties: + name: + type: string + issuer: + type: string + algo: + $ref: '#/components/schemas/TOTPHMacAlgo' + RecoveryCode: + type: object + properties: + secret: + $ref: '#/components/schemas/Secret' + used: + type: boolean + description: 'Recovery codes to use if the user loses access to their second factor auth device. Each code can only be used once, you should use these codes to login and disable or reset 2FA for your account' + BaseTOTPConfig: + type: object + properties: + enabled: + type: boolean + config_name: + type: string + description: 'This name must be defined within the "totp" section of the SFTPGo configuration file. You will be unable to save a user/admin referencing a missing config_name' + secret: + $ref: '#/components/schemas/Secret' + AdminTOTPConfig: + allOf: + - $ref: '#/components/schemas/BaseTOTPConfig' + UserTOTPConfig: + allOf: + - $ref: '#/components/schemas/BaseTOTPConfig' + - type: object + properties: + protocols: + type: array + items: + $ref: '#/components/schemas/MFAProtocols' + description: 'TOTP will be required for the specified protocols. SSH protocol (SFTP/SCP/SSH commands) will ask for the TOTP passcode if the client uses keyboard interactive authentication. FTP has no standard way to support two factor authentication, if you enable the FTP support, you have to add the TOTP passcode after the password. For example if your password is "password" and your one time passcode is "123456" you have to use "password123456" as password. WebDAV is not supported since each single request must be authenticated and a passcode cannot be reused.' + PatternsFilter: + type: object + properties: + path: + type: string + description: 'virtual path as seen by users, if no other specific filter is defined, the filter applies for sub directories too. For example if filters are defined for the paths "/" and "/sub" then the filters for "/" are applied for any file outside the "/sub" directory' + allowed_patterns: + type: array + items: + type: string + description: 'list of, case insensitive, allowed shell like patterns. Allowed patterns are evaluated before the denied ones' + example: + - '*.jpg' + - a*b?.png + denied_patterns: + type: array + items: + type: string + description: 'list of, case insensitive, denied shell like patterns' + example: + - '*.zip' + deny_policy: + type: integer + enum: + - 0 + - 1 + description: | + Policies for denied patterns + * `0` - default policy. Denied files/directories matching the filters are visible in directory listing but cannot be uploaded/downloaded/overwritten/renamed + * `1` - deny policy hide. This policy applies the same restrictions as the default one and denied files/directories matching the filters will also be hidden in directory listing. This mode may cause performance issues for large directories + HooksFilter: + type: object + properties: + external_auth_disabled: + type: boolean + example: false + description: If true, the external auth hook, if defined, will not be executed + pre_login_disabled: + type: boolean + example: false + description: If true, the pre-login hook, if defined, will not be executed + check_password_disabled: + type: boolean + example: false + description: If true, the check password hook, if defined, will not be executed + description: User specific hook overrides + BandwidthLimit: + type: object + properties: + sources: + type: array + items: + type: string + description: 'Source networks in CIDR notation as defined in RFC 4632 and RFC 4291 for example `192.0.2.0/24` or `2001:db8::/32`. The limit applies if the defined networks contain the client IP' + upload_bandwidth: + type: integer + format: int32 + description: 'Maximum upload bandwidth as KB/s, 0 means unlimited' + download_bandwidth: + type: integer + format: int32 + description: 'Maximum download bandwidth as KB/s, 0 means unlimited' + TimePeriod: + type: object + properties: + day_of_week: + type: integer + enum: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + description: Day of week, 0 Sunday, 6 Saturday + from: + type: string + description: Start time in HH:MM format + to: + type: string + description: End time in HH:MM format + BaseUserFilters: + type: object + properties: + allowed_ip: + type: array + items: + type: string + description: 'only clients connecting from these IP/Mask are allowed. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"' + example: + - 192.0.2.0/24 + - '2001:db8::/32' + denied_ip: + type: array + items: + type: string + description: clients connecting from these IP/Mask are not allowed. Denied rules are evaluated before allowed ones + example: + - 172.16.0.0/16 + denied_login_methods: + type: array + items: + $ref: '#/components/schemas/LoginMethods' + description: if null or empty any available login method is allowed + denied_protocols: + type: array + items: + $ref: '#/components/schemas/SupportedProtocols' + description: if null or empty any available protocol is allowed + file_patterns: + type: array + items: + $ref: '#/components/schemas/PatternsFilter' + description: 'filters based on shell like file patterns. These restrictions do not apply to files listing for performance reasons, so a denied file cannot be downloaded/overwritten/renamed but it will still be in the list of files. Please note that these restrictions can be easily bypassed' + max_upload_file_size: + type: integer + format: int64 + description: 'maximum allowed size, as bytes, for a single file upload. The upload will be aborted if/when the size of the file being sent exceeds this limit. 0 means unlimited' + tls_username: + type: string + description: 'defines the TLS certificate field to use as username. For FTP clients it must match the name provided using the "USER" command. For WebDAV, if no username is provided, the CN will be used as username. For WebDAV clients it must match the implicit or provided username. Ignored if mutual TLS is disabled. Currently the only supported value is `CommonName`' + hooks: + $ref: '#/components/schemas/HooksFilter' + disable_fs_checks: + type: boolean + example: false + description: Disable checks for existence and automatic creation of home directory and virtual folders. SFTPGo requires that the user's home directory, virtual folder root, and intermediate paths to virtual folders exist to work properly. If you already know that the required directories exist, disabling these checks will speed up login. You could, for example, disable these checks after the first login + web_client: + type: array + items: + $ref: '#/components/schemas/WebClientOptions' + description: WebClient/user REST API related configuration options + allow_api_key_auth: + type: boolean + description: 'API key authentication allows to impersonate this user with an API key' + user_type: + $ref: '#/components/schemas/UserType' + bandwidth_limits: + type: array + items: + $ref: '#/components/schemas/BandwidthLimit' + external_auth_cache_time: + type: integer + description: 'Defines the cache time, in seconds, for users authenticated using an external auth hook. 0 means no cache' + start_directory: + type: string + description: 'Specifies an alternate starting directory. If not set, the default is "/". This option is supported for SFTP/SCP, FTP and HTTP (WebClient/REST API) protocols. Relative paths will use this directory as base.' + two_factor_protocols: + type: array + items: + $ref: '#/components/schemas/MFAProtocols' + description: 'Defines protocols that require two factor authentication' + ftp_security: + type: integer + enum: + - 0 + - 1 + description: 'Set to `1` to require TLS for both data and control connection. his setting is useful if you want to allow both encrypted and plain text FTP sessions globally and then you want to require encrypted sessions on a per-user basis. It has no effect if TLS is already required for all users in the configuration file.' + is_anonymous: + type: boolean + description: 'If enabled the user can login with any password or no password at all. Anonymous users are supported for FTP and WebDAV protocols and permissions will be automatically set to "list" and "download" (read only)' + default_shares_expiration: + type: integer + description: 'Defines the default expiration for newly created shares as number of days. 0 means no expiration' + max_shares_expiration: + type: integer + description: 'Defines the maximum allowed expiration, as a number of days, when a user creates or updates a share. 0 means no expiration' + password_expiration: + type: integer + description: 'The password expires after the defined number of days. 0 means no expiration' + password_strength: + type: integer + description: 'Defines the minimum password strength. 0 means disabled, any password will be accepted. Values in the 50-70 range are suggested for common use cases' + access_time: + type: array + items: + $ref: '#/components/schemas/TimePeriod' + description: Additional user options + UserFilters: + allOf: + - $ref: '#/components/schemas/BaseUserFilters' + - type: object + properties: + require_password_change: + type: boolean + description: 'User must change password from WebClient/REST API at next login' + totp_config: + $ref: '#/components/schemas/UserTOTPConfig' + recovery_codes: + type: array + items: + $ref: '#/components/schemas/RecoveryCode' + tls_certs: + type: array + items: + type: string + additional_emails: + type: array + items: + type: string + format: email + Secret: + type: object + properties: + status: + type: string + enum: + - Plain + - AES-256-GCM + - Secretbox + - GCP + - AWS + - VaultTransit + - AzureKeyVault + - Redacted + description: 'Set to "Plain" to add or update an existing secret, set to "Redacted" to preserve the existing value' + payload: + type: string + key: + type: string + additional_data: + type: string + mode: + type: integer + description: 1 means encrypted using a master key + description: The secret is encrypted before saving, so to set a new secret you must provide a payload and set the status to "Plain". The encryption key and additional data will be generated automatically. If you set the status to "Redacted" the existing secret will be preserved + S3Config: + type: object + properties: + bucket: + type: string + minLength: 1 + region: + type: string + minLength: 1 + access_key: + type: string + access_secret: + $ref: '#/components/schemas/Secret' + sse_customer_key: + $ref: '#/components/schemas/Secret' + role_arn: + type: string + description: 'Optional IAM Role ARN to assume' + session_token: + type: string + description: 'Optional Session token that is a part of temporary security credentials provisioned by AWS STS' + endpoint: + type: string + description: optional endpoint + storage_class: + type: string + acl: + type: string + description: 'The canned ACL to apply to uploaded objects. Leave empty to use the default ACL. For more information and available ACLs, see here: https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl' + upload_part_size: + type: integer + description: 'the buffer size (in MB) to use for multipart uploads. The minimum allowed part size is 5MB, and if this value is set to zero, the default value (5MB) for the AWS SDK will be used. The minimum allowed value is 5.' + upload_concurrency: + type: integer + description: 'the number of parts to upload in parallel. If this value is set to zero, the default value (5) will be used' + upload_part_max_time: + type: integer + description: 'the maximum time allowed, in seconds, to upload a single chunk (the chunk size is defined via "upload_part_size"). 0 means no timeout' + download_part_size: + type: integer + description: 'the buffer size (in MB) to use for multipart downloads. The minimum allowed part size is 5MB, and if this value is set to zero, the default value (5MB) for the AWS SDK will be used. The minimum allowed value is 5. Ignored for partial downloads' + download_concurrency: + type: integer + description: 'the number of parts to download in parallel. If this value is set to zero, the default value (5) will be used. Ignored for partial downloads' + download_part_max_time: + type: integer + description: 'the maximum time allowed, in seconds, to download a single chunk (the chunk size is defined via "download_part_size"). 0 means no timeout. Ignored for partial downloads.' + force_path_style: + type: boolean + description: 'Set this to "true" to force the request to use path-style addressing, i.e., "http://s3.amazonaws.com/BUCKET/KEY". By default, the S3 client will use virtual hosted bucket addressing when possible ("http://BUCKET.s3.amazonaws.com/KEY")' + key_prefix: + type: string + description: 'key_prefix is similar to a chroot directory for a local filesystem. If specified the user will only see contents that starts with this prefix and so you can restrict access to a specific virtual folder. The prefix, if not empty, must not start with "/" and must end with "/". If empty the whole bucket contents will be available' + example: folder/subfolder/ + description: S3 Compatible Object Storage configuration details + GCSConfig: + type: object + properties: + bucket: + type: string + minLength: 1 + credentials: + $ref: '#/components/schemas/Secret' + automatic_credentials: + type: integer + enum: + - 0 + - 1 + description: | + Automatic credentials: + * `0` - disabled, explicit credentials, using a JSON credentials file, must be provided. This is the default value if the field is null + * `1` - enabled, we try to use the Application Default Credentials (ADC) strategy to find your application's credentials + storage_class: + type: string + acl: + type: string + description: 'The ACL to apply to uploaded objects. Leave empty to use the default ACL. For more information and available ACLs, refer to the JSON API here: https://cloud.google.com/storage/docs/access-control/lists#predefined-acl' + key_prefix: + type: string + description: 'key_prefix is similar to a chroot directory for a local filesystem. If specified the user will only see contents that starts with this prefix and so you can restrict access to a specific virtual folder. The prefix, if not empty, must not start with "/" and must end with "/". If empty the whole bucket contents will be available' + example: folder/subfolder/ + upload_part_size: + type: integer + description: 'The buffer size (in MB) to use for multipart uploads. The default value is 16MB. 0 means use the default' + upload_part_max_time: + type: integer + description: 'The maximum time allowed, in seconds, to upload a single chunk. The default value is 32. 0 means use the default' + description: 'Google Cloud Storage configuration details. The "credentials" field must be populated only when adding/updating a user. It will be always omitted, since there are sensitive data, when you search/get users' + AzureBlobFsConfig: + type: object + properties: + container: + type: string + account_name: + type: string + description: 'Storage Account Name, leave blank to use SAS URL' + account_key: + $ref: '#/components/schemas/Secret' + sas_url: + $ref: '#/components/schemas/Secret' + endpoint: + type: string + description: 'optional endpoint. Default is "blob.core.windows.net". If you use the emulator the endpoint must include the protocol, for example "http://127.0.0.1:10000"' + upload_part_size: + type: integer + description: 'the buffer size (in MB) to use for multipart uploads. If this value is set to zero, the default value (5MB) will be used.' + upload_concurrency: + type: integer + description: 'the number of parts to upload in parallel. If this value is set to zero, the default value (5) will be used' + download_part_size: + type: integer + description: 'the buffer size (in MB) to use for multipart downloads. If this value is set to zero, the default value (5MB) will be used.' + download_concurrency: + type: integer + description: 'the number of parts to download in parallel. If this value is set to zero, the default value (5) will be used' + access_tier: + type: string + enum: + - '' + - Archive + - Hot + - Cool + key_prefix: + type: string + description: 'key_prefix is similar to a chroot directory for a local filesystem. If specified the user will only see contents that starts with this prefix and so you can restrict access to a specific virtual folder. The prefix, if not empty, must not start with "/" and must end with "/". If empty the whole container contents will be available' + example: folder/subfolder/ + use_emulator: + type: boolean + description: Azure Blob Storage configuration details + OSFsConfig: + type: object + properties: + read_buffer_size: + type: integer + minimum: 0 + maximum: 10 + description: "The read buffer size, as MB, to use for downloads. 0 means no buffering, that's fine in most use cases." + write_buffer_size: + type: integer + minimum: 0 + maximum: 10 + description: "The write buffer size, as MB, to use for uploads. 0 means no buffering, that's fine in most use cases." + CryptFsConfig: + type: object + properties: + passphrase: + $ref: '#/components/schemas/Secret' + read_buffer_size: + type: integer + minimum: 0 + maximum: 10 + description: "The read buffer size, as MB, to use for downloads. 0 means no buffering, that's fine in most use cases." + write_buffer_size: + type: integer + minimum: 0 + maximum: 10 + description: "The write buffer size, as MB, to use for uploads. 0 means no buffering, that's fine in most use cases." + description: Crypt filesystem configuration details + SFTPFsConfig: + type: object + properties: + endpoint: + type: string + description: 'remote SFTP endpoint as host:port' + username: + type: string + description: you can specify a password or private key or both. In the latter case the private key will be tried first. + password: + $ref: '#/components/schemas/Secret' + private_key: + $ref: '#/components/schemas/Secret' + key_passphrase: + $ref: '#/components/schemas/Secret' + fingerprints: + type: array + items: + type: string + description: 'SHA256 fingerprints to use for host key verification. If you don''t provide any fingerprint the remote host key will not be verified, this is a security risk' + prefix: + type: string + description: Specifying a prefix you can restrict all operations to a given path within the remote SFTP server. + disable_concurrent_reads: + type: boolean + description: Concurrent reads are safe to use and disabling them will degrade performance. Some servers automatically delete files once they are downloaded. Using concurrent reads is problematic with such servers. + buffer_size: + type: integer + minimum: 0 + maximum: 16 + example: 2 + description: The size of the buffer (in MB) to use for transfers. By enabling buffering, the reads and writes, from/to the remote SFTP server, are split in multiple concurrent requests and this allows data to be transferred at a faster rate, over high latency networks, by overlapping round-trip times. With buffering enabled, resuming uploads is not supported and a file cannot be opened for both reading and writing at the same time. 0 means disabled. + equality_check_mode: + type: integer + enum: + - 0 + - 1 + description: | + Defines how to check if this config points to the same server as another config. If different configs point to the same server the renaming between the fs configs is allowed: + * `0` username and endpoint must match. This is the default + * `1` only the endpoint must match + HTTPFsConfig: + type: object + properties: + endpoint: + type: string + description: 'HTTP/S endpoint URL. SFTPGo will use this URL as base, for example for the `stat` API, SFTPGo will add `/stat/{name}`' + username: + type: string + password: + $ref: '#/components/schemas/Secret' + api_key: + $ref: '#/components/schemas/Secret' + skip_tls_verify: + type: boolean + equality_check_mode: + type: integer + enum: + - 0 + - 1 + description: | + Defines how to check if this config points to the same server as another config. If different configs point to the same server the renaming between the fs configs is allowed: + * `0` username and endpoint must match. This is the default + * `1` only the endpoint must match + FilesystemConfig: + type: object + properties: + provider: + $ref: '#/components/schemas/FsProviders' + osconfig: + $ref: '#/components/schemas/OSFsConfig' + s3config: + $ref: '#/components/schemas/S3Config' + gcsconfig: + $ref: '#/components/schemas/GCSConfig' + azblobconfig: + $ref: '#/components/schemas/AzureBlobFsConfig' + cryptconfig: + $ref: '#/components/schemas/CryptFsConfig' + sftpconfig: + $ref: '#/components/schemas/SFTPFsConfig' + httpconfig: + $ref: '#/components/schemas/HTTPFsConfig' + description: Storage filesystem details + BaseVirtualFolder: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + name: + type: string + description: unique name for this virtual folder + mapped_path: + type: string + description: absolute filesystem path to use as virtual folder + description: + type: string + description: optional description + used_quota_size: + type: integer + format: int64 + used_quota_files: + type: integer + format: int32 + last_quota_update: + type: integer + format: int64 + description: Last quota update as unix timestamp in milliseconds + users: + type: array + items: + type: string + description: list of usernames associated with this virtual folder + filesystem: + $ref: '#/components/schemas/FilesystemConfig' + description: 'Defines the filesystem for the virtual folder and the used quota limits. The same folder can be shared among multiple users and each user can have different quota limits or a different virtual path.' + VirtualFolder: + allOf: + - $ref: '#/components/schemas/BaseVirtualFolder' + - type: object + properties: + virtual_path: + type: string + quota_size: + type: integer + format: int64 + description: 'Quota as size in bytes. 0 means unlimited, -1 means included in user quota. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed' + quota_files: + type: integer + format: int32 + description: 'Quota as number of files. 0 means unlimited, , -1 means included in user quota. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed' + required: + - virtual_path + description: 'A virtual folder is a mapping between a SFTPGo virtual path and a filesystem path outside the user home directory. The specified paths must be absolute and the virtual path cannot be "/", it must be a sub directory. The parent directory for the specified virtual path must exist. SFTPGo will try to automatically create any missing parent directory for the configured virtual folders at user login.' + User: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + status: + type: integer + enum: + - 0 + - 1 + description: | + status: + * `0` user is disabled, login is not allowed + * `1` user is enabled + username: + type: string + description: username is unique + email: + type: string + format: email + description: + type: string + description: 'optional description, for example the user full name' + expiration_date: + type: integer + format: int64 + description: expiration date as unix timestamp in milliseconds. An expired account cannot login. 0 means no expiration + password: + type: string + format: password + description: If the password has no known hashing algo prefix it will be stored, by default, using bcrypt, argon2id is supported too. You can send a password hashed as bcrypt ($2a$ prefix), argon2id, pbkdf2 or unix crypt and it will be stored as is. For security reasons this field is omitted when you search/get users + public_keys: + type: array + items: + type: string + example: ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUWwDwEWhTbF0MqAsp/oXK1HR2cElhM8oo1uVmL3ZeDKDiTm4ljMr92wfTgIGDqIoxmVqgYIkAOAhuykAVWBzc= user@host + description: Public keys in OpenSSH format. + has_password: + type: boolean + description: Indicates whether the password is set + home_dir: + type: string + description: path to the user home directory. The user cannot upload or download files outside this directory. SFTPGo tries to automatically create this folder if missing. Must be an absolute path + virtual_folders: + type: array + items: + $ref: '#/components/schemas/VirtualFolder' + description: mapping between virtual SFTPGo paths and virtual folders + uid: + type: integer + format: int32 + minimum: 0 + maximum: 2147483647 + description: 'if you run SFTPGo as root user, the created files and directories will be assigned to this uid. 0 means no change, the owner will be the user that runs SFTPGo. Ignored on windows' + gid: + type: integer + format: int32 + minimum: 0 + maximum: 2147483647 + description: 'if you run SFTPGo as root user, the created files and directories will be assigned to this gid. 0 means no change, the group will be the one of the user that runs SFTPGo. Ignored on windows' + max_sessions: + type: integer + format: int32 + description: Limit the sessions that a user can open. 0 means unlimited + quota_size: + type: integer + format: int64 + description: Quota as size in bytes. 0 means unlimited. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed + quota_files: + type: integer + format: int32 + description: Quota as number of files. 0 means unlimited. Please note that quota is updated if files are added/removed via SFTPGo otherwise a quota scan or a manual quota update is needed + permissions: + type: object + additionalProperties: + type: array + items: + $ref: '#/components/schemas/Permission' + minItems: 1 + minProperties: 1 + description: 'hash map with directory as key and an array of permissions as value. Directories must be absolute paths, permissions for root directory ("/") are required' + example: + /: + - '*' + /somedir: + - list + - download + used_quota_size: + type: integer + format: int64 + used_quota_files: + type: integer + format: int32 + last_quota_update: + type: integer + format: int64 + description: Last quota update as unix timestamp in milliseconds + upload_bandwidth: + type: integer + description: 'Maximum upload bandwidth as KB/s, 0 means unlimited' + download_bandwidth: + type: integer + description: 'Maximum download bandwidth as KB/s, 0 means unlimited' + upload_data_transfer: + type: integer + description: 'Maximum data transfer allowed for uploads as MB. 0 means no limit' + download_data_transfer: + type: integer + description: 'Maximum data transfer allowed for downloads as MB. 0 means no limit' + total_data_transfer: + type: integer + description: 'Maximum total data transfer as MB. 0 means unlimited. You can set a total data transfer instead of the individual values for uploads and downloads' + used_upload_data_transfer: + type: integer + description: 'Uploaded size, as bytes, since the last reset' + used_download_data_transfer: + type: integer + description: 'Downloaded size, as bytes, since the last reset' + created_at: + type: integer + format: int64 + description: 'creation time as unix timestamp in milliseconds. It will be 0 for users created before v2.2.0' + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in milliseconds + last_login: + type: integer + format: int64 + description: Last user login as unix timestamp in milliseconds. It is saved at most once every 10 minutes + first_download: + type: integer + format: int64 + description: first download time as unix timestamp in milliseconds + first_upload: + type: integer + format: int64 + description: first upload time as unix timestamp in milliseconds + last_password_change: + type: integer + format: int64 + description: last password change time as unix timestamp in milliseconds + filters: + $ref: '#/components/schemas/UserFilters' + filesystem: + $ref: '#/components/schemas/FilesystemConfig' + additional_info: + type: string + description: Free form text field for external systems + groups: + type: array + items: + $ref: '#/components/schemas/GroupMapping' + oidc_custom_fields: + type: object + additionalProperties: true + description: 'This field is passed to the pre-login hook if custom OIDC token fields have been configured. Field values can be of any type (this is a free form object) and depend on the type of the configured OIDC token fields' + role: + type: string + AdminPreferences: + type: object + properties: + hide_user_page_sections: + type: integer + description: 'Allow to hide some sections from the user page. These are not security settings and are not enforced server side in any way. They are only intended to simplify the user page in the WebAdmin UI. 1 means hide groups section, 2 means hide filesystem section, "users_base_dir" must be set in the config file otherwise this setting is ignored, 4 means hide virtual folders section, 8 means hide profile section, 16 means hide ACLs section, 32 means hide disk and bandwidth quota limits section, 64 means hide advanced settings section. The settings can be combined' + default_users_expiration: + type: integer + description: 'Defines the default expiration for newly created users as number of days. 0 means no expiration' + AdminFilters: + type: object + properties: + allow_list: + type: array + items: + type: string + description: 'only clients connecting from these IP/Mask are allowed. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32"' + example: + - 192.0.2.0/24 + - '2001:db8::/32' + allow_api_key_auth: + type: boolean + description: 'API key auth allows to impersonate this administrator with an API key' + require_two_factor: + type: boolean + require_password_change: + type: boolean + totp_config: + $ref: '#/components/schemas/AdminTOTPConfig' + recovery_codes: + type: array + items: + $ref: '#/components/schemas/RecoveryCode' + preferences: + $ref: '#/components/schemas/AdminPreferences' + Admin: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + status: + type: integer + enum: + - 0 + - 1 + description: | + status: + * `0` user is disabled, login is not allowed + * `1` user is enabled + username: + type: string + description: username is unique + description: + type: string + description: 'optional description, for example the admin full name' + password: + type: string + format: password + description: Admin password. For security reasons this field is omitted when you search/get admins + email: + type: string + format: email + permissions: + type: array + items: + $ref: '#/components/schemas/AdminPermissions' + filters: + $ref: '#/components/schemas/AdminFilters' + additional_info: + type: string + description: Free form text field + groups: + type: array + items: + $ref: '#/components/schemas/AdminGroupMapping' + description: 'Groups automatically selected for new users created by this admin. The admin will still be able to choose different groups. These settings are only used for this admin UI and they will be ignored in REST API/hooks.' + created_at: + type: integer + format: int64 + description: 'creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0' + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in milliseconds + last_login: + type: integer + format: int64 + description: Last user login as unix timestamp in milliseconds. It is saved at most once every 10 minutes + role: + type: string + description: 'If set the admin can only administer users with the same role. Role admins cannot have the "*" permission' + AdminProfile: + type: object + properties: + email: + type: string + format: email + description: + type: string + allow_api_key_auth: + type: boolean + description: 'If enabled, you can impersonate this admin, in REST API, using an API key. If disabled admin credentials are required for impersonation' + UserProfile: + type: object + properties: + email: + type: string + format: email + description: + type: string + allow_api_key_auth: + type: boolean + description: 'If enabled, you can impersonate this user, in REST API, using an API key. If disabled user credentials are required for impersonation' + public_keys: + type: array + items: + type: string + example: ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEUWwDwEWhTbF0MqAsp/oXK1HR2cElhM8oo1uVmL3ZeDKDiTm4ljMr92wfTgIGDqIoxmVqgYIkAOAhuykAVWBzc= user@host + description: Public keys in OpenSSH format + APIKey: + type: object + properties: + id: + type: string + description: unique key identifier + name: + type: string + description: User friendly key name + key: + type: string + format: password + description: We store the hash of the key. This is just like a password. For security reasons this field is omitted when you search/get API keys + scope: + $ref: '#/components/schemas/APIKeyScope' + created_at: + type: integer + format: int64 + description: creation time as unix timestamp in milliseconds + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in milliseconds + last_use_at: + type: integer + format: int64 + description: last use time as unix timestamp in milliseconds. It is saved at most once every 10 minutes + expires_at: + type: integer + format: int64 + description: expiration time as unix timestamp in milliseconds + description: + type: string + description: optional description + user: + type: string + description: username associated with this API key. If empty and the scope is "user scope" the key can impersonate any user + admin: + type: string + description: admin associated with this API key. If empty and the scope is "admin scope" the key can impersonate any admin + QuotaUsage: + type: object + properties: + used_quota_size: + type: integer + format: int64 + used_quota_files: + type: integer + format: int32 + TransferQuotaUsage: + type: object + properties: + used_upload_data_transfer: + type: integer + format: int64 + description: 'The value must be specified as bytes' + used_download_data_transfer: + type: integer + format: int64 + description: 'The value must be specified as bytes' + Transfer: + type: object + properties: + operation_type: + type: string + enum: + - upload + - download + description: | + Operations: + * `upload` + * `download` + path: + type: string + description: file path for the upload/download + start_time: + type: integer + format: int64 + description: start time as unix timestamp in milliseconds + size: + type: integer + format: int64 + description: bytes transferred + ConnectionStatus: + type: object + properties: + username: + type: string + description: connected username + connection_id: + type: string + description: unique connection identifier + client_version: + type: string + description: client version + remote_address: + type: string + description: Remote address for the connected client + connection_time: + type: integer + format: int64 + description: connection time as unix timestamp in milliseconds + command: + type: string + description: Last SSH/FTP command or WebDAV method + last_activity: + type: integer + format: int64 + description: last client activity as unix timestamp in milliseconds + protocol: + type: string + enum: + - SFTP + - SCP + - SSH + - FTP + - DAV + active_transfers: + type: array + items: + $ref: '#/components/schemas/Transfer' + node: + type: string + description: 'Node identifier, omitted for single node installations' + FolderRetention: + type: object + properties: + path: + type: string + description: 'virtual directory path as seen by users, if no other specific retention is defined, the retention applies for sub directories too. For example if retention is defined for the paths "/" and "/sub" then the retention for "/" is applied for any file outside the "/sub" directory' + example: '/' + retention: + type: integer + description: retention time in hours. All the files with a modification time older than the defined value will be deleted. 0 means exclude this path + example: 24 + delete_empty_dirs: + type: boolean + description: if enabled, empty directories will be deleted + RetentionCheck: + type: object + properties: + username: + type: string + description: username to which the retention check refers + folders: + type: array + items: + $ref: '#/components/schemas/FolderRetention' + start_time: + type: integer + format: int64 + description: check start time as unix timestamp in milliseconds + QuotaScan: + type: object + properties: + username: + type: string + description: username to which the quota scan refers + start_time: + type: integer + format: int64 + description: scan start time as unix timestamp in milliseconds + FolderQuotaScan: + type: object + properties: + name: + type: string + description: folder name to which the quota scan refers + start_time: + type: integer + format: int64 + description: scan start time as unix timestamp in milliseconds + DefenderEntry: + type: object + properties: + id: + type: string + ip: + type: string + score: + type: integer + description: the score increases whenever a violation is detected, such as an attempt to log in using an incorrect password or invalid username. If the score exceeds the configured threshold, the IP is banned. Omitted for banned IPs + ban_time: + type: string + format: date-time + description: date time until the IP is banned. For already banned hosts, the ban time is increased each time a new violation is detected. Omitted if the IP is not banned + SSHHostKey: + type: object + properties: + path: + type: string + fingerprint: + type: string + algorithms: + type: array + items: + type: string + SSHBinding: + type: object + properties: + address: + type: string + description: TCP address the server listen on + port: + type: integer + description: the port used for serving requests + apply_proxy_config: + type: boolean + description: 'apply the proxy configuration, if any' + WebDAVBinding: + type: object + properties: + address: + type: string + description: TCP address the server listen on + port: + type: integer + description: the port used for serving requests + enable_https: + type: boolean + min_tls_version: + $ref: '#/components/schemas/TLSVersions' + client_auth_type: + type: integer + description: 1 means that client certificate authentication is required in addition to HTTP basic authentication + tls_cipher_suites: + type: array + items: + type: string + description: 'List of supported cipher suites for TLS version 1.2. If empty a default list of secure cipher suites is used, with a preference order based on hardware performance' + prefix: + type: string + description: 'Prefix for WebDAV resources, if empty WebDAV resources will be available at the `/` URI' + proxy_allowed: + type: array + items: + type: string + description: 'List of IP addresses and IP ranges allowed to set proxy headers' + PassiveIPOverride: + type: object + properties: + networks: + type: array + items: + type: string + ip: + type: string + FTPDBinding: + type: object + properties: + address: + type: string + description: TCP address the server listen on + port: + type: integer + description: the port used for serving requests + apply_proxy_config: + type: boolean + description: 'apply the proxy configuration, if any' + tls_mode: + type: integer + enum: + - 0 + - 1 + - 2 + description: | + TLS mode: + * `0` - clear or explicit TLS + * `1` - explicit TLS required + * `2` - implicit TLS + min_tls_version: + $ref: '#/components/schemas/TLSVersions' + force_passive_ip: + type: string + description: External IP address for passive connections + passive_ip_overrides: + type: array + items: + $ref: '#/components/schemas/PassiveIPOverride' + client_auth_type: + type: integer + description: 1 means that client certificate authentication is required in addition to FTP authentication + tls_cipher_suites: + type: array + items: + type: string + description: 'List of supported cipher suites for TLS version 1.2. If empty a default list of secure cipher suites is used, with a preference order based on hardware performance' + passive_connections_security: + type: integer + enum: + - 0 + - 1 + description: | + Active connections security: + * `0` - require matching peer IP addresses of control and data connection + * `1` - disable any checks + active_connections_security: + type: integer + enum: + - 0 + - 1 + description: | + Active connections security: + * `0` - require matching peer IP addresses of control and data connection + * `1` - disable any checks + ignore_ascii_transfer_type: + type: integer + enum: + - 0 + - 1 + description: | + Ignore client requests to perform ASCII translations: + * `0` - ASCII translations are enabled + * `1` - ASCII translations are silently ignored + debug: + type: boolean + description: 'If enabled any FTP command will be logged' + SSHServiceStatus: + type: object + properties: + is_active: + type: boolean + bindings: + type: array + items: + $ref: '#/components/schemas/SSHBinding' + nullable: true + host_keys: + type: array + items: + $ref: '#/components/schemas/SSHHostKey' + nullable: true + ssh_commands: + type: array + items: + type: string + authentications: + type: array + items: + $ref: '#/components/schemas/SSHAuthentications' + public_key_algorithms: + type: array + items: + type: string + macs: + type: array + items: + type: string + kex_algorithms: + type: array + items: + type: string + ciphers: + type: array + items: + type: string + FTPPassivePortRange: + type: object + properties: + start: + type: integer + end: + type: integer + FTPServiceStatus: + type: object + properties: + is_active: + type: boolean + bindings: + type: array + items: + $ref: '#/components/schemas/FTPDBinding' + nullable: true + passive_port_range: + $ref: '#/components/schemas/FTPPassivePortRange' + WebDAVServiceStatus: + type: object + properties: + is_active: + type: boolean + bindings: + type: array + items: + $ref: '#/components/schemas/WebDAVBinding' + nullable: true + DataProviderStatus: + type: object + properties: + is_active: + type: boolean + driver: + type: string + error: + type: string + MFAStatus: + type: object + properties: + is_active: + type: boolean + totp_configs: + type: array + items: + $ref: '#/components/schemas/TOTPConfig' + ServicesStatus: + type: object + properties: + ssh: + $ref: '#/components/schemas/SSHServiceStatus' + ftp: + $ref: '#/components/schemas/FTPServiceStatus' + webdav: + $ref: '#/components/schemas/WebDAVServiceStatus' + data_provider: + $ref: '#/components/schemas/DataProviderStatus' + defender: + type: object + properties: + is_active: + type: boolean + mfa: + $ref: '#/components/schemas/MFAStatus' + allow_list: + type: object + properties: + is_active: + type: boolean + rate_limiters: + type: object + properties: + is_active: + type: boolean + protocols: + type: array + items: + type: string + example: SSH + Share: + type: object + properties: + id: + type: string + description: auto-generated unique share identifier + name: + type: string + description: + type: string + description: optional description + scope: + $ref: '#/components/schemas/ShareScope' + paths: + type: array + items: + type: string + description: 'paths to files or directories, for share scope write this array must contain exactly one directory. Paths will not be validated on save so you can also create them after creating the share' + example: + - '/dir1' + - '/dir2/file.txt' + - '/dir3/subdir' + username: + type: string + created_at: + type: integer + format: int64 + description: 'creation time as unix timestamp in milliseconds' + updated_at: + type: integer + format: int64 + description: 'last update time as unix timestamp in milliseconds' + last_use_at: + type: integer + format: int64 + description: last use time as unix timestamp in milliseconds + expires_at: + type: integer + format: int64 + description: 'optional share expiration, as unix timestamp in milliseconds. 0 means no expiration' + password: + type: string + description: 'optional password to protect the share. The special value "[**redacted**]" means that a password has been set, you can use this value if you want to preserve the current password when you update a share' + max_tokens: + type: integer + description: 'maximum allowed access tokens. 0 means no limit' + used_tokens: + type: integer + allow_from: + type: array + items: + type: string + description: 'Limit the share availability to these IP/Mask. IP/Mask must be in CIDR notation as defined in RFC 4632 and RFC 4291, for example "192.0.2.0/24" or "2001:db8::/32". An empty list means no restrictions' + example: + - 192.0.2.0/24 + - '2001:db8::/32' + GroupUserSettings: + type: object + properties: + home_dir: + type: string + max_sessions: + type: integer + format: int32 + quota_size: + type: integer + format: int64 + quota_files: + type: integer + format: int32 + permissions: + type: object + additionalProperties: + type: array + items: + $ref: '#/components/schemas/Permission' + minItems: 1 + minProperties: 1 + description: 'hash map with directory as key and an array of permissions as value. Directories must be absolute paths, permissions for root directory ("/") are required' + example: + /: + - '*' + /somedir: + - list + - download + upload_bandwidth: + type: integer + description: 'Maximum upload bandwidth as KB/s' + download_bandwidth: + type: integer + description: 'Maximum download bandwidth as KB/s' + upload_data_transfer: + type: integer + description: 'Maximum data transfer allowed for uploads as MB' + download_data_transfer: + type: integer + description: 'Maximum data transfer allowed for downloads as MB' + total_data_transfer: + type: integer + description: 'Maximum total data transfer as MB' + expires_in: + type: integer + description: 'Account expiration in number of days from creation. 0 means no expiration' + filters: + $ref: '#/components/schemas/BaseUserFilters' + filesystem: + $ref: '#/components/schemas/FilesystemConfig' + Role: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + name: + type: string + description: name is unique + description: + type: string + description: 'optional description' + created_at: + type: integer + format: int64 + description: creation time as unix timestamp in milliseconds + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in milliseconds + users: + type: array + items: + type: string + description: list of usernames associated with this group + admins: + type: array + items: + type: string + description: list of admins usernames associated with this group + Group: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + name: + type: string + description: name is unique + description: + type: string + description: 'optional description' + created_at: + type: integer + format: int64 + description: creation time as unix timestamp in milliseconds + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in milliseconds + user_settings: + $ref: '#/components/schemas/GroupUserSettings' + virtual_folders: + type: array + items: + $ref: '#/components/schemas/VirtualFolder' + description: mapping between virtual SFTPGo paths and folders + users: + type: array + items: + type: string + description: list of usernames associated with this group + admins: + type: array + items: + type: string + description: list of admins usernames associated with this group + GroupMapping: + type: object + properties: + name: + type: string + description: group name + type: + enum: + - 1 + - 2 + - 3 + description: | + Group type: + * `1` - Primary group + * `2` - Secondary group + * `3` - Membership only, no settings are inherited from this group type + AdminGroupMappingOptions: + type: object + properties: + add_to_users_as: + enum: + - 0 + - 1 + - 2 + description: | + Add to new users as: + * `0` - the admin's group will be added as membership group for new users + * `1` - the admin's group will be added as primary group for new users + * `2` - the admin's group will be added as secondary group for new users + AdminGroupMapping: + type: object + properties: + name: + type: string + description: group name + options: + $ref: '#/components/schemas/AdminGroupMappingOptions' + BackupData: + type: object + properties: + users: + type: array + items: + $ref: '#/components/schemas/User' + folders: + type: array + items: + $ref: '#/components/schemas/BaseVirtualFolder' + groups: + type: array + items: + $ref: '#/components/schemas/Group' + admins: + type: array + items: + $ref: '#/components/schemas/Admin' + api_keys: + type: array + items: + $ref: '#/components/schemas/APIKey' + shares: + type: array + items: + $ref: '#/components/schemas/Share' + event_actions: + type: array + items: + $ref: '#/components/schemas/EventAction' + event_rules: + type: array + items: + $ref: '#/components/schemas/EventRule' + roles: + type: array + items: + $ref: '#/components/schemas/Role' + version: + type: integer + PwdChange: + type: object + properties: + current_password: + type: string + new_password: + type: string + DirEntry: + type: object + properties: + name: + type: string + description: name of the file (or subdirectory) described by the entry. This name is the final element of the path (the base name), not the entire path + size: + type: integer + format: int64 + description: file size, omitted for folders and non regular files + mode: + type: integer + description: | + File mode and permission bits. More details here: https://golang.org/pkg/io/fs/#FileMode. + Let's see some examples: + - for a directory mode&2147483648 != 0 + - for a symlink mode&134217728 != 0 + - for a regular file mode&2401763328 == 0 + last_modified: + type: string + format: date-time + FsEvent: + type: object + properties: + id: + type: string + timestamp: + type: integer + format: int64 + description: 'unix timestamp in nanoseconds' + action: + $ref: '#/components/schemas/FsEventAction' + username: + type: string + fs_path: + type: string + fs_target_path: + type: string + virtual_path: + type: string + virtual_target_path: + type: string + ssh_cmd: + type: string + file_size: + type: integer + format: int64 + elapsed: + type: integer + format: int64 + description: elapsed time as milliseconds + status: + $ref: '#/components/schemas/FsEventStatus' + protocol: + $ref: '#/components/schemas/EventProtocols' + ip: + type: string + session_id: + type: string + fs_provider: + $ref: '#/components/schemas/FsProviders' + bucket: + type: string + endpoint: + type: string + open_flags: + type: string + role: + type: string + instance_id: + type: string + ProviderEvent: + type: object + properties: + id: + type: string + timestamp: + type: integer + format: int64 + description: 'unix timestamp in nanoseconds' + action: + $ref: '#/components/schemas/ProviderEventAction' + username: + type: string + ip: + type: string + object_type: + $ref: '#/components/schemas/ProviderEventObjectType' + object_name: + type: string + object_data: + type: string + format: byte + description: 'base64 of the JSON serialized object with sensitive fields removed' + role: + type: string + instance_id: + type: string + LogEvent: + type: object + properties: + id: + type: string + timestamp: + type: integer + format: int64 + description: 'unix timestamp in nanoseconds' + event: + $ref: '#/components/schemas/LogEventType' + protocol: + $ref: '#/components/schemas/EventProtocols' + username: + type: string + ip: + type: string + message: + type: string + role: + type: string + instance_id: + type: string + KeyValue: + type: object + properties: + key: + type: string + value: + type: string + RenameConfig: + allOf: + - $ref: '#/components/schemas/KeyValue' + - type: object + properties: + update_modtime: + type: boolean + description: 'Update modification time. This setting is not recursive and only applies to storage providers that support changing modification times' + HTTPPart: + type: object + properties: + name: + type: string + headers: + type: array + items: + $ref: '#/components/schemas/KeyValue' + description: 'Additional headers. Content-Disposition header is automatically set. Content-Type header is automatically detect for files to attach' + filepath: + type: string + description: 'path to the file to be sent as an attachment' + body: + type: string + EventActionHTTPConfig: + type: object + properties: + endpoint: + type: string + description: HTTP endpoint + example: https://example.com + username: + type: string + password: + $ref: '#/components/schemas/Secret' + headers: + type: array + items: + $ref: '#/components/schemas/KeyValue' + description: headers to add + timeout: + type: integer + minimum: 1 + maximum: 180 + description: 'Ignored for multipart requests with files as attachments' + skip_tls_verify: + type: boolean + description: 'if enabled the HTTP client accepts any TLS certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.' + method: + type: string + enum: + - GET + - POST + - PUT + - DELETE + query_parameters: + type: array + items: + $ref: '#/components/schemas/KeyValue' + body: + type: string + description: HTTP POST/PUT body + parts: + type: array + items: + $ref: '#/components/schemas/HTTPPart' + description: 'Multipart requests allow to combine one or more sets of data into a single body. For each part, you can set a file path or a body as text. Placeholders are supported in file path, body, header values.' + EventActionCommandConfig: + type: object + properties: + cmd: + type: string + description: absolute path to the command to execute + args: + type: array + items: + type: string + description: 'command line arguments' + timeout: + type: integer + minimum: 1 + maximum: 120 + env_vars: + type: array + items: + $ref: '#/components/schemas/KeyValue' + EventActionEmailConfig: + type: object + properties: + recipients: + type: array + items: + type: string + bcc: + type: array + items: + type: string + subject: + type: string + body: + type: string + content_type: + type: integer + enum: + - 0 + - 1 + description: | + Content type: + * `0` text/plain + * `1` text/html + attachments: + type: array + items: + type: string + description: 'list of file paths to attach. The total size is limited to 10 MB' + EventActionDataRetentionConfig: + type: object + properties: + folders: + type: array + items: + $ref: '#/components/schemas/FolderRetention' + EventActionFsCompress: + type: object + properties: + name: + type: string + description: 'Full path to the (zip) archive to create. The parent dir must exist' + paths: + type: array + items: + type: string + description: 'paths to add the archive' + EventActionFilesystemConfig: + type: object + properties: + type: + $ref: '#/components/schemas/FilesystemActionTypes' + renames: + type: array + items: + $ref: '#/components/schemas/RenameConfig' + mkdirs: + type: array + items: + type: string + deletes: + type: array + items: + type: string + exist: + type: array + items: + type: string + copy: + type: array + items: + $ref: '#/components/schemas/KeyValue' + compress: + $ref: '#/components/schemas/EventActionFsCompress' + EventActionPasswordExpiration: + type: object + properties: + threshold: + type: integer + description: 'An email notification will be generated for users whose password expires in a number of days less than or equal to this threshold' + EventActionUserInactivity: + type: object + properties: + disable_threshold: + type: integer + description: 'Inactivity threshold, in days, before disabling the account' + delete_threshold: + type: integer + description: 'Inactivity threshold, in days, before deleting the account' + EventActionIDPAccountCheck: + type: object + properties: + mode: + type: integer + enum: + - 0 + - 1 + description: | + Account check mode: + * `0` Create or update the account + * `1` Create the account if it doesn't exist + template_user: + type: string + description: 'SFTPGo user template in JSON format' + template_admin: + type: string + description: 'SFTPGo admin template in JSON format' + BaseEventActionOptions: + type: object + properties: + http_config: + $ref: '#/components/schemas/EventActionHTTPConfig' + cmd_config: + $ref: '#/components/schemas/EventActionCommandConfig' + email_config: + $ref: '#/components/schemas/EventActionEmailConfig' + retention_config: + $ref: '#/components/schemas/EventActionDataRetentionConfig' + fs_config: + $ref: '#/components/schemas/EventActionFilesystemConfig' + pwd_expiration_config: + $ref: '#/components/schemas/EventActionPasswordExpiration' + user_inactivity_config: + $ref: '#/components/schemas/EventActionUserInactivity' + idp_config: + $ref: '#/components/schemas/EventActionIDPAccountCheck' + BaseEventAction: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + name: + type: string + description: unique name + description: + type: string + description: optional description + type: + $ref: '#/components/schemas/EventActionTypes' + options: + $ref: '#/components/schemas/BaseEventActionOptions' + rules: + type: array + items: + type: string + description: list of event rules names associated with this action + EventActionOptions: + type: object + properties: + is_failure_action: + type: boolean + stop_on_failure: + type: boolean + execute_sync: + type: boolean + EventAction: + allOf: + - $ref: '#/components/schemas/BaseEventAction' + - type: object + properties: + order: + type: integer + description: execution order + relation_options: + $ref: '#/components/schemas/EventActionOptions' + EventActionMinimal: + type: object + properties: + name: + type: string + order: + type: integer + description: execution order + relation_options: + $ref: '#/components/schemas/EventActionOptions' + ConditionPattern: + type: object + properties: + pattern: + type: string + inverse_match: + type: boolean + ConditionOptions: + type: object + properties: + names: + type: array + items: + $ref: '#/components/schemas/ConditionPattern' + group_names: + type: array + items: + $ref: '#/components/schemas/ConditionPattern' + role_names: + type: array + items: + $ref: '#/components/schemas/ConditionPattern' + fs_paths: + type: array + items: + $ref: '#/components/schemas/ConditionPattern' + protocols: + type: array + items: + type: string + enum: + - SFTP + - SCP + - SSH + - FTP + - DAV + - HTTP + - HTTPShare + - OIDC + provider_objects: + type: array + items: + type: string + enum: + - user + - group + - admin + - api_key + - share + - event_action + - event_rule + min_size: + type: integer + format: int64 + max_size: + type: integer + format: int64 + event_statuses: + type: array + items: + type: integer + enum: + - 1 + - 2 + - 3 + description: | + Event status: + - `1` OK + - `2` Failed + - `3` Quota exceeded + concurrent_execution: + type: boolean + description: allow concurrent execution from multiple nodes + Schedule: + type: object + properties: + hour: + type: string + day_of_week: + type: string + day_of_month: + type: string + month: + type: string + EventConditions: + type: object + properties: + fs_events: + type: array + items: + type: string + enum: + - upload + - download + - delete + - rename + - mkdir + - rmdir + - copy + - ssh_cmd + - pre-upload + - pre-download + - pre-delete + - first-upload + - first-download + provider_events: + type: array + items: + type: string + enum: + - add + - update + - delete + schedules: + type: array + items: + $ref: '#/components/schemas/Schedule' + idp_login_event: + type: integer + enum: + - 0 + - 1 + - 2 + description: | + IDP login events: + - `0` any login event + - `1` user login event + - `2` admin login event + options: + $ref: '#/components/schemas/ConditionOptions' + BaseEventRule: + type: object + properties: + id: + type: integer + format: int32 + minimum: 1 + name: + type: string + description: unique name + status: + type: integer + enum: + - 0 + - 1 + description: | + status: + * `0` disabled + * `1` enabled + description: + type: string + description: optional description + created_at: + type: integer + format: int64 + description: creation time as unix timestamp in milliseconds + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in millisecond + trigger: + $ref: '#/components/schemas/EventTriggerTypes' + conditions: + $ref: '#/components/schemas/EventConditions' + EventRule: + allOf: + - $ref: '#/components/schemas/BaseEventRule' + - type: object + properties: + actions: + type: array + items: + $ref: '#/components/schemas/EventAction' + EventRuleMinimal: + allOf: + - $ref: '#/components/schemas/BaseEventRule' + - type: object + properties: + actions: + type: array + items: + $ref: '#/components/schemas/EventActionMinimal' + IPListEntry: + type: object + properties: + ipornet: + type: string + description: IP address or network in CIDR format, for example `192.168.1.2/32`, `192.168.0.0/24`, `2001:db8::/32` + description: + type: string + description: optional description + type: + $ref: '#/components/schemas/IPListType' + mode: + $ref: '#/components/schemas/IPListMode' + protocols: + type: integer + description: Defines the protocol the entry applies to. `0` means all the supported protocols, 1 SSH, 2 FTP, 4 WebDAV, 8 HTTP. Protocols can be combined, for example 3 means SSH and FTP + created_at: + type: integer + format: int64 + description: creation time as unix timestamp in milliseconds + updated_at: + type: integer + format: int64 + description: last update time as unix timestamp in millisecond + ApiResponse: + type: object + properties: + message: + type: string + description: 'message, can be empty' + error: + type: string + description: error description if any + VersionInfo: + type: object + properties: + version: + type: string + build_date: + type: string + commit_hash: + type: string + features: + type: array + items: + type: string + description: 'Features for the current build. Available features are `portable`, `bolt`, `mysql`, `sqlite`, `pgsql`, `s3`, `gcs`, `azblob`, `metrics`, `unixcrypt`. If a feature is available it has a `+` prefix, otherwise a `-` prefix' + Token: + type: object + properties: + access_token: + type: string + expires_at: + type: string + format: date-time + securitySchemes: + BasicAuth: + type: http + scheme: basic + BearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + APIKeyAuth: + type: apiKey + in: header + name: X-SFTPGO-API-KEY + description: 'API key to use for authentication. API key authentication is intrinsically less secure than using a short lived JWT token. You should prefer API key authentication only for machine-to-machine communications in trusted environments. If no admin/user is associated to the provided key you need to add ".username" at the end of the key. For example if your API key is "6ajKLwswLccVBGpZGv596G.ySAXc8vtp9hMiwAuaLtzof" and you want to impersonate the admin with username "myadmin" you have to use "6ajKLwswLccVBGpZGv596G.ySAXc8vtp9hMiwAuaLtzof.myadmin" as API key. When using API key authentication you cannot manage API keys, update the impersonated admin, change password or public keys for the impersonated user.' diff --git a/openapi/swagger-ui/favicon-16x16.png b/openapi/swagger-ui/favicon-16x16.png new file mode 100644 index 00000000..8b194e61 Binary files /dev/null and b/openapi/swagger-ui/favicon-16x16.png differ diff --git a/openapi/swagger-ui/favicon-32x32.png b/openapi/swagger-ui/favicon-32x32.png new file mode 100644 index 00000000..249737fe Binary files /dev/null and b/openapi/swagger-ui/favicon-32x32.png differ diff --git a/openapi/swagger-ui/index.css b/openapi/swagger-ui/index.css new file mode 100644 index 00000000..f2376fda --- /dev/null +++ b/openapi/swagger-ui/index.css @@ -0,0 +1,16 @@ +html { + box-sizing: border-box; + overflow: -moz-scrollbars-vertical; + overflow-y: scroll; +} + +*, +*:before, +*:after { + box-sizing: inherit; +} + +body { + margin: 0; + background: #fafafa; +} diff --git a/openapi/swagger-ui/index.html b/openapi/swagger-ui/index.html new file mode 100644 index 00000000..84ae62d3 --- /dev/null +++ b/openapi/swagger-ui/index.html @@ -0,0 +1,19 @@ + + + + + + Swagger UI + + + + + + + +
+ + + + + diff --git a/openapi/swagger-ui/swagger-initializer.js b/openapi/swagger-ui/swagger-initializer.js new file mode 100644 index 00000000..8ba59486 --- /dev/null +++ b/openapi/swagger-ui/swagger-initializer.js @@ -0,0 +1,20 @@ +window.onload = function() { + // + + // the following lines will be replaced by docker/configurator, when it runs in a docker-container + window.ui = SwaggerUIBundle({ + url: "../openapi.yaml", + dom_id: '#swagger-ui', + deepLinking: true, + presets: [ + SwaggerUIBundle.presets.apis, + SwaggerUIStandalonePreset + ], + plugins: [ + SwaggerUIBundle.plugins.DownloadUrl + ], + layout: "StandaloneLayout" + }); + + // +}; diff --git a/openapi/swagger-ui/swagger-ui-bundle.js b/openapi/swagger-ui/swagger-ui-bundle.js new file mode 100644 index 00000000..64a04935 --- /dev/null +++ b/openapi/swagger-ui/swagger-ui-bundle.js @@ -0,0 +1,2 @@ +/*! For license information please see swagger-ui-bundle.js.LICENSE.txt */ +!function webpackUniversalModuleDefinition(s,o){"object"==typeof exports&&"object"==typeof module?module.exports=o():"function"==typeof define&&define.amd?define([],o):"object"==typeof exports?exports.SwaggerUIBundle=o():s.SwaggerUIBundle=o()}(this,(()=>(()=>{var s={251:(s,o)=>{o.read=function(s,o,i,a,u){var _,w,x=8*u-a-1,C=(1<>1,L=-7,B=i?u-1:0,$=i?-1:1,U=s[o+B];for(B+=$,_=U&(1<<-L)-1,U>>=-L,L+=x;L>0;_=256*_+s[o+B],B+=$,L-=8);for(w=_&(1<<-L)-1,_>>=-L,L+=a;L>0;w=256*w+s[o+B],B+=$,L-=8);if(0===_)_=1-j;else{if(_===C)return w?NaN:1/0*(U?-1:1);w+=Math.pow(2,a),_-=j}return(U?-1:1)*w*Math.pow(2,_-a)},o.write=function(s,o,i,a,u,_){var w,x,C,j=8*_-u-1,L=(1<>1,$=23===u?Math.pow(2,-24)-Math.pow(2,-77):0,U=a?0:_-1,V=a?1:-1,z=o<0||0===o&&1/o<0?1:0;for(o=Math.abs(o),isNaN(o)||o===1/0?(x=isNaN(o)?1:0,w=L):(w=Math.floor(Math.log(o)/Math.LN2),o*(C=Math.pow(2,-w))<1&&(w--,C*=2),(o+=w+B>=1?$/C:$*Math.pow(2,1-B))*C>=2&&(w++,C/=2),w+B>=L?(x=0,w=L):w+B>=1?(x=(o*C-1)*Math.pow(2,u),w+=B):(x=o*Math.pow(2,B-1)*Math.pow(2,u),w=0));u>=8;s[i+U]=255&x,U+=V,x/=256,u-=8);for(w=w<0;s[i+U]=255&w,U+=V,w/=256,j-=8);s[i+U-V]|=128*z}},462:(s,o,i)=>{"use strict";var a=i(40975);s.exports=a},659:(s,o,i)=>{var a=i(51873),u=Object.prototype,_=u.hasOwnProperty,w=u.toString,x=a?a.toStringTag:void 0;s.exports=function getRawTag(s){var o=_.call(s,x),i=s[x];try{s[x]=void 0;var a=!0}catch(s){}var u=w.call(s);return a&&(o?s[x]=i:delete s[x]),u}},694:(s,o,i)=>{"use strict";i(91599);var a=i(37257);i(12560),s.exports=a},953:(s,o,i)=>{"use strict";s.exports=i(53375)},1733:s=>{var o=/[^\x00-\x2f\x3a-\x40\x5b-\x60\x7b-\x7f]+/g;s.exports=function asciiWords(s){return s.match(o)||[]}},1882:(s,o,i)=>{var a=i(72552),u=i(23805);s.exports=function isFunction(s){if(!u(s))return!1;var o=a(s);return"[object Function]"==o||"[object GeneratorFunction]"==o||"[object AsyncFunction]"==o||"[object Proxy]"==o}},1907:(s,o,i)=>{"use strict";var a=i(41505),u=Function.prototype,_=u.call,w=a&&u.bind.bind(_,_);s.exports=a?w:function(s){return function(){return _.apply(s,arguments)}}},2205:function(s,o,i){var a;a=void 0!==i.g?i.g:this,s.exports=function(s){if(s.CSS&&s.CSS.escape)return s.CSS.escape;var cssEscape=function(s){if(0==arguments.length)throw new TypeError("`CSS.escape` requires an argument.");for(var o,i=String(s),a=i.length,u=-1,_="",w=i.charCodeAt(0);++u=1&&o<=31||127==o||0==u&&o>=48&&o<=57||1==u&&o>=48&&o<=57&&45==w?"\\"+o.toString(16)+" ":0==u&&1==a&&45==o||!(o>=128||45==o||95==o||o>=48&&o<=57||o>=65&&o<=90||o>=97&&o<=122)?"\\"+i.charAt(u):i.charAt(u):_+="�";return _};return s.CSS||(s.CSS={}),s.CSS.escape=cssEscape,cssEscape}(a)},2209:(s,o,i)=>{"use strict";var a,u=i(9404),_=function productionTypeChecker(){invariant(!1,"ImmutablePropTypes type checking code is stripped in production.")};_.isRequired=_;var w=function getProductionTypeChecker(){return _};function getPropType(s){var o=typeof s;return Array.isArray(s)?"array":s instanceof RegExp?"object":s instanceof u.Iterable?"Immutable."+s.toSource().split(" ")[0]:o}function createChainableTypeChecker(s){function checkType(o,i,a,u,_,w){for(var x=arguments.length,C=Array(x>6?x-6:0),j=6;j>",null!=i[a]?s.apply(void 0,[i,a,u,_,w].concat(C)):o?new Error("Required "+_+" `"+w+"` was not specified in `"+u+"`."):void 0}var o=checkType.bind(null,!1);return o.isRequired=checkType.bind(null,!0),o}function createIterableSubclassTypeChecker(s,o){return function createImmutableTypeChecker(s,o){return createChainableTypeChecker((function validate(i,a,u,_,w){var x=i[a];if(!o(x)){var C=getPropType(x);return new Error("Invalid "+_+" `"+w+"` of type `"+C+"` supplied to `"+u+"`, expected `"+s+"`.")}return null}))}("Iterable."+s,(function(s){return u.Iterable.isIterable(s)&&o(s)}))}(a={listOf:w,mapOf:w,orderedMapOf:w,setOf:w,orderedSetOf:w,stackOf:w,iterableOf:w,recordOf:w,shape:w,contains:w,mapContains:w,orderedMapContains:w,list:_,map:_,orderedMap:_,set:_,orderedSet:_,stack:_,seq:_,record:_,iterable:_}).iterable.indexed=createIterableSubclassTypeChecker("Indexed",u.Iterable.isIndexed),a.iterable.keyed=createIterableSubclassTypeChecker("Keyed",u.Iterable.isKeyed),s.exports=a},2404:(s,o,i)=>{var a=i(60270);s.exports=function isEqual(s,o){return a(s,o)}},2523:s=>{s.exports=function baseFindIndex(s,o,i,a){for(var u=s.length,_=i+(a?1:-1);a?_--:++_{"use strict";var a=i(45951),u=Object.defineProperty;s.exports=function(s,o){try{u(a,s,{value:o,configurable:!0,writable:!0})}catch(i){a[s]=o}return o}},2694:(s,o,i)=>{"use strict";var a=i(6925);function emptyFunction(){}function emptyFunctionWithReset(){}emptyFunctionWithReset.resetWarningCache=emptyFunction,s.exports=function(){function shim(s,o,i,u,_,w){if(w!==a){var x=new Error("Calling PropTypes validators directly is not supported by the `prop-types` package. Use PropTypes.checkPropTypes() to call them. Read more at http://fb.me/use-check-prop-types");throw x.name="Invariant Violation",x}}function getShim(){return shim}shim.isRequired=shim;var s={array:shim,bigint:shim,bool:shim,func:shim,number:shim,object:shim,string:shim,symbol:shim,any:shim,arrayOf:getShim,element:shim,elementType:shim,instanceOf:getShim,node:shim,objectOf:getShim,oneOf:getShim,oneOfType:getShim,shape:getShim,exact:getShim,checkPropTypes:emptyFunctionWithReset,resetWarningCache:emptyFunction};return s.PropTypes=s,s}},2874:s=>{s.exports={}},2875:(s,o,i)=>{"use strict";var a=i(23045),u=i(80376);s.exports=Object.keys||function keys(s){return a(s,u)}},2955:(s,o,i)=>{"use strict";var a,u=i(65606);function _defineProperty(s,o,i){return(o=function _toPropertyKey(s){var o=function _toPrimitive(s,o){if("object"!=typeof s||null===s)return s;var i=s[Symbol.toPrimitive];if(void 0!==i){var a=i.call(s,o||"default");if("object"!=typeof a)return a;throw new TypeError("@@toPrimitive must return a primitive value.")}return("string"===o?String:Number)(s)}(s,"string");return"symbol"==typeof o?o:String(o)}(o))in s?Object.defineProperty(s,o,{value:i,enumerable:!0,configurable:!0,writable:!0}):s[o]=i,s}var _=i(86238),w=Symbol("lastResolve"),x=Symbol("lastReject"),C=Symbol("error"),j=Symbol("ended"),L=Symbol("lastPromise"),B=Symbol("handlePromise"),$=Symbol("stream");function createIterResult(s,o){return{value:s,done:o}}function readAndResolve(s){var o=s[w];if(null!==o){var i=s[$].read();null!==i&&(s[L]=null,s[w]=null,s[x]=null,o(createIterResult(i,!1)))}}function onReadable(s){u.nextTick(readAndResolve,s)}var U=Object.getPrototypeOf((function(){})),V=Object.setPrototypeOf((_defineProperty(a={get stream(){return this[$]},next:function next(){var s=this,o=this[C];if(null!==o)return Promise.reject(o);if(this[j])return Promise.resolve(createIterResult(void 0,!0));if(this[$].destroyed)return new Promise((function(o,i){u.nextTick((function(){s[C]?i(s[C]):o(createIterResult(void 0,!0))}))}));var i,a=this[L];if(a)i=new Promise(function wrapForNext(s,o){return function(i,a){s.then((function(){o[j]?i(createIterResult(void 0,!0)):o[B](i,a)}),a)}}(a,this));else{var _=this[$].read();if(null!==_)return Promise.resolve(createIterResult(_,!1));i=new Promise(this[B])}return this[L]=i,i}},Symbol.asyncIterator,(function(){return this})),_defineProperty(a,"return",(function _return(){var s=this;return new Promise((function(o,i){s[$].destroy(null,(function(s){s?i(s):o(createIterResult(void 0,!0))}))}))})),a),U);s.exports=function createReadableStreamAsyncIterator(s){var o,i=Object.create(V,(_defineProperty(o={},$,{value:s,writable:!0}),_defineProperty(o,w,{value:null,writable:!0}),_defineProperty(o,x,{value:null,writable:!0}),_defineProperty(o,C,{value:null,writable:!0}),_defineProperty(o,j,{value:s._readableState.endEmitted,writable:!0}),_defineProperty(o,B,{value:function value(s,o){var a=i[$].read();a?(i[L]=null,i[w]=null,i[x]=null,s(createIterResult(a,!1))):(i[w]=s,i[x]=o)},writable:!0}),o));return i[L]=null,_(s,(function(s){if(s&&"ERR_STREAM_PREMATURE_CLOSE"!==s.code){var o=i[x];return null!==o&&(i[L]=null,i[w]=null,i[x]=null,o(s)),void(i[C]=s)}var a=i[w];null!==a&&(i[L]=null,i[w]=null,i[x]=null,a(createIterResult(void 0,!0))),i[j]=!0})),s.on("readable",onReadable.bind(null,i)),i}},3110:(s,o,i)=>{const a=i(5187),u=i(85015),_=i(98023),w=i(53812),x=i(23805),C=i(85105),j=i(86804);class Namespace{constructor(s){this.elementMap={},this.elementDetection=[],this.Element=j.Element,this.KeyValuePair=j.KeyValuePair,s&&s.noDefault||this.useDefault(),this._attributeElementKeys=[],this._attributeElementArrayKeys=[]}use(s){return s.namespace&&s.namespace({base:this}),s.load&&s.load({base:this}),this}useDefault(){return this.register("null",j.NullElement).register("string",j.StringElement).register("number",j.NumberElement).register("boolean",j.BooleanElement).register("array",j.ArrayElement).register("object",j.ObjectElement).register("member",j.MemberElement).register("ref",j.RefElement).register("link",j.LinkElement),this.detect(a,j.NullElement,!1).detect(u,j.StringElement,!1).detect(_,j.NumberElement,!1).detect(w,j.BooleanElement,!1).detect(Array.isArray,j.ArrayElement,!1).detect(x,j.ObjectElement,!1),this}register(s,o){return this._elements=void 0,this.elementMap[s]=o,this}unregister(s){return this._elements=void 0,delete this.elementMap[s],this}detect(s,o,i){return void 0===i||i?this.elementDetection.unshift([s,o]):this.elementDetection.push([s,o]),this}toElement(s){if(s instanceof this.Element)return s;let o;for(let i=0;i{const o=s[0].toUpperCase()+s.substr(1);this._elements[o]=this.elementMap[s]}))),this._elements}get serialiser(){return new C(this)}}C.prototype.Namespace=Namespace,s.exports=Namespace},3121:(s,o,i)=>{"use strict";var a=i(65482),u=Math.min;s.exports=function(s){var o=a(s);return o>0?u(o,9007199254740991):0}},3209:(s,o,i)=>{var a=i(91596),u=i(53320),_=i(36306),w="__lodash_placeholder__",x=128,C=Math.min;s.exports=function mergeData(s,o){var i=s[1],j=o[1],L=i|j,B=L<131,$=j==x&&8==i||j==x&&256==i&&s[7].length<=o[8]||384==j&&o[7].length<=o[8]&&8==i;if(!B&&!$)return s;1&j&&(s[2]=o[2],L|=1&i?0:4);var U=o[3];if(U){var V=s[3];s[3]=V?a(V,U,o[4]):U,s[4]=V?_(s[3],w):o[4]}return(U=o[5])&&(V=s[5],s[5]=V?u(V,U,o[6]):U,s[6]=V?_(s[5],w):o[6]),(U=o[7])&&(s[7]=U),j&x&&(s[8]=null==s[8]?o[8]:C(s[8],o[8])),null==s[9]&&(s[9]=o[9]),s[0]=o[0],s[1]=L,s}},3650:(s,o,i)=>{var a=i(74335)(Object.keys,Object);s.exports=a},3656:(s,o,i)=>{s=i.nmd(s);var a=i(9325),u=i(89935),_=o&&!o.nodeType&&o,w=_&&s&&!s.nodeType&&s,x=w&&w.exports===_?a.Buffer:void 0,C=(x?x.isBuffer:void 0)||u;s.exports=C},4509:(s,o,i)=>{var a=i(12651);s.exports=function mapCacheHas(s){return a(this,s).has(s)}},4640:s=>{"use strict";var o=String;s.exports=function(s){try{return o(s)}catch(s){return"Object"}}},4664:(s,o,i)=>{var a=i(79770),u=i(63345),_=Object.prototype.propertyIsEnumerable,w=Object.getOwnPropertySymbols,x=w?function(s){return null==s?[]:(s=Object(s),a(w(s),(function(o){return _.call(s,o)})))}:u;s.exports=x},4901:(s,o,i)=>{var a=i(72552),u=i(30294),_=i(40346),w={};w["[object Float32Array]"]=w["[object Float64Array]"]=w["[object Int8Array]"]=w["[object Int16Array]"]=w["[object Int32Array]"]=w["[object Uint8Array]"]=w["[object Uint8ClampedArray]"]=w["[object Uint16Array]"]=w["[object Uint32Array]"]=!0,w["[object Arguments]"]=w["[object Array]"]=w["[object ArrayBuffer]"]=w["[object Boolean]"]=w["[object DataView]"]=w["[object Date]"]=w["[object Error]"]=w["[object Function]"]=w["[object Map]"]=w["[object Number]"]=w["[object Object]"]=w["[object RegExp]"]=w["[object Set]"]=w["[object String]"]=w["[object WeakMap]"]=!1,s.exports=function baseIsTypedArray(s){return _(s)&&u(s.length)&&!!w[a(s)]}},4993:(s,o,i)=>{"use strict";var a=i(16946),u=i(74239);s.exports=function(s){return a(u(s))}},5187:s=>{s.exports=function isNull(s){return null===s}},5419:s=>{s.exports=function(s,o,i,a){var u=new Blob(void 0!==a?[a,s]:[s],{type:i||"application/octet-stream"});if(void 0!==window.navigator.msSaveBlob)window.navigator.msSaveBlob(u,o);else{var _=window.URL&&window.URL.createObjectURL?window.URL.createObjectURL(u):window.webkitURL.createObjectURL(u),w=document.createElement("a");w.style.display="none",w.href=_,w.setAttribute("download",o),void 0===w.download&&w.setAttribute("target","_blank"),document.body.appendChild(w),w.click(),setTimeout((function(){document.body.removeChild(w),window.URL.revokeObjectURL(_)}),200)}}},5556:(s,o,i)=>{s.exports=i(2694)()},5861:(s,o,i)=>{var a=i(55580),u=i(68223),_=i(32804),w=i(76545),x=i(28303),C=i(72552),j=i(47473),L="[object Map]",B="[object Promise]",$="[object Set]",U="[object WeakMap]",V="[object DataView]",z=j(a),Y=j(u),Z=j(_),ee=j(w),ie=j(x),ae=C;(a&&ae(new a(new ArrayBuffer(1)))!=V||u&&ae(new u)!=L||_&&ae(_.resolve())!=B||w&&ae(new w)!=$||x&&ae(new x)!=U)&&(ae=function(s){var o=C(s),i="[object Object]"==o?s.constructor:void 0,a=i?j(i):"";if(a)switch(a){case z:return V;case Y:return L;case Z:return B;case ee:return $;case ie:return U}return o}),s.exports=ae},6048:s=>{s.exports=function negate(s){if("function"!=typeof s)throw new TypeError("Expected a function");return function(){var o=arguments;switch(o.length){case 0:return!s.call(this);case 1:return!s.call(this,o[0]);case 2:return!s.call(this,o[0],o[1]);case 3:return!s.call(this,o[0],o[1],o[2])}return!s.apply(this,o)}}},6188:s=>{"use strict";s.exports=Math.max},6205:s=>{s.exports={ROOT:0,GROUP:1,POSITION:2,SET:3,RANGE:4,REPETITION:5,REFERENCE:6,CHAR:7}},6233:(s,o,i)=>{const a=i(6048),u=i(10316),_=i(92340);class ArrayElement extends u{constructor(s,o,i){super(s||[],o,i),this.element="array"}primitive(){return"array"}get(s){return this.content[s]}getValue(s){const o=this.get(s);if(o)return o.toValue()}getIndex(s){return this.content[s]}set(s,o){return this.content[s]=this.refract(o),this}remove(s){const o=this.content.splice(s,1);return o.length?o[0]:null}map(s,o){return this.content.map(s,o)}flatMap(s,o){return this.map(s,o).reduce(((s,o)=>s.concat(o)),[])}compactMap(s,o){const i=[];return this.forEach((a=>{const u=s.bind(o)(a);u&&i.push(u)})),i}filter(s,o){return new _(this.content.filter(s,o))}reject(s,o){return this.filter(a(s),o)}reduce(s,o){let i,a;void 0!==o?(i=0,a=this.refract(o)):(i=1,a="object"===this.primitive()?this.first.value:this.first);for(let o=i;o{s.bind(o)(i,this.refract(a))}))}shift(){return this.content.shift()}unshift(s){this.content.unshift(this.refract(s))}push(s){return this.content.push(this.refract(s)),this}add(s){this.push(s)}findElements(s,o){const i=o||{},a=!!i.recursive,u=void 0===i.results?[]:i.results;return this.forEach(((o,i,_)=>{a&&void 0!==o.findElements&&o.findElements(s,{results:u,recursive:a}),s(o,i,_)&&u.push(o)})),u}find(s){return new _(this.findElements(s,{recursive:!0}))}findByElement(s){return this.find((o=>o.element===s))}findByClass(s){return this.find((o=>o.classes.includes(s)))}getById(s){return this.find((o=>o.id.toValue()===s)).first}includes(s){return this.content.some((o=>o.equals(s)))}contains(s){return this.includes(s)}empty(){return new this.constructor([])}"fantasy-land/empty"(){return this.empty()}concat(s){return new this.constructor(this.content.concat(s.content))}"fantasy-land/concat"(s){return this.concat(s)}"fantasy-land/map"(s){return new this.constructor(this.map(s))}"fantasy-land/chain"(s){return this.map((o=>s(o)),this).reduce(((s,o)=>s.concat(o)),this.empty())}"fantasy-land/filter"(s){return new this.constructor(this.content.filter(s))}"fantasy-land/reduce"(s,o){return this.content.reduce(s,o)}get length(){return this.content.length}get isEmpty(){return 0===this.content.length}get first(){return this.getIndex(0)}get second(){return this.getIndex(1)}get last(){return this.getIndex(this.length-1)}}ArrayElement.empty=function empty(){return new this},ArrayElement["fantasy-land/empty"]=ArrayElement.empty,"undefined"!=typeof Symbol&&(ArrayElement.prototype[Symbol.iterator]=function symbol(){return this.content[Symbol.iterator]()}),s.exports=ArrayElement},6499:(s,o,i)=>{"use strict";var a=i(1907),u=0,_=Math.random(),w=a(1..toString);s.exports=function(s){return"Symbol("+(void 0===s?"":s)+")_"+w(++u+_,36)}},6549:s=>{"use strict";s.exports=Object.getOwnPropertyDescriptor},6925:s=>{"use strict";s.exports="SECRET_DO_NOT_PASS_THIS_OR_YOU_WILL_BE_FIRED"},7057:(s,o,i)=>{"use strict";var a=i(11470).charAt,u=i(90160),_=i(64932),w=i(60183),x=i(59550),C="String Iterator",j=_.set,L=_.getterFor(C);w(String,"String",(function(s){j(this,{type:C,string:u(s),index:0})}),(function next(){var s,o=L(this),i=o.string,u=o.index;return u>=i.length?x(void 0,!0):(s=a(i,u),o.index+=s.length,x(s,!1))}))},7176:(s,o,i)=>{"use strict";var a,u=i(73126),_=i(75795);try{a=[].__proto__===Array.prototype}catch(s){if(!s||"object"!=typeof s||!("code"in s)||"ERR_PROTO_ACCESS"!==s.code)throw s}var w=!!a&&_&&_(Object.prototype,"__proto__"),x=Object,C=x.getPrototypeOf;s.exports=w&&"function"==typeof w.get?u([w.get]):"function"==typeof C&&function getDunder(s){return C(null==s?s:x(s))}},7309:(s,o,i)=>{var a=i(62006)(i(24713));s.exports=a},7376:s=>{"use strict";s.exports=!0},7463:(s,o,i)=>{"use strict";var a=i(98828),u=i(62250),_=/#|\.prototype\./,isForced=function(s,o){var i=x[w(s)];return i===j||i!==C&&(u(o)?a(o):!!o)},w=isForced.normalize=function(s){return String(s).replace(_,".").toLowerCase()},x=isForced.data={},C=isForced.NATIVE="N",j=isForced.POLYFILL="P";s.exports=isForced},7666:(s,o,i)=>{var a=i(84851),u=i(953);function _extends(){var o;return s.exports=_extends=a?u(o=a).call(o):function(s){for(var o=1;o{const a=i(6205);o.wordBoundary=()=>({type:a.POSITION,value:"b"}),o.nonWordBoundary=()=>({type:a.POSITION,value:"B"}),o.begin=()=>({type:a.POSITION,value:"^"}),o.end=()=>({type:a.POSITION,value:"$"})},8068:s=>{"use strict";var o=(()=>{var s=Object.defineProperty,o=Object.getOwnPropertyDescriptor,i=Object.getOwnPropertyNames,a=Object.getOwnPropertySymbols,u=Object.prototype.hasOwnProperty,_=Object.prototype.propertyIsEnumerable,__defNormalProp=(o,i,a)=>i in o?s(o,i,{enumerable:!0,configurable:!0,writable:!0,value:a}):o[i]=a,__spreadValues=(s,o)=>{for(var i in o||(o={}))u.call(o,i)&&__defNormalProp(s,i,o[i]);if(a)for(var i of a(o))_.call(o,i)&&__defNormalProp(s,i,o[i]);return s},__publicField=(s,o,i)=>__defNormalProp(s,"symbol"!=typeof o?o+"":o,i),w={};((o,i)=>{for(var a in i)s(o,a,{get:i[a],enumerable:!0})})(w,{DEFAULT_OPTIONS:()=>C,DEFAULT_UUID_LENGTH:()=>x,default:()=>B});var x=6,C={dictionary:"alphanum",shuffle:!0,debug:!1,length:x,counter:0},j=class _ShortUniqueId{constructor(s={}){__publicField(this,"counter"),__publicField(this,"debug"),__publicField(this,"dict"),__publicField(this,"version"),__publicField(this,"dictIndex",0),__publicField(this,"dictRange",[]),__publicField(this,"lowerBound",0),__publicField(this,"upperBound",0),__publicField(this,"dictLength",0),__publicField(this,"uuidLength"),__publicField(this,"_digit_first_ascii",48),__publicField(this,"_digit_last_ascii",58),__publicField(this,"_alpha_lower_first_ascii",97),__publicField(this,"_alpha_lower_last_ascii",123),__publicField(this,"_hex_last_ascii",103),__publicField(this,"_alpha_upper_first_ascii",65),__publicField(this,"_alpha_upper_last_ascii",91),__publicField(this,"_number_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii]}),__publicField(this,"_alpha_dict_ranges",{lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii],upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_alpha_lower_dict_ranges",{lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii]}),__publicField(this,"_alpha_upper_dict_ranges",{upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_alphanum_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii],lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii],upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_alphanum_lower_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii],lowerCase:[this._alpha_lower_first_ascii,this._alpha_lower_last_ascii]}),__publicField(this,"_alphanum_upper_dict_ranges",{digits:[this._digit_first_ascii,this._digit_last_ascii],upperCase:[this._alpha_upper_first_ascii,this._alpha_upper_last_ascii]}),__publicField(this,"_hex_dict_ranges",{decDigits:[this._digit_first_ascii,this._digit_last_ascii],alphaDigits:[this._alpha_lower_first_ascii,this._hex_last_ascii]}),__publicField(this,"_dict_ranges",{_number_dict_ranges:this._number_dict_ranges,_alpha_dict_ranges:this._alpha_dict_ranges,_alpha_lower_dict_ranges:this._alpha_lower_dict_ranges,_alpha_upper_dict_ranges:this._alpha_upper_dict_ranges,_alphanum_dict_ranges:this._alphanum_dict_ranges,_alphanum_lower_dict_ranges:this._alphanum_lower_dict_ranges,_alphanum_upper_dict_ranges:this._alphanum_upper_dict_ranges,_hex_dict_ranges:this._hex_dict_ranges}),__publicField(this,"log",((...s)=>{const o=[...s];o[0]="[short-unique-id] ".concat(s[0]),!0!==this.debug||"undefined"==typeof console||null===console||console.log(...o)})),__publicField(this,"_normalizeDictionary",((s,o)=>{let i;if(s&&Array.isArray(s)&&s.length>1)i=s;else{i=[],this.dictIndex=0;const o="_".concat(s,"_dict_ranges"),a=this._dict_ranges[o];let u=0;for(const[,s]of Object.entries(a)){const[o,i]=s;u+=Math.abs(i-o)}i=new Array(u);let _=0;for(const[,s]of Object.entries(a)){this.dictRange=s,this.lowerBound=this.dictRange[0],this.upperBound=this.dictRange[1];const o=this.lowerBound<=this.upperBound,a=this.lowerBound,u=this.upperBound;if(o)for(let s=a;su;s--)i[_++]=String.fromCharCode(s),this.dictIndex=s}i.length=_}if(o){for(let s=i.length-1;s>0;s--){const o=Math.floor(Math.random()*(s+1));[i[s],i[o]]=[i[o],i[s]]}}return i})),__publicField(this,"setDictionary",((s,o)=>{this.dict=this._normalizeDictionary(s,o),this.dictLength=this.dict.length,this.setCounter(0)})),__publicField(this,"seq",(()=>this.sequentialUUID())),__publicField(this,"sequentialUUID",(()=>{const s=this.dictLength,o=this.dict;let i=this.counter;const a=[];do{const u=i%s;i=Math.trunc(i/s),a.push(o[u])}while(0!==i);const u=a.join("");return this.counter+=1,u})),__publicField(this,"rnd",((s=this.uuidLength||x)=>this.randomUUID(s))),__publicField(this,"randomUUID",((s=this.uuidLength||x)=>{if(null==s||s<1)throw new Error("Invalid UUID Length Provided");const o=new Array(s),i=this.dictLength,a=this.dict;for(let u=0;uthis.formattedUUID(s,o))),__publicField(this,"formattedUUID",((s,o)=>{const i={$r:this.randomUUID,$s:this.sequentialUUID,$t:this.stamp};return s.replace(/\$[rs]\d{0,}|\$t0|\$t[1-9]\d{1,}/g,(s=>{const a=s.slice(0,2),u=Number.parseInt(s.slice(2),10);return"$s"===a?i[a]().padStart(u,"0"):"$t"===a&&o?i[a](u,o):i[a](u)}))})),__publicField(this,"availableUUIDs",((s=this.uuidLength)=>Number.parseFloat(([...new Set(this.dict)].length**s).toFixed(0)))),__publicField(this,"_collisionCache",new Map),__publicField(this,"approxMaxBeforeCollision",((s=this.availableUUIDs(this.uuidLength))=>{const o=s,i=this._collisionCache.get(o);if(void 0!==i)return i;const a=Number.parseFloat(Math.sqrt(Math.PI/2*s).toFixed(20));return this._collisionCache.set(o,a),a})),__publicField(this,"collisionProbability",((s=this.availableUUIDs(this.uuidLength),o=this.uuidLength)=>Number.parseFloat((this.approxMaxBeforeCollision(s)/this.availableUUIDs(o)).toFixed(20)))),__publicField(this,"uniqueness",((s=this.availableUUIDs(this.uuidLength))=>{const o=Number.parseFloat((1-this.approxMaxBeforeCollision(s)/s).toFixed(20));return o>1?1:o<0?0:o})),__publicField(this,"getVersion",(()=>this.version)),__publicField(this,"stamp",((s,o)=>{const i=Math.floor(+(o||new Date)/1e3).toString(16);if("number"==typeof s&&0===s)return i;if("number"!=typeof s||s<10)throw new Error(["Param finalLength must be a number greater than or equal to 10,","or 0 if you want the raw hexadecimal timestamp"].join("\n"));const a=s-9,u=Math.round(Math.random()*(a>15?15:a)),_=this.randomUUID(a);return"".concat(_.substring(0,u)).concat(i).concat(_.substring(u)).concat(u.toString(16))})),__publicField(this,"parseStamp",((s,o)=>{if(o&&!/t0|t[1-9]\d{1,}/.test(o))throw new Error("Cannot extract date from a formated UUID with no timestamp in the format");const i=o?o.replace(/\$[rs]\d{0,}|\$t0|\$t[1-9]\d{1,}/g,(s=>{const o={$r:s=>[...Array(s)].map((()=>"r")).join(""),$s:s=>[...Array(s)].map((()=>"s")).join(""),$t:s=>[...Array(s)].map((()=>"t")).join("")},i=s.slice(0,2),a=Number.parseInt(s.slice(2),10);return o[i](a)})).replace(/^(.*?)(t{8,})(.*)$/g,((o,i,a)=>s.substring(i.length,i.length+a.length))):s;if(8===i.length)return new Date(1e3*Number.parseInt(i,16));if(i.length<10)throw new Error("Stamp length invalid");const a=Number.parseInt(i.substring(i.length-1),16);return new Date(1e3*Number.parseInt(i.substring(a,a+8),16))})),__publicField(this,"setCounter",(s=>{this.counter=s})),__publicField(this,"validate",((s,o)=>{const i=o?this._normalizeDictionary(o):this.dict;return s.split("").every((s=>i.includes(s)))}));const o=__spreadValues(__spreadValues({},C),s);this.counter=0,this.debug=!1,this.dict=[],this.version="5.3.2";const{dictionary:i,shuffle:a,length:u,counter:_}=o;this.uuidLength=u,this.setDictionary(i,a),this.setCounter(_),this.debug=o.debug,this.log(this.dict),this.log("Generator instantiated with Dictionary Size ".concat(this.dictLength," and counter set to ").concat(this.counter)),this.log=this.log.bind(this),this.setDictionary=this.setDictionary.bind(this),this.setCounter=this.setCounter.bind(this),this.seq=this.seq.bind(this),this.sequentialUUID=this.sequentialUUID.bind(this),this.rnd=this.rnd.bind(this),this.randomUUID=this.randomUUID.bind(this),this.fmt=this.fmt.bind(this),this.formattedUUID=this.formattedUUID.bind(this),this.availableUUIDs=this.availableUUIDs.bind(this),this.approxMaxBeforeCollision=this.approxMaxBeforeCollision.bind(this),this.collisionProbability=this.collisionProbability.bind(this),this.uniqueness=this.uniqueness.bind(this),this.getVersion=this.getVersion.bind(this),this.stamp=this.stamp.bind(this),this.parseStamp=this.parseStamp.bind(this)}};__publicField(j,"default",j);var L,B=j;return L=w,((a,_,w,x)=>{if(_&&"object"==typeof _||"function"==typeof _)for(let C of i(_))u.call(a,C)||C===w||s(a,C,{get:()=>_[C],enumerable:!(x=o(_,C))||x.enumerable});return a})(s({},"__esModule",{value:!0}),L)})();s.exports=o.default,"undefined"!=typeof window&&(o=o.default)},9325:(s,o,i)=>{var a=i(34840),u="object"==typeof self&&self&&self.Object===Object&&self,_=a||u||Function("return this")();s.exports=_},9404:function(s){s.exports=function(){"use strict";var s=Array.prototype.slice;function createClass(s,o){o&&(s.prototype=Object.create(o.prototype)),s.prototype.constructor=s}function Iterable(s){return isIterable(s)?s:Seq(s)}function KeyedIterable(s){return isKeyed(s)?s:KeyedSeq(s)}function IndexedIterable(s){return isIndexed(s)?s:IndexedSeq(s)}function SetIterable(s){return isIterable(s)&&!isAssociative(s)?s:SetSeq(s)}function isIterable(s){return!(!s||!s[o])}function isKeyed(s){return!(!s||!s[i])}function isIndexed(s){return!(!s||!s[a])}function isAssociative(s){return isKeyed(s)||isIndexed(s)}function isOrdered(s){return!(!s||!s[u])}createClass(KeyedIterable,Iterable),createClass(IndexedIterable,Iterable),createClass(SetIterable,Iterable),Iterable.isIterable=isIterable,Iterable.isKeyed=isKeyed,Iterable.isIndexed=isIndexed,Iterable.isAssociative=isAssociative,Iterable.isOrdered=isOrdered,Iterable.Keyed=KeyedIterable,Iterable.Indexed=IndexedIterable,Iterable.Set=SetIterable;var o="@@__IMMUTABLE_ITERABLE__@@",i="@@__IMMUTABLE_KEYED__@@",a="@@__IMMUTABLE_INDEXED__@@",u="@@__IMMUTABLE_ORDERED__@@",_="delete",w=5,x=1<>>0;if(""+i!==o||4294967295===i)return NaN;o=i}return o<0?ensureSize(s)+o:o}function returnTrue(){return!0}function wholeSlice(s,o,i){return(0===s||void 0!==i&&s<=-i)&&(void 0===o||void 0!==i&&o>=i)}function resolveBegin(s,o){return resolveIndex(s,o,0)}function resolveEnd(s,o){return resolveIndex(s,o,o)}function resolveIndex(s,o,i){return void 0===s?i:s<0?Math.max(0,o+s):void 0===o?s:Math.min(o,s)}var $=0,U=1,V=2,z="function"==typeof Symbol&&Symbol.iterator,Y="@@iterator",Z=z||Y;function Iterator(s){this.next=s}function iteratorValue(s,o,i,a){var u=0===s?o:1===s?i:[o,i];return a?a.value=u:a={value:u,done:!1},a}function iteratorDone(){return{value:void 0,done:!0}}function hasIterator(s){return!!getIteratorFn(s)}function isIterator(s){return s&&"function"==typeof s.next}function getIterator(s){var o=getIteratorFn(s);return o&&o.call(s)}function getIteratorFn(s){var o=s&&(z&&s[z]||s[Y]);if("function"==typeof o)return o}function isArrayLike(s){return s&&"number"==typeof s.length}function Seq(s){return null==s?emptySequence():isIterable(s)?s.toSeq():seqFromValue(s)}function KeyedSeq(s){return null==s?emptySequence().toKeyedSeq():isIterable(s)?isKeyed(s)?s.toSeq():s.fromEntrySeq():keyedSeqFromValue(s)}function IndexedSeq(s){return null==s?emptySequence():isIterable(s)?isKeyed(s)?s.entrySeq():s.toIndexedSeq():indexedSeqFromValue(s)}function SetSeq(s){return(null==s?emptySequence():isIterable(s)?isKeyed(s)?s.entrySeq():s:indexedSeqFromValue(s)).toSetSeq()}Iterator.prototype.toString=function(){return"[Iterator]"},Iterator.KEYS=$,Iterator.VALUES=U,Iterator.ENTRIES=V,Iterator.prototype.inspect=Iterator.prototype.toSource=function(){return this.toString()},Iterator.prototype[Z]=function(){return this},createClass(Seq,Iterable),Seq.of=function(){return Seq(arguments)},Seq.prototype.toSeq=function(){return this},Seq.prototype.toString=function(){return this.__toString("Seq {","}")},Seq.prototype.cacheResult=function(){return!this._cache&&this.__iterateUncached&&(this._cache=this.entrySeq().toArray(),this.size=this._cache.length),this},Seq.prototype.__iterate=function(s,o){return seqIterate(this,s,o,!0)},Seq.prototype.__iterator=function(s,o){return seqIterator(this,s,o,!0)},createClass(KeyedSeq,Seq),KeyedSeq.prototype.toKeyedSeq=function(){return this},createClass(IndexedSeq,Seq),IndexedSeq.of=function(){return IndexedSeq(arguments)},IndexedSeq.prototype.toIndexedSeq=function(){return this},IndexedSeq.prototype.toString=function(){return this.__toString("Seq [","]")},IndexedSeq.prototype.__iterate=function(s,o){return seqIterate(this,s,o,!1)},IndexedSeq.prototype.__iterator=function(s,o){return seqIterator(this,s,o,!1)},createClass(SetSeq,Seq),SetSeq.of=function(){return SetSeq(arguments)},SetSeq.prototype.toSetSeq=function(){return this},Seq.isSeq=isSeq,Seq.Keyed=KeyedSeq,Seq.Set=SetSeq,Seq.Indexed=IndexedSeq;var ee,ie,ae,ce="@@__IMMUTABLE_SEQ__@@";function ArraySeq(s){this._array=s,this.size=s.length}function ObjectSeq(s){var o=Object.keys(s);this._object=s,this._keys=o,this.size=o.length}function IterableSeq(s){this._iterable=s,this.size=s.length||s.size}function IteratorSeq(s){this._iterator=s,this._iteratorCache=[]}function isSeq(s){return!(!s||!s[ce])}function emptySequence(){return ee||(ee=new ArraySeq([]))}function keyedSeqFromValue(s){var o=Array.isArray(s)?new ArraySeq(s).fromEntrySeq():isIterator(s)?new IteratorSeq(s).fromEntrySeq():hasIterator(s)?new IterableSeq(s).fromEntrySeq():"object"==typeof s?new ObjectSeq(s):void 0;if(!o)throw new TypeError("Expected Array or iterable object of [k, v] entries, or keyed object: "+s);return o}function indexedSeqFromValue(s){var o=maybeIndexedSeqFromValue(s);if(!o)throw new TypeError("Expected Array or iterable object of values: "+s);return o}function seqFromValue(s){var o=maybeIndexedSeqFromValue(s)||"object"==typeof s&&new ObjectSeq(s);if(!o)throw new TypeError("Expected Array or iterable object of values, or keyed object: "+s);return o}function maybeIndexedSeqFromValue(s){return isArrayLike(s)?new ArraySeq(s):isIterator(s)?new IteratorSeq(s):hasIterator(s)?new IterableSeq(s):void 0}function seqIterate(s,o,i,a){var u=s._cache;if(u){for(var _=u.length-1,w=0;w<=_;w++){var x=u[i?_-w:w];if(!1===o(x[1],a?x[0]:w,s))return w+1}return w}return s.__iterateUncached(o,i)}function seqIterator(s,o,i,a){var u=s._cache;if(u){var _=u.length-1,w=0;return new Iterator((function(){var s=u[i?_-w:w];return w++>_?iteratorDone():iteratorValue(o,a?s[0]:w-1,s[1])}))}return s.__iteratorUncached(o,i)}function fromJS(s,o){return o?fromJSWith(o,s,"",{"":s}):fromJSDefault(s)}function fromJSWith(s,o,i,a){return Array.isArray(o)?s.call(a,i,IndexedSeq(o).map((function(i,a){return fromJSWith(s,i,a,o)}))):isPlainObj(o)?s.call(a,i,KeyedSeq(o).map((function(i,a){return fromJSWith(s,i,a,o)}))):o}function fromJSDefault(s){return Array.isArray(s)?IndexedSeq(s).map(fromJSDefault).toList():isPlainObj(s)?KeyedSeq(s).map(fromJSDefault).toMap():s}function isPlainObj(s){return s&&(s.constructor===Object||void 0===s.constructor)}function is(s,o){if(s===o||s!=s&&o!=o)return!0;if(!s||!o)return!1;if("function"==typeof s.valueOf&&"function"==typeof o.valueOf){if((s=s.valueOf())===(o=o.valueOf())||s!=s&&o!=o)return!0;if(!s||!o)return!1}return!("function"!=typeof s.equals||"function"!=typeof o.equals||!s.equals(o))}function deepEqual(s,o){if(s===o)return!0;if(!isIterable(o)||void 0!==s.size&&void 0!==o.size&&s.size!==o.size||void 0!==s.__hash&&void 0!==o.__hash&&s.__hash!==o.__hash||isKeyed(s)!==isKeyed(o)||isIndexed(s)!==isIndexed(o)||isOrdered(s)!==isOrdered(o))return!1;if(0===s.size&&0===o.size)return!0;var i=!isAssociative(s);if(isOrdered(s)){var a=s.entries();return o.every((function(s,o){var u=a.next().value;return u&&is(u[1],s)&&(i||is(u[0],o))}))&&a.next().done}var u=!1;if(void 0===s.size)if(void 0===o.size)"function"==typeof s.cacheResult&&s.cacheResult();else{u=!0;var _=s;s=o,o=_}var w=!0,x=o.__iterate((function(o,a){if(i?!s.has(o):u?!is(o,s.get(a,j)):!is(s.get(a,j),o))return w=!1,!1}));return w&&s.size===x}function Repeat(s,o){if(!(this instanceof Repeat))return new Repeat(s,o);if(this._value=s,this.size=void 0===o?1/0:Math.max(0,o),0===this.size){if(ie)return ie;ie=this}}function invariant(s,o){if(!s)throw new Error(o)}function Range(s,o,i){if(!(this instanceof Range))return new Range(s,o,i);if(invariant(0!==i,"Cannot step a Range by 0"),s=s||0,void 0===o&&(o=1/0),i=void 0===i?1:Math.abs(i),oa?iteratorDone():iteratorValue(s,u,i[o?a-u++:u++])}))},createClass(ObjectSeq,KeyedSeq),ObjectSeq.prototype.get=function(s,o){return void 0===o||this.has(s)?this._object[s]:o},ObjectSeq.prototype.has=function(s){return this._object.hasOwnProperty(s)},ObjectSeq.prototype.__iterate=function(s,o){for(var i=this._object,a=this._keys,u=a.length-1,_=0;_<=u;_++){var w=a[o?u-_:_];if(!1===s(i[w],w,this))return _+1}return _},ObjectSeq.prototype.__iterator=function(s,o){var i=this._object,a=this._keys,u=a.length-1,_=0;return new Iterator((function(){var w=a[o?u-_:_];return _++>u?iteratorDone():iteratorValue(s,w,i[w])}))},ObjectSeq.prototype[u]=!0,createClass(IterableSeq,IndexedSeq),IterableSeq.prototype.__iterateUncached=function(s,o){if(o)return this.cacheResult().__iterate(s,o);var i=getIterator(this._iterable),a=0;if(isIterator(i))for(var u;!(u=i.next()).done&&!1!==s(u.value,a++,this););return a},IterableSeq.prototype.__iteratorUncached=function(s,o){if(o)return this.cacheResult().__iterator(s,o);var i=getIterator(this._iterable);if(!isIterator(i))return new Iterator(iteratorDone);var a=0;return new Iterator((function(){var o=i.next();return o.done?o:iteratorValue(s,a++,o.value)}))},createClass(IteratorSeq,IndexedSeq),IteratorSeq.prototype.__iterateUncached=function(s,o){if(o)return this.cacheResult().__iterate(s,o);for(var i,a=this._iterator,u=this._iteratorCache,_=0;_=a.length){var o=i.next();if(o.done)return o;a[u]=o.value}return iteratorValue(s,u,a[u++])}))},createClass(Repeat,IndexedSeq),Repeat.prototype.toString=function(){return 0===this.size?"Repeat []":"Repeat [ "+this._value+" "+this.size+" times ]"},Repeat.prototype.get=function(s,o){return this.has(s)?this._value:o},Repeat.prototype.includes=function(s){return is(this._value,s)},Repeat.prototype.slice=function(s,o){var i=this.size;return wholeSlice(s,o,i)?this:new Repeat(this._value,resolveEnd(o,i)-resolveBegin(s,i))},Repeat.prototype.reverse=function(){return this},Repeat.prototype.indexOf=function(s){return is(this._value,s)?0:-1},Repeat.prototype.lastIndexOf=function(s){return is(this._value,s)?this.size:-1},Repeat.prototype.__iterate=function(s,o){for(var i=0;i=0&&o=0&&ii?iteratorDone():iteratorValue(s,_++,w)}))},Range.prototype.equals=function(s){return s instanceof Range?this._start===s._start&&this._end===s._end&&this._step===s._step:deepEqual(this,s)},createClass(Collection,Iterable),createClass(KeyedCollection,Collection),createClass(IndexedCollection,Collection),createClass(SetCollection,Collection),Collection.Keyed=KeyedCollection,Collection.Indexed=IndexedCollection,Collection.Set=SetCollection;var le="function"==typeof Math.imul&&-2===Math.imul(4294967295,2)?Math.imul:function imul(s,o){var i=65535&(s|=0),a=65535&(o|=0);return i*a+((s>>>16)*a+i*(o>>>16)<<16>>>0)|0};function smi(s){return s>>>1&1073741824|3221225471&s}function hash(s){if(!1===s||null==s)return 0;if("function"==typeof s.valueOf&&(!1===(s=s.valueOf())||null==s))return 0;if(!0===s)return 1;var o=typeof s;if("number"===o){if(s!=s||s===1/0)return 0;var i=0|s;for(i!==s&&(i^=4294967295*s);s>4294967295;)i^=s/=4294967295;return smi(i)}if("string"===o)return s.length>Se?cachedHashString(s):hashString(s);if("function"==typeof s.hashCode)return s.hashCode();if("object"===o)return hashJSObj(s);if("function"==typeof s.toString)return hashString(s.toString());throw new Error("Value type "+o+" cannot be hashed.")}function cachedHashString(s){var o=Pe[s];return void 0===o&&(o=hashString(s),xe===we&&(xe=0,Pe={}),xe++,Pe[s]=o),o}function hashString(s){for(var o=0,i=0;i0)switch(s.nodeType){case 1:return s.uniqueID;case 9:return s.documentElement&&s.documentElement.uniqueID}}var fe,ye="function"==typeof WeakMap;ye&&(fe=new WeakMap);var be=0,_e="__immutablehash__";"function"==typeof Symbol&&(_e=Symbol(_e));var Se=16,we=255,xe=0,Pe={};function assertNotInfinite(s){invariant(s!==1/0,"Cannot perform this action with an infinite size.")}function Map(s){return null==s?emptyMap():isMap(s)&&!isOrdered(s)?s:emptyMap().withMutations((function(o){var i=KeyedIterable(s);assertNotInfinite(i.size),i.forEach((function(s,i){return o.set(i,s)}))}))}function isMap(s){return!(!s||!s[Re])}createClass(Map,KeyedCollection),Map.of=function(){var o=s.call(arguments,0);return emptyMap().withMutations((function(s){for(var i=0;i=o.length)throw new Error("Missing value for key: "+o[i]);s.set(o[i],o[i+1])}}))},Map.prototype.toString=function(){return this.__toString("Map {","}")},Map.prototype.get=function(s,o){return this._root?this._root.get(0,void 0,s,o):o},Map.prototype.set=function(s,o){return updateMap(this,s,o)},Map.prototype.setIn=function(s,o){return this.updateIn(s,j,(function(){return o}))},Map.prototype.remove=function(s){return updateMap(this,s,j)},Map.prototype.deleteIn=function(s){return this.updateIn(s,(function(){return j}))},Map.prototype.update=function(s,o,i){return 1===arguments.length?s(this):this.updateIn([s],o,i)},Map.prototype.updateIn=function(s,o,i){i||(i=o,o=void 0);var a=updateInDeepMap(this,forceIterator(s),o,i);return a===j?void 0:a},Map.prototype.clear=function(){return 0===this.size?this:this.__ownerID?(this.size=0,this._root=null,this.__hash=void 0,this.__altered=!0,this):emptyMap()},Map.prototype.merge=function(){return mergeIntoMapWith(this,void 0,arguments)},Map.prototype.mergeWith=function(o){return mergeIntoMapWith(this,o,s.call(arguments,1))},Map.prototype.mergeIn=function(o){var i=s.call(arguments,1);return this.updateIn(o,emptyMap(),(function(s){return"function"==typeof s.merge?s.merge.apply(s,i):i[i.length-1]}))},Map.prototype.mergeDeep=function(){return mergeIntoMapWith(this,deepMerger,arguments)},Map.prototype.mergeDeepWith=function(o){var i=s.call(arguments,1);return mergeIntoMapWith(this,deepMergerWith(o),i)},Map.prototype.mergeDeepIn=function(o){var i=s.call(arguments,1);return this.updateIn(o,emptyMap(),(function(s){return"function"==typeof s.mergeDeep?s.mergeDeep.apply(s,i):i[i.length-1]}))},Map.prototype.sort=function(s){return OrderedMap(sortFactory(this,s))},Map.prototype.sortBy=function(s,o){return OrderedMap(sortFactory(this,o,s))},Map.prototype.withMutations=function(s){var o=this.asMutable();return s(o),o.wasAltered()?o.__ensureOwner(this.__ownerID):this},Map.prototype.asMutable=function(){return this.__ownerID?this:this.__ensureOwner(new OwnerID)},Map.prototype.asImmutable=function(){return this.__ensureOwner()},Map.prototype.wasAltered=function(){return this.__altered},Map.prototype.__iterator=function(s,o){return new MapIterator(this,s,o)},Map.prototype.__iterate=function(s,o){var i=this,a=0;return this._root&&this._root.iterate((function(o){return a++,s(o[1],o[0],i)}),o),a},Map.prototype.__ensureOwner=function(s){return s===this.__ownerID?this:s?makeMap(this.size,this._root,s,this.__hash):(this.__ownerID=s,this.__altered=!1,this)},Map.isMap=isMap;var Te,Re="@@__IMMUTABLE_MAP__@@",$e=Map.prototype;function ArrayMapNode(s,o){this.ownerID=s,this.entries=o}function BitmapIndexedNode(s,o,i){this.ownerID=s,this.bitmap=o,this.nodes=i}function HashArrayMapNode(s,o,i){this.ownerID=s,this.count=o,this.nodes=i}function HashCollisionNode(s,o,i){this.ownerID=s,this.keyHash=o,this.entries=i}function ValueNode(s,o,i){this.ownerID=s,this.keyHash=o,this.entry=i}function MapIterator(s,o,i){this._type=o,this._reverse=i,this._stack=s._root&&mapIteratorFrame(s._root)}function mapIteratorValue(s,o){return iteratorValue(s,o[0],o[1])}function mapIteratorFrame(s,o){return{node:s,index:0,__prev:o}}function makeMap(s,o,i,a){var u=Object.create($e);return u.size=s,u._root=o,u.__ownerID=i,u.__hash=a,u.__altered=!1,u}function emptyMap(){return Te||(Te=makeMap(0))}function updateMap(s,o,i){var a,u;if(s._root){var _=MakeRef(L),w=MakeRef(B);if(a=updateNode(s._root,s.__ownerID,0,void 0,o,i,_,w),!w.value)return s;u=s.size+(_.value?i===j?-1:1:0)}else{if(i===j)return s;u=1,a=new ArrayMapNode(s.__ownerID,[[o,i]])}return s.__ownerID?(s.size=u,s._root=a,s.__hash=void 0,s.__altered=!0,s):a?makeMap(u,a):emptyMap()}function updateNode(s,o,i,a,u,_,w,x){return s?s.update(o,i,a,u,_,w,x):_===j?s:(SetRef(x),SetRef(w),new ValueNode(o,a,[u,_]))}function isLeafNode(s){return s.constructor===ValueNode||s.constructor===HashCollisionNode}function mergeIntoNode(s,o,i,a,u){if(s.keyHash===a)return new HashCollisionNode(o,a,[s.entry,u]);var _,x=(0===i?s.keyHash:s.keyHash>>>i)&C,j=(0===i?a:a>>>i)&C;return new BitmapIndexedNode(o,1<>>=1)w[C]=1&i?o[_++]:void 0;return w[a]=u,new HashArrayMapNode(s,_+1,w)}function mergeIntoMapWith(s,o,i){for(var a=[],u=0;u>1&1431655765))+(s>>2&858993459))+(s>>4)&252645135,s+=s>>8,127&(s+=s>>16)}function setIn(s,o,i,a){var u=a?s:arrCopy(s);return u[o]=i,u}function spliceIn(s,o,i,a){var u=s.length+1;if(a&&o+1===u)return s[o]=i,s;for(var _=new Array(u),w=0,x=0;x=qe)return createNodes(s,C,a,u);var U=s&&s===this.ownerID,V=U?C:arrCopy(C);return $?x?L===B-1?V.pop():V[L]=V.pop():V[L]=[a,u]:V.push([a,u]),U?(this.entries=V,this):new ArrayMapNode(s,V)}},BitmapIndexedNode.prototype.get=function(s,o,i,a){void 0===o&&(o=hash(i));var u=1<<((0===s?o:o>>>s)&C),_=this.bitmap;return _&u?this.nodes[popCount(_&u-1)].get(s+w,o,i,a):a},BitmapIndexedNode.prototype.update=function(s,o,i,a,u,_,x){void 0===i&&(i=hash(a));var L=(0===o?i:i>>>o)&C,B=1<=ze)return expandNodes(s,z,$,L,Z);if(U&&!Z&&2===z.length&&isLeafNode(z[1^V]))return z[1^V];if(U&&Z&&1===z.length&&isLeafNode(Z))return Z;var ee=s&&s===this.ownerID,ie=U?Z?$:$^B:$|B,ae=U?Z?setIn(z,V,Z,ee):spliceOut(z,V,ee):spliceIn(z,V,Z,ee);return ee?(this.bitmap=ie,this.nodes=ae,this):new BitmapIndexedNode(s,ie,ae)},HashArrayMapNode.prototype.get=function(s,o,i,a){void 0===o&&(o=hash(i));var u=(0===s?o:o>>>s)&C,_=this.nodes[u];return _?_.get(s+w,o,i,a):a},HashArrayMapNode.prototype.update=function(s,o,i,a,u,_,x){void 0===i&&(i=hash(a));var L=(0===o?i:i>>>o)&C,B=u===j,$=this.nodes,U=$[L];if(B&&!U)return this;var V=updateNode(U,s,o+w,i,a,u,_,x);if(V===U)return this;var z=this.count;if(U){if(!V&&--z0&&a=0&&s>>o&C;if(a>=this.array.length)return new VNode([],s);var u,_=0===a;if(o>0){var x=this.array[a];if((u=x&&x.removeBefore(s,o-w,i))===x&&_)return this}if(_&&!u)return this;var j=editableVNode(this,s);if(!_)for(var L=0;L>>o&C;if(u>=this.array.length)return this;if(o>0){var _=this.array[u];if((a=_&&_.removeAfter(s,o-w,i))===_&&u===this.array.length-1)return this}var x=editableVNode(this,s);return x.array.splice(u+1),a&&(x.array[u]=a),x};var Xe,Qe,et={};function iterateList(s,o){var i=s._origin,a=s._capacity,u=getTailOffset(a),_=s._tail;return iterateNodeOrLeaf(s._root,s._level,0);function iterateNodeOrLeaf(s,o,i){return 0===o?iterateLeaf(s,i):iterateNode(s,o,i)}function iterateLeaf(s,w){var C=w===u?_&&_.array:s&&s.array,j=w>i?0:i-w,L=a-w;return L>x&&(L=x),function(){if(j===L)return et;var s=o?--L:j++;return C&&C[s]}}function iterateNode(s,u,_){var C,j=s&&s.array,L=_>i?0:i-_>>u,B=1+(a-_>>u);return B>x&&(B=x),function(){for(;;){if(C){var s=C();if(s!==et)return s;C=null}if(L===B)return et;var i=o?--B:L++;C=iterateNodeOrLeaf(j&&j[i],u-w,_+(i<=s.size||o<0)return s.withMutations((function(s){o<0?setListBounds(s,o).set(0,i):setListBounds(s,0,o+1).set(o,i)}));o+=s._origin;var a=s._tail,u=s._root,_=MakeRef(B);return o>=getTailOffset(s._capacity)?a=updateVNode(a,s.__ownerID,0,o,i,_):u=updateVNode(u,s.__ownerID,s._level,o,i,_),_.value?s.__ownerID?(s._root=u,s._tail=a,s.__hash=void 0,s.__altered=!0,s):makeList(s._origin,s._capacity,s._level,u,a):s}function updateVNode(s,o,i,a,u,_){var x,j=a>>>i&C,L=s&&j0){var B=s&&s.array[j],$=updateVNode(B,o,i-w,a,u,_);return $===B?s:((x=editableVNode(s,o)).array[j]=$,x)}return L&&s.array[j]===u?s:(SetRef(_),x=editableVNode(s,o),void 0===u&&j===x.array.length-1?x.array.pop():x.array[j]=u,x)}function editableVNode(s,o){return o&&s&&o===s.ownerID?s:new VNode(s?s.array.slice():[],o)}function listNodeFor(s,o){if(o>=getTailOffset(s._capacity))return s._tail;if(o<1<0;)i=i.array[o>>>a&C],a-=w;return i}}function setListBounds(s,o,i){void 0!==o&&(o|=0),void 0!==i&&(i|=0);var a=s.__ownerID||new OwnerID,u=s._origin,_=s._capacity,x=u+o,j=void 0===i?_:i<0?_+i:u+i;if(x===u&&j===_)return s;if(x>=j)return s.clear();for(var L=s._level,B=s._root,$=0;x+$<0;)B=new VNode(B&&B.array.length?[void 0,B]:[],a),$+=1<<(L+=w);$&&(x+=$,u+=$,j+=$,_+=$);for(var U=getTailOffset(_),V=getTailOffset(j);V>=1<U?new VNode([],a):z;if(z&&V>U&&x<_&&z.array.length){for(var Z=B=editableVNode(B,a),ee=L;ee>w;ee-=w){var ie=U>>>ee&C;Z=Z.array[ie]=editableVNode(Z.array[ie],a)}Z.array[U>>>w&C]=z}if(j<_&&(Y=Y&&Y.removeAfter(a,0,j)),x>=V)x-=V,j-=V,L=w,B=null,Y=Y&&Y.removeBefore(a,0,x);else if(x>u||V>>L&C;if(ae!==V>>>L&C)break;ae&&($+=(1<u&&(B=B.removeBefore(a,L,x-$)),B&&Vu&&(u=x.size),isIterable(w)||(x=x.map((function(s){return fromJS(s)}))),a.push(x)}return u>s.size&&(s=s.setSize(u)),mergeIntoCollectionWith(s,o,a)}function getTailOffset(s){return s>>w<=x&&w.size>=2*_.size?(a=(u=w.filter((function(s,o){return void 0!==s&&C!==o}))).toKeyedSeq().map((function(s){return s[0]})).flip().toMap(),s.__ownerID&&(a.__ownerID=u.__ownerID=s.__ownerID)):(a=_.remove(o),u=C===w.size-1?w.pop():w.set(C,void 0))}else if(L){if(i===w.get(C)[1])return s;a=_,u=w.set(C,[o,i])}else a=_.set(o,w.size),u=w.set(w.size,[o,i]);return s.__ownerID?(s.size=a.size,s._map=a,s._list=u,s.__hash=void 0,s):makeOrderedMap(a,u)}function ToKeyedSequence(s,o){this._iter=s,this._useKeys=o,this.size=s.size}function ToIndexedSequence(s){this._iter=s,this.size=s.size}function ToSetSequence(s){this._iter=s,this.size=s.size}function FromEntriesSequence(s){this._iter=s,this.size=s.size}function flipFactory(s){var o=makeSequence(s);return o._iter=s,o.size=s.size,o.flip=function(){return s},o.reverse=function(){var o=s.reverse.apply(this);return o.flip=function(){return s.reverse()},o},o.has=function(o){return s.includes(o)},o.includes=function(o){return s.has(o)},o.cacheResult=cacheResultThrough,o.__iterateUncached=function(o,i){var a=this;return s.__iterate((function(s,i){return!1!==o(i,s,a)}),i)},o.__iteratorUncached=function(o,i){if(o===V){var a=s.__iterator(o,i);return new Iterator((function(){var s=a.next();if(!s.done){var o=s.value[0];s.value[0]=s.value[1],s.value[1]=o}return s}))}return s.__iterator(o===U?$:U,i)},o}function mapFactory(s,o,i){var a=makeSequence(s);return a.size=s.size,a.has=function(o){return s.has(o)},a.get=function(a,u){var _=s.get(a,j);return _===j?u:o.call(i,_,a,s)},a.__iterateUncached=function(a,u){var _=this;return s.__iterate((function(s,u,w){return!1!==a(o.call(i,s,u,w),u,_)}),u)},a.__iteratorUncached=function(a,u){var _=s.__iterator(V,u);return new Iterator((function(){var u=_.next();if(u.done)return u;var w=u.value,x=w[0];return iteratorValue(a,x,o.call(i,w[1],x,s),u)}))},a}function reverseFactory(s,o){var i=makeSequence(s);return i._iter=s,i.size=s.size,i.reverse=function(){return s},s.flip&&(i.flip=function(){var o=flipFactory(s);return o.reverse=function(){return s.flip()},o}),i.get=function(i,a){return s.get(o?i:-1-i,a)},i.has=function(i){return s.has(o?i:-1-i)},i.includes=function(o){return s.includes(o)},i.cacheResult=cacheResultThrough,i.__iterate=function(o,i){var a=this;return s.__iterate((function(s,i){return o(s,i,a)}),!i)},i.__iterator=function(o,i){return s.__iterator(o,!i)},i}function filterFactory(s,o,i,a){var u=makeSequence(s);return a&&(u.has=function(a){var u=s.get(a,j);return u!==j&&!!o.call(i,u,a,s)},u.get=function(a,u){var _=s.get(a,j);return _!==j&&o.call(i,_,a,s)?_:u}),u.__iterateUncached=function(u,_){var w=this,x=0;return s.__iterate((function(s,_,C){if(o.call(i,s,_,C))return x++,u(s,a?_:x-1,w)}),_),x},u.__iteratorUncached=function(u,_){var w=s.__iterator(V,_),x=0;return new Iterator((function(){for(;;){var _=w.next();if(_.done)return _;var C=_.value,j=C[0],L=C[1];if(o.call(i,L,j,s))return iteratorValue(u,a?j:x++,L,_)}}))},u}function countByFactory(s,o,i){var a=Map().asMutable();return s.__iterate((function(u,_){a.update(o.call(i,u,_,s),0,(function(s){return s+1}))})),a.asImmutable()}function groupByFactory(s,o,i){var a=isKeyed(s),u=(isOrdered(s)?OrderedMap():Map()).asMutable();s.__iterate((function(_,w){u.update(o.call(i,_,w,s),(function(s){return(s=s||[]).push(a?[w,_]:_),s}))}));var _=iterableClass(s);return u.map((function(o){return reify(s,_(o))}))}function sliceFactory(s,o,i,a){var u=s.size;if(void 0!==o&&(o|=0),void 0!==i&&(i===1/0?i=u:i|=0),wholeSlice(o,i,u))return s;var _=resolveBegin(o,u),w=resolveEnd(i,u);if(_!=_||w!=w)return sliceFactory(s.toSeq().cacheResult(),o,i,a);var x,C=w-_;C==C&&(x=C<0?0:C);var j=makeSequence(s);return j.size=0===x?x:s.size&&x||void 0,!a&&isSeq(s)&&x>=0&&(j.get=function(o,i){return(o=wrapIndex(this,o))>=0&&ox)return iteratorDone();var s=u.next();return a||o===U?s:iteratorValue(o,C-1,o===$?void 0:s.value[1],s)}))},j}function takeWhileFactory(s,o,i){var a=makeSequence(s);return a.__iterateUncached=function(a,u){var _=this;if(u)return this.cacheResult().__iterate(a,u);var w=0;return s.__iterate((function(s,u,x){return o.call(i,s,u,x)&&++w&&a(s,u,_)})),w},a.__iteratorUncached=function(a,u){var _=this;if(u)return this.cacheResult().__iterator(a,u);var w=s.__iterator(V,u),x=!0;return new Iterator((function(){if(!x)return iteratorDone();var s=w.next();if(s.done)return s;var u=s.value,C=u[0],j=u[1];return o.call(i,j,C,_)?a===V?s:iteratorValue(a,C,j,s):(x=!1,iteratorDone())}))},a}function skipWhileFactory(s,o,i,a){var u=makeSequence(s);return u.__iterateUncached=function(u,_){var w=this;if(_)return this.cacheResult().__iterate(u,_);var x=!0,C=0;return s.__iterate((function(s,_,j){if(!x||!(x=o.call(i,s,_,j)))return C++,u(s,a?_:C-1,w)})),C},u.__iteratorUncached=function(u,_){var w=this;if(_)return this.cacheResult().__iterator(u,_);var x=s.__iterator(V,_),C=!0,j=0;return new Iterator((function(){var s,_,L;do{if((s=x.next()).done)return a||u===U?s:iteratorValue(u,j++,u===$?void 0:s.value[1],s);var B=s.value;_=B[0],L=B[1],C&&(C=o.call(i,L,_,w))}while(C);return u===V?s:iteratorValue(u,_,L,s)}))},u}function concatFactory(s,o){var i=isKeyed(s),a=[s].concat(o).map((function(s){return isIterable(s)?i&&(s=KeyedIterable(s)):s=i?keyedSeqFromValue(s):indexedSeqFromValue(Array.isArray(s)?s:[s]),s})).filter((function(s){return 0!==s.size}));if(0===a.length)return s;if(1===a.length){var u=a[0];if(u===s||i&&isKeyed(u)||isIndexed(s)&&isIndexed(u))return u}var _=new ArraySeq(a);return i?_=_.toKeyedSeq():isIndexed(s)||(_=_.toSetSeq()),(_=_.flatten(!0)).size=a.reduce((function(s,o){if(void 0!==s){var i=o.size;if(void 0!==i)return s+i}}),0),_}function flattenFactory(s,o,i){var a=makeSequence(s);return a.__iterateUncached=function(a,u){var _=0,w=!1;function flatDeep(s,x){var C=this;s.__iterate((function(s,u){return(!o||x0}function zipWithFactory(s,o,i){var a=makeSequence(s);return a.size=new ArraySeq(i).map((function(s){return s.size})).min(),a.__iterate=function(s,o){for(var i,a=this.__iterator(U,o),u=0;!(i=a.next()).done&&!1!==s(i.value,u++,this););return u},a.__iteratorUncached=function(s,a){var u=i.map((function(s){return s=Iterable(s),getIterator(a?s.reverse():s)})),_=0,w=!1;return new Iterator((function(){var i;return w||(i=u.map((function(s){return s.next()})),w=i.some((function(s){return s.done}))),w?iteratorDone():iteratorValue(s,_++,o.apply(null,i.map((function(s){return s.value}))))}))},a}function reify(s,o){return isSeq(s)?o:s.constructor(o)}function validateEntry(s){if(s!==Object(s))throw new TypeError("Expected [K, V] tuple: "+s)}function resolveSize(s){return assertNotInfinite(s.size),ensureSize(s)}function iterableClass(s){return isKeyed(s)?KeyedIterable:isIndexed(s)?IndexedIterable:SetIterable}function makeSequence(s){return Object.create((isKeyed(s)?KeyedSeq:isIndexed(s)?IndexedSeq:SetSeq).prototype)}function cacheResultThrough(){return this._iter.cacheResult?(this._iter.cacheResult(),this.size=this._iter.size,this):Seq.prototype.cacheResult.call(this)}function defaultComparator(s,o){return s>o?1:s=0;i--)o={value:arguments[i],next:o};return this.__ownerID?(this.size=s,this._head=o,this.__hash=void 0,this.__altered=!0,this):makeStack(s,o)},Stack.prototype.pushAll=function(s){if(0===(s=IndexedIterable(s)).size)return this;assertNotInfinite(s.size);var o=this.size,i=this._head;return s.reverse().forEach((function(s){o++,i={value:s,next:i}})),this.__ownerID?(this.size=o,this._head=i,this.__hash=void 0,this.__altered=!0,this):makeStack(o,i)},Stack.prototype.pop=function(){return this.slice(1)},Stack.prototype.unshift=function(){return this.push.apply(this,arguments)},Stack.prototype.unshiftAll=function(s){return this.pushAll(s)},Stack.prototype.shift=function(){return this.pop.apply(this,arguments)},Stack.prototype.clear=function(){return 0===this.size?this:this.__ownerID?(this.size=0,this._head=void 0,this.__hash=void 0,this.__altered=!0,this):emptyStack()},Stack.prototype.slice=function(s,o){if(wholeSlice(s,o,this.size))return this;var i=resolveBegin(s,this.size);if(resolveEnd(o,this.size)!==this.size)return IndexedCollection.prototype.slice.call(this,s,o);for(var a=this.size-i,u=this._head;i--;)u=u.next;return this.__ownerID?(this.size=a,this._head=u,this.__hash=void 0,this.__altered=!0,this):makeStack(a,u)},Stack.prototype.__ensureOwner=function(s){return s===this.__ownerID?this:s?makeStack(this.size,this._head,s,this.__hash):(this.__ownerID=s,this.__altered=!1,this)},Stack.prototype.__iterate=function(s,o){if(o)return this.reverse().__iterate(s);for(var i=0,a=this._head;a&&!1!==s(a.value,i++,this);)a=a.next;return i},Stack.prototype.__iterator=function(s,o){if(o)return this.reverse().__iterator(s);var i=0,a=this._head;return new Iterator((function(){if(a){var o=a.value;return a=a.next,iteratorValue(s,i++,o)}return iteratorDone()}))},Stack.isStack=isStack;var at,ct="@@__IMMUTABLE_STACK__@@",lt=Stack.prototype;function makeStack(s,o,i,a){var u=Object.create(lt);return u.size=s,u._head=o,u.__ownerID=i,u.__hash=a,u.__altered=!1,u}function emptyStack(){return at||(at=makeStack(0))}function mixin(s,o){var keyCopier=function(i){s.prototype[i]=o[i]};return Object.keys(o).forEach(keyCopier),Object.getOwnPropertySymbols&&Object.getOwnPropertySymbols(o).forEach(keyCopier),s}lt[ct]=!0,lt.withMutations=$e.withMutations,lt.asMutable=$e.asMutable,lt.asImmutable=$e.asImmutable,lt.wasAltered=$e.wasAltered,Iterable.Iterator=Iterator,mixin(Iterable,{toArray:function(){assertNotInfinite(this.size);var s=new Array(this.size||0);return this.valueSeq().__iterate((function(o,i){s[i]=o})),s},toIndexedSeq:function(){return new ToIndexedSequence(this)},toJS:function(){return this.toSeq().map((function(s){return s&&"function"==typeof s.toJS?s.toJS():s})).__toJS()},toJSON:function(){return this.toSeq().map((function(s){return s&&"function"==typeof s.toJSON?s.toJSON():s})).__toJS()},toKeyedSeq:function(){return new ToKeyedSequence(this,!0)},toMap:function(){return Map(this.toKeyedSeq())},toObject:function(){assertNotInfinite(this.size);var s={};return this.__iterate((function(o,i){s[i]=o})),s},toOrderedMap:function(){return OrderedMap(this.toKeyedSeq())},toOrderedSet:function(){return OrderedSet(isKeyed(this)?this.valueSeq():this)},toSet:function(){return Set(isKeyed(this)?this.valueSeq():this)},toSetSeq:function(){return new ToSetSequence(this)},toSeq:function(){return isIndexed(this)?this.toIndexedSeq():isKeyed(this)?this.toKeyedSeq():this.toSetSeq()},toStack:function(){return Stack(isKeyed(this)?this.valueSeq():this)},toList:function(){return List(isKeyed(this)?this.valueSeq():this)},toString:function(){return"[Iterable]"},__toString:function(s,o){return 0===this.size?s+o:s+" "+this.toSeq().map(this.__toStringMapper).join(", ")+" "+o},concat:function(){return reify(this,concatFactory(this,s.call(arguments,0)))},includes:function(s){return this.some((function(o){return is(o,s)}))},entries:function(){return this.__iterator(V)},every:function(s,o){assertNotInfinite(this.size);var i=!0;return this.__iterate((function(a,u,_){if(!s.call(o,a,u,_))return i=!1,!1})),i},filter:function(s,o){return reify(this,filterFactory(this,s,o,!0))},find:function(s,o,i){var a=this.findEntry(s,o);return a?a[1]:i},forEach:function(s,o){return assertNotInfinite(this.size),this.__iterate(o?s.bind(o):s)},join:function(s){assertNotInfinite(this.size),s=void 0!==s?""+s:",";var o="",i=!0;return this.__iterate((function(a){i?i=!1:o+=s,o+=null!=a?a.toString():""})),o},keys:function(){return this.__iterator($)},map:function(s,o){return reify(this,mapFactory(this,s,o))},reduce:function(s,o,i){var a,u;return assertNotInfinite(this.size),arguments.length<2?u=!0:a=o,this.__iterate((function(o,_,w){u?(u=!1,a=o):a=s.call(i,a,o,_,w)})),a},reduceRight:function(s,o,i){var a=this.toKeyedSeq().reverse();return a.reduce.apply(a,arguments)},reverse:function(){return reify(this,reverseFactory(this,!0))},slice:function(s,o){return reify(this,sliceFactory(this,s,o,!0))},some:function(s,o){return!this.every(not(s),o)},sort:function(s){return reify(this,sortFactory(this,s))},values:function(){return this.__iterator(U)},butLast:function(){return this.slice(0,-1)},isEmpty:function(){return void 0!==this.size?0===this.size:!this.some((function(){return!0}))},count:function(s,o){return ensureSize(s?this.toSeq().filter(s,o):this)},countBy:function(s,o){return countByFactory(this,s,o)},equals:function(s){return deepEqual(this,s)},entrySeq:function(){var s=this;if(s._cache)return new ArraySeq(s._cache);var o=s.toSeq().map(entryMapper).toIndexedSeq();return o.fromEntrySeq=function(){return s.toSeq()},o},filterNot:function(s,o){return this.filter(not(s),o)},findEntry:function(s,o,i){var a=i;return this.__iterate((function(i,u,_){if(s.call(o,i,u,_))return a=[u,i],!1})),a},findKey:function(s,o){var i=this.findEntry(s,o);return i&&i[0]},findLast:function(s,o,i){return this.toKeyedSeq().reverse().find(s,o,i)},findLastEntry:function(s,o,i){return this.toKeyedSeq().reverse().findEntry(s,o,i)},findLastKey:function(s,o){return this.toKeyedSeq().reverse().findKey(s,o)},first:function(){return this.find(returnTrue)},flatMap:function(s,o){return reify(this,flatMapFactory(this,s,o))},flatten:function(s){return reify(this,flattenFactory(this,s,!0))},fromEntrySeq:function(){return new FromEntriesSequence(this)},get:function(s,o){return this.find((function(o,i){return is(i,s)}),void 0,o)},getIn:function(s,o){for(var i,a=this,u=forceIterator(s);!(i=u.next()).done;){var _=i.value;if((a=a&&a.get?a.get(_,j):j)===j)return o}return a},groupBy:function(s,o){return groupByFactory(this,s,o)},has:function(s){return this.get(s,j)!==j},hasIn:function(s){return this.getIn(s,j)!==j},isSubset:function(s){return s="function"==typeof s.includes?s:Iterable(s),this.every((function(o){return s.includes(o)}))},isSuperset:function(s){return(s="function"==typeof s.isSubset?s:Iterable(s)).isSubset(this)},keyOf:function(s){return this.findKey((function(o){return is(o,s)}))},keySeq:function(){return this.toSeq().map(keyMapper).toIndexedSeq()},last:function(){return this.toSeq().reverse().first()},lastKeyOf:function(s){return this.toKeyedSeq().reverse().keyOf(s)},max:function(s){return maxFactory(this,s)},maxBy:function(s,o){return maxFactory(this,o,s)},min:function(s){return maxFactory(this,s?neg(s):defaultNegComparator)},minBy:function(s,o){return maxFactory(this,o?neg(o):defaultNegComparator,s)},rest:function(){return this.slice(1)},skip:function(s){return this.slice(Math.max(0,s))},skipLast:function(s){return reify(this,this.toSeq().reverse().skip(s).reverse())},skipWhile:function(s,o){return reify(this,skipWhileFactory(this,s,o,!0))},skipUntil:function(s,o){return this.skipWhile(not(s),o)},sortBy:function(s,o){return reify(this,sortFactory(this,o,s))},take:function(s){return this.slice(0,Math.max(0,s))},takeLast:function(s){return reify(this,this.toSeq().reverse().take(s).reverse())},takeWhile:function(s,o){return reify(this,takeWhileFactory(this,s,o))},takeUntil:function(s,o){return this.takeWhile(not(s),o)},valueSeq:function(){return this.toIndexedSeq()},hashCode:function(){return this.__hash||(this.__hash=hashIterable(this))}});var ut=Iterable.prototype;ut[o]=!0,ut[Z]=ut.values,ut.__toJS=ut.toArray,ut.__toStringMapper=quoteString,ut.inspect=ut.toSource=function(){return this.toString()},ut.chain=ut.flatMap,ut.contains=ut.includes,mixin(KeyedIterable,{flip:function(){return reify(this,flipFactory(this))},mapEntries:function(s,o){var i=this,a=0;return reify(this,this.toSeq().map((function(u,_){return s.call(o,[_,u],a++,i)})).fromEntrySeq())},mapKeys:function(s,o){var i=this;return reify(this,this.toSeq().flip().map((function(a,u){return s.call(o,a,u,i)})).flip())}});var pt=KeyedIterable.prototype;function keyMapper(s,o){return o}function entryMapper(s,o){return[o,s]}function not(s){return function(){return!s.apply(this,arguments)}}function neg(s){return function(){return-s.apply(this,arguments)}}function quoteString(s){return"string"==typeof s?JSON.stringify(s):String(s)}function defaultZipper(){return arrCopy(arguments)}function defaultNegComparator(s,o){return so?-1:0}function hashIterable(s){if(s.size===1/0)return 0;var o=isOrdered(s),i=isKeyed(s),a=o?1:0;return murmurHashOfSize(s.__iterate(i?o?function(s,o){a=31*a+hashMerge(hash(s),hash(o))|0}:function(s,o){a=a+hashMerge(hash(s),hash(o))|0}:o?function(s){a=31*a+hash(s)|0}:function(s){a=a+hash(s)|0}),a)}function murmurHashOfSize(s,o){return o=le(o,3432918353),o=le(o<<15|o>>>-15,461845907),o=le(o<<13|o>>>-13,5),o=le((o=o+3864292196^s)^o>>>16,2246822507),o=smi((o=le(o^o>>>13,3266489909))^o>>>16)}function hashMerge(s,o){return s^o+2654435769+(s<<6)+(s>>2)}return pt[i]=!0,pt[Z]=ut.entries,pt.__toJS=ut.toObject,pt.__toStringMapper=function(s,o){return JSON.stringify(o)+": "+quoteString(s)},mixin(IndexedIterable,{toKeyedSeq:function(){return new ToKeyedSequence(this,!1)},filter:function(s,o){return reify(this,filterFactory(this,s,o,!1))},findIndex:function(s,o){var i=this.findEntry(s,o);return i?i[0]:-1},indexOf:function(s){var o=this.keyOf(s);return void 0===o?-1:o},lastIndexOf:function(s){var o=this.lastKeyOf(s);return void 0===o?-1:o},reverse:function(){return reify(this,reverseFactory(this,!1))},slice:function(s,o){return reify(this,sliceFactory(this,s,o,!1))},splice:function(s,o){var i=arguments.length;if(o=Math.max(0|o,0),0===i||2===i&&!o)return this;s=resolveBegin(s,s<0?this.count():this.size);var a=this.slice(0,s);return reify(this,1===i?a:a.concat(arrCopy(arguments,2),this.slice(s+o)))},findLastIndex:function(s,o){var i=this.findLastEntry(s,o);return i?i[0]:-1},first:function(){return this.get(0)},flatten:function(s){return reify(this,flattenFactory(this,s,!1))},get:function(s,o){return(s=wrapIndex(this,s))<0||this.size===1/0||void 0!==this.size&&s>this.size?o:this.find((function(o,i){return i===s}),void 0,o)},has:function(s){return(s=wrapIndex(this,s))>=0&&(void 0!==this.size?this.size===1/0||s{"use strict";i(71340);var a=i(92046);s.exports=a.Object.assign},9957:(s,o,i)=>{"use strict";var a=Function.prototype.call,u=Object.prototype.hasOwnProperty,_=i(66743);s.exports=_.call(a,u)},9999:(s,o,i)=>{var a=i(37217),u=i(83729),_=i(16547),w=i(74733),x=i(43838),C=i(93290),j=i(23007),L=i(92271),B=i(48948),$=i(50002),U=i(83349),V=i(5861),z=i(76189),Y=i(77199),Z=i(35529),ee=i(56449),ie=i(3656),ae=i(87730),ce=i(23805),le=i(38440),pe=i(95950),de=i(37241),fe="[object Arguments]",ye="[object Function]",be="[object Object]",_e={};_e[fe]=_e["[object Array]"]=_e["[object ArrayBuffer]"]=_e["[object DataView]"]=_e["[object Boolean]"]=_e["[object Date]"]=_e["[object Float32Array]"]=_e["[object Float64Array]"]=_e["[object Int8Array]"]=_e["[object Int16Array]"]=_e["[object Int32Array]"]=_e["[object Map]"]=_e["[object Number]"]=_e[be]=_e["[object RegExp]"]=_e["[object Set]"]=_e["[object String]"]=_e["[object Symbol]"]=_e["[object Uint8Array]"]=_e["[object Uint8ClampedArray]"]=_e["[object Uint16Array]"]=_e["[object Uint32Array]"]=!0,_e["[object Error]"]=_e[ye]=_e["[object WeakMap]"]=!1,s.exports=function baseClone(s,o,i,Se,we,xe){var Pe,Te=1&o,Re=2&o,$e=4&o;if(i&&(Pe=we?i(s,Se,we,xe):i(s)),void 0!==Pe)return Pe;if(!ce(s))return s;var qe=ee(s);if(qe){if(Pe=z(s),!Te)return j(s,Pe)}else{var ze=V(s),We=ze==ye||"[object GeneratorFunction]"==ze;if(ie(s))return C(s,Te);if(ze==be||ze==fe||We&&!we){if(Pe=Re||We?{}:Z(s),!Te)return Re?B(s,x(Pe,s)):L(s,w(Pe,s))}else{if(!_e[ze])return we?s:{};Pe=Y(s,ze,Te)}}xe||(xe=new a);var He=xe.get(s);if(He)return He;xe.set(s,Pe),le(s)?s.forEach((function(a){Pe.add(baseClone(a,o,i,a,s,xe))})):ae(s)&&s.forEach((function(a,u){Pe.set(u,baseClone(a,o,i,u,s,xe))}));var Ye=qe?void 0:($e?Re?U:$:Re?de:pe)(s);return u(Ye||s,(function(a,u){Ye&&(a=s[u=a]),_(Pe,u,baseClone(a,o,i,u,s,xe))})),Pe}},10023:(s,o,i)=>{const a=i(6205),INTS=()=>[{type:a.RANGE,from:48,to:57}],WORDS=()=>[{type:a.CHAR,value:95},{type:a.RANGE,from:97,to:122},{type:a.RANGE,from:65,to:90}].concat(INTS()),WHITESPACE=()=>[{type:a.CHAR,value:9},{type:a.CHAR,value:10},{type:a.CHAR,value:11},{type:a.CHAR,value:12},{type:a.CHAR,value:13},{type:a.CHAR,value:32},{type:a.CHAR,value:160},{type:a.CHAR,value:5760},{type:a.RANGE,from:8192,to:8202},{type:a.CHAR,value:8232},{type:a.CHAR,value:8233},{type:a.CHAR,value:8239},{type:a.CHAR,value:8287},{type:a.CHAR,value:12288},{type:a.CHAR,value:65279}];o.words=()=>({type:a.SET,set:WORDS(),not:!1}),o.notWords=()=>({type:a.SET,set:WORDS(),not:!0}),o.ints=()=>({type:a.SET,set:INTS(),not:!1}),o.notInts=()=>({type:a.SET,set:INTS(),not:!0}),o.whitespace=()=>({type:a.SET,set:WHITESPACE(),not:!1}),o.notWhitespace=()=>({type:a.SET,set:WHITESPACE(),not:!0}),o.anyChar=()=>({type:a.SET,set:[{type:a.CHAR,value:10},{type:a.CHAR,value:13},{type:a.CHAR,value:8232},{type:a.CHAR,value:8233}],not:!0})},10043:(s,o,i)=>{"use strict";var a=i(54018),u=String,_=TypeError;s.exports=function(s){if(a(s))return s;throw new _("Can't set "+u(s)+" as a prototype")}},10076:s=>{"use strict";s.exports=Function.prototype.call},10124:(s,o,i)=>{var a=i(9325);s.exports=function(){return a.Date.now()}},10300:(s,o,i)=>{"use strict";var a=i(13930),u=i(82159),_=i(36624),w=i(4640),x=i(73448),C=TypeError;s.exports=function(s,o){var i=arguments.length<2?x(s):o;if(u(i))return _(a(i,s));throw new C(w(s)+" is not iterable")}},10316:(s,o,i)=>{const a=i(2404),u=i(55973),_=i(92340);class Element{constructor(s,o,i){o&&(this.meta=o),i&&(this.attributes=i),this.content=s}freeze(){Object.isFrozen(this)||(this._meta&&(this.meta.parent=this,this.meta.freeze()),this._attributes&&(this.attributes.parent=this,this.attributes.freeze()),this.children.forEach((s=>{s.parent=this,s.freeze()}),this),this.content&&Array.isArray(this.content)&&Object.freeze(this.content),Object.freeze(this))}primitive(){}clone(){const s=new this.constructor;return s.element=this.element,this.meta.length&&(s._meta=this.meta.clone()),this.attributes.length&&(s._attributes=this.attributes.clone()),this.content?this.content.clone?s.content=this.content.clone():Array.isArray(this.content)?s.content=this.content.map((s=>s.clone())):s.content=this.content:s.content=this.content,s}toValue(){return this.content instanceof Element?this.content.toValue():this.content instanceof u?{key:this.content.key.toValue(),value:this.content.value?this.content.value.toValue():void 0}:this.content&&this.content.map?this.content.map((s=>s.toValue()),this):this.content}toRef(s){if(""===this.id.toValue())throw Error("Cannot create reference to an element that does not contain an ID");const o=new this.RefElement(this.id.toValue());return s&&(o.path=s),o}findRecursive(...s){if(arguments.length>1&&!this.isFrozen)throw new Error("Cannot find recursive with multiple element names without first freezing the element. Call `element.freeze()`");const o=s.pop();let i=new _;const append=(s,o)=>(s.push(o),s),checkElement=(s,i)=>{i.element===o&&s.push(i);const a=i.findRecursive(o);return a&&a.reduce(append,s),i.content instanceof u&&(i.content.key&&checkElement(s,i.content.key),i.content.value&&checkElement(s,i.content.value)),s};return this.content&&(this.content.element&&checkElement(i,this.content),Array.isArray(this.content)&&this.content.reduce(checkElement,i)),s.isEmpty||(i=i.filter((o=>{let i=o.parents.map((s=>s.element));for(const o in s){const a=s[o],u=i.indexOf(a);if(-1===u)return!1;i=i.splice(0,u)}return!0}))),i}set(s){return this.content=s,this}equals(s){return a(this.toValue(),s)}getMetaProperty(s,o){if(!this.meta.hasKey(s)){if(this.isFrozen){const s=this.refract(o);return s.freeze(),s}this.meta.set(s,o)}return this.meta.get(s)}setMetaProperty(s,o){this.meta.set(s,o)}get element(){return this._storedElement||"element"}set element(s){this._storedElement=s}get content(){return this._content}set content(s){if(s instanceof Element)this._content=s;else if(s instanceof _)this.content=s.elements;else if("string"==typeof s||"number"==typeof s||"boolean"==typeof s||"null"===s||null==s)this._content=s;else if(s instanceof u)this._content=s;else if(Array.isArray(s))this._content=s.map(this.refract);else{if("object"!=typeof s)throw new Error("Cannot set content to given value");this._content=Object.keys(s).map((o=>new this.MemberElement(o,s[o])))}}get meta(){if(!this._meta){if(this.isFrozen){const s=new this.ObjectElement;return s.freeze(),s}this._meta=new this.ObjectElement}return this._meta}set meta(s){s instanceof this.ObjectElement?this._meta=s:this.meta.set(s||{})}get attributes(){if(!this._attributes){if(this.isFrozen){const s=new this.ObjectElement;return s.freeze(),s}this._attributes=new this.ObjectElement}return this._attributes}set attributes(s){s instanceof this.ObjectElement?this._attributes=s:this.attributes.set(s||{})}get id(){return this.getMetaProperty("id","")}set id(s){this.setMetaProperty("id",s)}get classes(){return this.getMetaProperty("classes",[])}set classes(s){this.setMetaProperty("classes",s)}get title(){return this.getMetaProperty("title","")}set title(s){this.setMetaProperty("title",s)}get description(){return this.getMetaProperty("description","")}set description(s){this.setMetaProperty("description",s)}get links(){return this.getMetaProperty("links",[])}set links(s){this.setMetaProperty("links",s)}get isFrozen(){return Object.isFrozen(this)}get parents(){let{parent:s}=this;const o=new _;for(;s;)o.push(s),s=s.parent;return o}get children(){if(Array.isArray(this.content))return new _(this.content);if(this.content instanceof u){const s=new _([this.content.key]);return this.content.value&&s.push(this.content.value),s}return this.content instanceof Element?new _([this.content]):new _}get recursiveChildren(){const s=new _;return this.children.forEach((o=>{s.push(o),o.recursiveChildren.forEach((o=>{s.push(o)}))})),s}}s.exports=Element},10392:s=>{s.exports=function getValue(s,o){return null==s?void 0:s[o]}},10487:(s,o,i)=>{"use strict";var a=i(96897),u=i(30655),_=i(73126),w=i(12205);s.exports=function callBind(s){var o=_(arguments),i=s.length-(arguments.length-1);return a(o,1+(i>0?i:0),!0)},u?u(s.exports,"apply",{value:w}):s.exports.apply=w},10776:(s,o,i)=>{var a=i(30756),u=i(95950);s.exports=function getMatchData(s){for(var o=u(s),i=o.length;i--;){var _=o[i],w=s[_];o[i]=[_,w,a(w)]}return o}},10866:(s,o,i)=>{const a=i(6048),u=i(92340);class ObjectSlice extends u{map(s,o){return this.elements.map((i=>s.bind(o)(i.value,i.key,i)))}filter(s,o){return new ObjectSlice(this.elements.filter((i=>s.bind(o)(i.value,i.key,i))))}reject(s,o){return this.filter(a(s.bind(o)))}forEach(s,o){return this.elements.forEach(((i,a)=>{s.bind(o)(i.value,i.key,i,a)}))}keys(){return this.map(((s,o)=>o.toValue()))}values(){return this.map((s=>s.toValue()))}}s.exports=ObjectSlice},11002:s=>{"use strict";s.exports=Function.prototype.apply},11042:(s,o,i)=>{"use strict";var a=i(85582),u=i(1907),_=i(24443),w=i(87170),x=i(36624),C=u([].concat);s.exports=a("Reflect","ownKeys")||function ownKeys(s){var o=_.f(x(s)),i=w.f;return i?C(o,i(s)):o}},11091:(s,o,i)=>{"use strict";var a=i(45951),u=i(76024),_=i(92361),w=i(62250),x=i(13846).f,C=i(7463),j=i(92046),L=i(28311),B=i(61626),$=i(49724);i(36128);var wrapConstructor=function(s){var Wrapper=function(o,i,a){if(this instanceof Wrapper){switch(arguments.length){case 0:return new s;case 1:return new s(o);case 2:return new s(o,i)}return new s(o,i,a)}return u(s,this,arguments)};return Wrapper.prototype=s.prototype,Wrapper};s.exports=function(s,o){var i,u,U,V,z,Y,Z,ee,ie,ae=s.target,ce=s.global,le=s.stat,pe=s.proto,de=ce?a:le?a[ae]:a[ae]&&a[ae].prototype,fe=ce?j:j[ae]||B(j,ae,{})[ae],ye=fe.prototype;for(V in o)u=!(i=C(ce?V:ae+(le?".":"#")+V,s.forced))&&de&&$(de,V),Y=fe[V],u&&(Z=s.dontCallGetSet?(ie=x(de,V))&&ie.value:de[V]),z=u&&Z?Z:o[V],(i||pe||typeof Y!=typeof z)&&(ee=s.bind&&u?L(z,a):s.wrap&&u?wrapConstructor(z):pe&&w(z)?_(z):z,(s.sham||z&&z.sham||Y&&Y.sham)&&B(ee,"sham",!0),B(fe,V,ee),pe&&($(j,U=ae+"Prototype")||B(j,U,{}),B(j[U],V,z),s.real&&ye&&(i||!ye[V])&&B(ye,V,z)))}},11287:s=>{s.exports=function getHolder(s){return s.placeholder}},11331:(s,o,i)=>{var a=i(72552),u=i(28879),_=i(40346),w=Function.prototype,x=Object.prototype,C=w.toString,j=x.hasOwnProperty,L=C.call(Object);s.exports=function isPlainObject(s){if(!_(s)||"[object Object]"!=a(s))return!1;var o=u(s);if(null===o)return!0;var i=j.call(o,"constructor")&&o.constructor;return"function"==typeof i&&i instanceof i&&C.call(i)==L}},11470:(s,o,i)=>{"use strict";var a=i(1907),u=i(65482),_=i(90160),w=i(74239),x=a("".charAt),C=a("".charCodeAt),j=a("".slice),createMethod=function(s){return function(o,i){var a,L,B=_(w(o)),$=u(i),U=B.length;return $<0||$>=U?s?"":void 0:(a=C(B,$))<55296||a>56319||$+1===U||(L=C(B,$+1))<56320||L>57343?s?x(B,$):a:s?j(B,$,$+2):L-56320+(a-55296<<10)+65536}};s.exports={codeAt:createMethod(!1),charAt:createMethod(!0)}},11842:(s,o,i)=>{var a=i(82819),u=i(9325);s.exports=function createBind(s,o,i){var _=1&o,w=a(s);return function wrapper(){return(this&&this!==u&&this instanceof wrapper?w:s).apply(_?i:this,arguments)}}},12205:(s,o,i)=>{"use strict";var a=i(66743),u=i(11002),_=i(13144);s.exports=function applyBind(){return _(a,u,arguments)}},12242:(s,o,i)=>{const a=i(10316);s.exports=class BooleanElement extends a{constructor(s,o,i){super(s,o,i),this.element="boolean"}primitive(){return"boolean"}}},12507:(s,o,i)=>{var a=i(28754),u=i(49698),_=i(63912),w=i(13222);s.exports=function createCaseFirst(s){return function(o){o=w(o);var i=u(o)?_(o):void 0,x=i?i[0]:o.charAt(0),C=i?a(i,1).join(""):o.slice(1);return x[s]()+C}}},12560:(s,o,i)=>{"use strict";i(99363);var a=i(19287),u=i(45951),_=i(14840),w=i(93742);for(var x in a)_(u[x],x),w[x]=w.Array},12651:(s,o,i)=>{var a=i(74218);s.exports=function getMapData(s,o){var i=s.__data__;return a(o)?i["string"==typeof o?"string":"hash"]:i.map}},12749:(s,o,i)=>{var a=i(81042),u=Object.prototype.hasOwnProperty;s.exports=function hashHas(s){var o=this.__data__;return a?void 0!==o[s]:u.call(o,s)}},13144:(s,o,i)=>{"use strict";var a=i(66743),u=i(11002),_=i(10076),w=i(47119);s.exports=w||a.call(_,u)},13222:(s,o,i)=>{var a=i(77556);s.exports=function toString(s){return null==s?"":a(s)}},13846:(s,o,i)=>{"use strict";var a=i(39447),u=i(13930),_=i(22574),w=i(75817),x=i(4993),C=i(70470),j=i(49724),L=i(73648),B=Object.getOwnPropertyDescriptor;o.f=a?B:function getOwnPropertyDescriptor(s,o){if(s=x(s),o=C(o),L)try{return B(s,o)}catch(s){}if(j(s,o))return w(!u(_.f,s,o),s[o])}},13930:(s,o,i)=>{"use strict";var a=i(41505),u=Function.prototype.call;s.exports=a?u.bind(u):function(){return u.apply(u,arguments)}},14248:s=>{s.exports=function arraySome(s,o){for(var i=-1,a=null==s?0:s.length;++i{s.exports=function arrayPush(s,o){for(var i=-1,a=o.length,u=s.length;++i{const a=i(10316);s.exports=class RefElement extends a{constructor(s,o,i){super(s||[],o,i),this.element="ref",this.path||(this.path="element")}get path(){return this.attributes.get("path")}set path(s){this.attributes.set("path",s)}}},14744:s=>{"use strict";var o=function isMergeableObject(s){return function isNonNullObject(s){return!!s&&"object"==typeof s}(s)&&!function isSpecial(s){var o=Object.prototype.toString.call(s);return"[object RegExp]"===o||"[object Date]"===o||function isReactElement(s){return s.$$typeof===i}(s)}(s)};var i="function"==typeof Symbol&&Symbol.for?Symbol.for("react.element"):60103;function cloneUnlessOtherwiseSpecified(s,o){return!1!==o.clone&&o.isMergeableObject(s)?deepmerge(function emptyTarget(s){return Array.isArray(s)?[]:{}}(s),s,o):s}function defaultArrayMerge(s,o,i){return s.concat(o).map((function(s){return cloneUnlessOtherwiseSpecified(s,i)}))}function getKeys(s){return Object.keys(s).concat(function getEnumerableOwnPropertySymbols(s){return Object.getOwnPropertySymbols?Object.getOwnPropertySymbols(s).filter((function(o){return Object.propertyIsEnumerable.call(s,o)})):[]}(s))}function propertyIsOnObject(s,o){try{return o in s}catch(s){return!1}}function mergeObject(s,o,i){var a={};return i.isMergeableObject(s)&&getKeys(s).forEach((function(o){a[o]=cloneUnlessOtherwiseSpecified(s[o],i)})),getKeys(o).forEach((function(u){(function propertyIsUnsafe(s,o){return propertyIsOnObject(s,o)&&!(Object.hasOwnProperty.call(s,o)&&Object.propertyIsEnumerable.call(s,o))})(s,u)||(propertyIsOnObject(s,u)&&i.isMergeableObject(o[u])?a[u]=function getMergeFunction(s,o){if(!o.customMerge)return deepmerge;var i=o.customMerge(s);return"function"==typeof i?i:deepmerge}(u,i)(s[u],o[u],i):a[u]=cloneUnlessOtherwiseSpecified(o[u],i))})),a}function deepmerge(s,i,a){(a=a||{}).arrayMerge=a.arrayMerge||defaultArrayMerge,a.isMergeableObject=a.isMergeableObject||o,a.cloneUnlessOtherwiseSpecified=cloneUnlessOtherwiseSpecified;var u=Array.isArray(i);return u===Array.isArray(s)?u?a.arrayMerge(s,i,a):mergeObject(s,i,a):cloneUnlessOtherwiseSpecified(i,a)}deepmerge.all=function deepmergeAll(s,o){if(!Array.isArray(s))throw new Error("first argument should be an array");return s.reduce((function(s,i){return deepmerge(s,i,o)}),{})};var a=deepmerge;s.exports=a},14792:(s,o,i)=>{var a=i(13222),u=i(55808);s.exports=function capitalize(s){return u(a(s).toLowerCase())}},14840:(s,o,i)=>{"use strict";var a=i(52623),u=i(74284).f,_=i(61626),w=i(49724),x=i(54878),C=i(76264)("toStringTag");s.exports=function(s,o,i,j){var L=i?s:s&&s.prototype;L&&(w(L,C)||u(L,C,{configurable:!0,value:o}),j&&!a&&_(L,"toString",x))}},14974:s=>{s.exports=function safeGet(s,o){if(("constructor"!==o||"function"!=typeof s[o])&&"__proto__"!=o)return s[o]}},15287:(s,o)=>{"use strict";var i=Symbol.for("react.element"),a=Symbol.for("react.portal"),u=Symbol.for("react.fragment"),_=Symbol.for("react.strict_mode"),w=Symbol.for("react.profiler"),x=Symbol.for("react.provider"),C=Symbol.for("react.context"),j=Symbol.for("react.forward_ref"),L=Symbol.for("react.suspense"),B=Symbol.for("react.memo"),$=Symbol.for("react.lazy"),U=Symbol.iterator;var V={isMounted:function(){return!1},enqueueForceUpdate:function(){},enqueueReplaceState:function(){},enqueueSetState:function(){}},z=Object.assign,Y={};function E(s,o,i){this.props=s,this.context=o,this.refs=Y,this.updater=i||V}function F(){}function G(s,o,i){this.props=s,this.context=o,this.refs=Y,this.updater=i||V}E.prototype.isReactComponent={},E.prototype.setState=function(s,o){if("object"!=typeof s&&"function"!=typeof s&&null!=s)throw Error("setState(...): takes an object of state variables to update or a function which returns an object of state variables.");this.updater.enqueueSetState(this,s,o,"setState")},E.prototype.forceUpdate=function(s){this.updater.enqueueForceUpdate(this,s,"forceUpdate")},F.prototype=E.prototype;var Z=G.prototype=new F;Z.constructor=G,z(Z,E.prototype),Z.isPureReactComponent=!0;var ee=Array.isArray,ie=Object.prototype.hasOwnProperty,ae={current:null},ce={key:!0,ref:!0,__self:!0,__source:!0};function M(s,o,a){var u,_={},w=null,x=null;if(null!=o)for(u in void 0!==o.ref&&(x=o.ref),void 0!==o.key&&(w=""+o.key),o)ie.call(o,u)&&!ce.hasOwnProperty(u)&&(_[u]=o[u]);var C=arguments.length-2;if(1===C)_.children=a;else if(1{var a=i(96131);s.exports=function arrayIncludes(s,o){return!!(null==s?0:s.length)&&a(s,o,0)>-1}},15340:()=>{},15377:(s,o,i)=>{"use strict";var a=i(92861).Buffer,u=i(64634),_=i(74372),w=ArrayBuffer.isView||function isView(s){try{return _(s),!0}catch(s){return!1}},x="undefined"!=typeof Uint8Array,C="undefined"!=typeof ArrayBuffer&&"undefined"!=typeof Uint8Array,j=C&&(a.prototype instanceof Uint8Array||a.TYPED_ARRAY_SUPPORT);s.exports=function toBuffer(s,o){if(s instanceof a)return s;if("string"==typeof s)return a.from(s,o);if(C&&w(s)){if(0===s.byteLength)return a.alloc(0);if(j){var i=a.from(s.buffer,s.byteOffset,s.byteLength);if(i.byteLength===s.byteLength)return i}var _=s instanceof Uint8Array?s:new Uint8Array(s.buffer,s.byteOffset,s.byteLength),L=a.from(_);if(L.length===s.byteLength)return L}if(x&&s instanceof Uint8Array)return a.from(s);var B=u(s);if(B)for(var $=0;$255||~~U!==U)throw new RangeError("Array items must be numbers in the range 0-255.")}if(B||a.isBuffer(s)&&s.constructor&&"function"==typeof s.constructor.isBuffer&&s.constructor.isBuffer(s))return a.from(s);throw new TypeError('The "data" argument must be a string, an Array, a Buffer, a Uint8Array, or a DataView.')}},15389:(s,o,i)=>{var a=i(93663),u=i(87978),_=i(83488),w=i(56449),x=i(50583);s.exports=function baseIteratee(s){return"function"==typeof s?s:null==s?_:"object"==typeof s?w(s)?u(s[0],s[1]):a(s):x(s)}},15972:(s,o,i)=>{"use strict";var a=i(49724),u=i(62250),_=i(39298),w=i(92522),x=i(57382),C=w("IE_PROTO"),j=Object,L=j.prototype;s.exports=x?j.getPrototypeOf:function(s){var o=_(s);if(a(o,C))return o[C];var i=o.constructor;return u(i)&&o instanceof i?i.prototype:o instanceof j?L:null}},16038:(s,o,i)=>{var a=i(5861),u=i(40346);s.exports=function baseIsSet(s){return u(s)&&"[object Set]"==a(s)}},16426:s=>{s.exports=function(){var s=document.getSelection();if(!s.rangeCount)return function(){};for(var o=document.activeElement,i=[],a=0;a{var a=i(43360),u=i(75288),_=Object.prototype.hasOwnProperty;s.exports=function assignValue(s,o,i){var w=s[o];_.call(s,o)&&u(w,i)&&(void 0!==i||o in s)||a(s,o,i)}},16708:(s,o,i)=>{"use strict";var a,u=i(65606);function CorkedRequest(s){var o=this;this.next=null,this.entry=null,this.finish=function(){!function onCorkedFinish(s,o,i){var a=s.entry;s.entry=null;for(;a;){var u=a.callback;o.pendingcb--,u(i),a=a.next}o.corkedRequestsFree.next=s}(o,s)}}s.exports=Writable,Writable.WritableState=WritableState;var _={deprecate:i(94643)},w=i(40345),x=i(48287).Buffer,C=(void 0!==i.g?i.g:"undefined"!=typeof window?window:"undefined"!=typeof self?self:{}).Uint8Array||function(){};var j,L=i(75896),B=i(65291).getHighWaterMark,$=i(86048).F,U=$.ERR_INVALID_ARG_TYPE,V=$.ERR_METHOD_NOT_IMPLEMENTED,z=$.ERR_MULTIPLE_CALLBACK,Y=$.ERR_STREAM_CANNOT_PIPE,Z=$.ERR_STREAM_DESTROYED,ee=$.ERR_STREAM_NULL_VALUES,ie=$.ERR_STREAM_WRITE_AFTER_END,ae=$.ERR_UNKNOWN_ENCODING,ce=L.errorOrDestroy;function nop(){}function WritableState(s,o,_){a=a||i(25382),s=s||{},"boolean"!=typeof _&&(_=o instanceof a),this.objectMode=!!s.objectMode,_&&(this.objectMode=this.objectMode||!!s.writableObjectMode),this.highWaterMark=B(this,s,"writableHighWaterMark",_),this.finalCalled=!1,this.needDrain=!1,this.ending=!1,this.ended=!1,this.finished=!1,this.destroyed=!1;var w=!1===s.decodeStrings;this.decodeStrings=!w,this.defaultEncoding=s.defaultEncoding||"utf8",this.length=0,this.writing=!1,this.corked=0,this.sync=!0,this.bufferProcessing=!1,this.onwrite=function(s){!function onwrite(s,o){var i=s._writableState,a=i.sync,_=i.writecb;if("function"!=typeof _)throw new z;if(function onwriteStateUpdate(s){s.writing=!1,s.writecb=null,s.length-=s.writelen,s.writelen=0}(i),o)!function onwriteError(s,o,i,a,_){--o.pendingcb,i?(u.nextTick(_,a),u.nextTick(finishMaybe,s,o),s._writableState.errorEmitted=!0,ce(s,a)):(_(a),s._writableState.errorEmitted=!0,ce(s,a),finishMaybe(s,o))}(s,i,a,o,_);else{var w=needFinish(i)||s.destroyed;w||i.corked||i.bufferProcessing||!i.bufferedRequest||clearBuffer(s,i),a?u.nextTick(afterWrite,s,i,w,_):afterWrite(s,i,w,_)}}(o,s)},this.writecb=null,this.writelen=0,this.bufferedRequest=null,this.lastBufferedRequest=null,this.pendingcb=0,this.prefinished=!1,this.errorEmitted=!1,this.emitClose=!1!==s.emitClose,this.autoDestroy=!!s.autoDestroy,this.bufferedRequestCount=0,this.corkedRequestsFree=new CorkedRequest(this)}function Writable(s){var o=this instanceof(a=a||i(25382));if(!o&&!j.call(Writable,this))return new Writable(s);this._writableState=new WritableState(s,this,o),this.writable=!0,s&&("function"==typeof s.write&&(this._write=s.write),"function"==typeof s.writev&&(this._writev=s.writev),"function"==typeof s.destroy&&(this._destroy=s.destroy),"function"==typeof s.final&&(this._final=s.final)),w.call(this)}function doWrite(s,o,i,a,u,_,w){o.writelen=a,o.writecb=w,o.writing=!0,o.sync=!0,o.destroyed?o.onwrite(new Z("write")):i?s._writev(u,o.onwrite):s._write(u,_,o.onwrite),o.sync=!1}function afterWrite(s,o,i,a){i||function onwriteDrain(s,o){0===o.length&&o.needDrain&&(o.needDrain=!1,s.emit("drain"))}(s,o),o.pendingcb--,a(),finishMaybe(s,o)}function clearBuffer(s,o){o.bufferProcessing=!0;var i=o.bufferedRequest;if(s._writev&&i&&i.next){var a=o.bufferedRequestCount,u=new Array(a),_=o.corkedRequestsFree;_.entry=i;for(var w=0,x=!0;i;)u[w]=i,i.isBuf||(x=!1),i=i.next,w+=1;u.allBuffers=x,doWrite(s,o,!0,o.length,u,"",_.finish),o.pendingcb++,o.lastBufferedRequest=null,_.next?(o.corkedRequestsFree=_.next,_.next=null):o.corkedRequestsFree=new CorkedRequest(o),o.bufferedRequestCount=0}else{for(;i;){var C=i.chunk,j=i.encoding,L=i.callback;if(doWrite(s,o,!1,o.objectMode?1:C.length,C,j,L),i=i.next,o.bufferedRequestCount--,o.writing)break}null===i&&(o.lastBufferedRequest=null)}o.bufferedRequest=i,o.bufferProcessing=!1}function needFinish(s){return s.ending&&0===s.length&&null===s.bufferedRequest&&!s.finished&&!s.writing}function callFinal(s,o){s._final((function(i){o.pendingcb--,i&&ce(s,i),o.prefinished=!0,s.emit("prefinish"),finishMaybe(s,o)}))}function finishMaybe(s,o){var i=needFinish(o);if(i&&(function prefinish(s,o){o.prefinished||o.finalCalled||("function"!=typeof s._final||o.destroyed?(o.prefinished=!0,s.emit("prefinish")):(o.pendingcb++,o.finalCalled=!0,u.nextTick(callFinal,s,o)))}(s,o),0===o.pendingcb&&(o.finished=!0,s.emit("finish"),o.autoDestroy))){var a=s._readableState;(!a||a.autoDestroy&&a.endEmitted)&&s.destroy()}return i}i(56698)(Writable,w),WritableState.prototype.getBuffer=function getBuffer(){for(var s=this.bufferedRequest,o=[];s;)o.push(s),s=s.next;return o},function(){try{Object.defineProperty(WritableState.prototype,"buffer",{get:_.deprecate((function writableStateBufferGetter(){return this.getBuffer()}),"_writableState.buffer is deprecated. Use _writableState.getBuffer instead.","DEP0003")})}catch(s){}}(),"function"==typeof Symbol&&Symbol.hasInstance&&"function"==typeof Function.prototype[Symbol.hasInstance]?(j=Function.prototype[Symbol.hasInstance],Object.defineProperty(Writable,Symbol.hasInstance,{value:function value(s){return!!j.call(this,s)||this===Writable&&(s&&s._writableState instanceof WritableState)}})):j=function realHasInstance(s){return s instanceof this},Writable.prototype.pipe=function(){ce(this,new Y)},Writable.prototype.write=function(s,o,i){var a=this._writableState,_=!1,w=!a.objectMode&&function _isUint8Array(s){return x.isBuffer(s)||s instanceof C}(s);return w&&!x.isBuffer(s)&&(s=function _uint8ArrayToBuffer(s){return x.from(s)}(s)),"function"==typeof o&&(i=o,o=null),w?o="buffer":o||(o=a.defaultEncoding),"function"!=typeof i&&(i=nop),a.ending?function writeAfterEnd(s,o){var i=new ie;ce(s,i),u.nextTick(o,i)}(this,i):(w||function validChunk(s,o,i,a){var _;return null===i?_=new ee:"string"==typeof i||o.objectMode||(_=new U("chunk",["string","Buffer"],i)),!_||(ce(s,_),u.nextTick(a,_),!1)}(this,a,s,i))&&(a.pendingcb++,_=function writeOrBuffer(s,o,i,a,u,_){if(!i){var w=function decodeChunk(s,o,i){s.objectMode||!1===s.decodeStrings||"string"!=typeof o||(o=x.from(o,i));return o}(o,a,u);a!==w&&(i=!0,u="buffer",a=w)}var C=o.objectMode?1:a.length;o.length+=C;var j=o.length-1))throw new ae(s);return this._writableState.defaultEncoding=s,this},Object.defineProperty(Writable.prototype,"writableBuffer",{enumerable:!1,get:function get(){return this._writableState&&this._writableState.getBuffer()}}),Object.defineProperty(Writable.prototype,"writableHighWaterMark",{enumerable:!1,get:function get(){return this._writableState.highWaterMark}}),Writable.prototype._write=function(s,o,i){i(new V("_write()"))},Writable.prototype._writev=null,Writable.prototype.end=function(s,o,i){var a=this._writableState;return"function"==typeof s?(i=s,s=null,o=null):"function"==typeof o&&(i=o,o=null),null!=s&&this.write(s,o),a.corked&&(a.corked=1,this.uncork()),a.ending||function endWritable(s,o,i){o.ending=!0,finishMaybe(s,o),i&&(o.finished?u.nextTick(i):s.once("finish",i));o.ended=!0,s.writable=!1}(this,a,i),this},Object.defineProperty(Writable.prototype,"writableLength",{enumerable:!1,get:function get(){return this._writableState.length}}),Object.defineProperty(Writable.prototype,"destroyed",{enumerable:!1,get:function get(){return void 0!==this._writableState&&this._writableState.destroyed},set:function set(s){this._writableState&&(this._writableState.destroyed=s)}}),Writable.prototype.destroy=L.destroy,Writable.prototype._undestroy=L.undestroy,Writable.prototype._destroy=function(s,o){o(s)}},16946:(s,o,i)=>{"use strict";var a=i(1907),u=i(98828),_=i(45807),w=Object,x=a("".split);s.exports=u((function(){return!w("z").propertyIsEnumerable(0)}))?function(s){return"String"===_(s)?x(s,""):w(s)}:w},16962:(s,o)=>{o.aliasToReal={each:"forEach",eachRight:"forEachRight",entries:"toPairs",entriesIn:"toPairsIn",extend:"assignIn",extendAll:"assignInAll",extendAllWith:"assignInAllWith",extendWith:"assignInWith",first:"head",conforms:"conformsTo",matches:"isMatch",property:"get",__:"placeholder",F:"stubFalse",T:"stubTrue",all:"every",allPass:"overEvery",always:"constant",any:"some",anyPass:"overSome",apply:"spread",assoc:"set",assocPath:"set",complement:"negate",compose:"flowRight",contains:"includes",dissoc:"unset",dissocPath:"unset",dropLast:"dropRight",dropLastWhile:"dropRightWhile",equals:"isEqual",identical:"eq",indexBy:"keyBy",init:"initial",invertObj:"invert",juxt:"over",omitAll:"omit",nAry:"ary",path:"get",pathEq:"matchesProperty",pathOr:"getOr",paths:"at",pickAll:"pick",pipe:"flow",pluck:"map",prop:"get",propEq:"matchesProperty",propOr:"getOr",props:"at",symmetricDifference:"xor",symmetricDifferenceBy:"xorBy",symmetricDifferenceWith:"xorWith",takeLast:"takeRight",takeLastWhile:"takeRightWhile",unapply:"rest",unnest:"flatten",useWith:"overArgs",where:"conformsTo",whereEq:"isMatch",zipObj:"zipObject"},o.aryMethod={1:["assignAll","assignInAll","attempt","castArray","ceil","create","curry","curryRight","defaultsAll","defaultsDeepAll","floor","flow","flowRight","fromPairs","invert","iteratee","memoize","method","mergeAll","methodOf","mixin","nthArg","over","overEvery","overSome","rest","reverse","round","runInContext","spread","template","trim","trimEnd","trimStart","uniqueId","words","zipAll"],2:["add","after","ary","assign","assignAllWith","assignIn","assignInAllWith","at","before","bind","bindAll","bindKey","chunk","cloneDeepWith","cloneWith","concat","conformsTo","countBy","curryN","curryRightN","debounce","defaults","defaultsDeep","defaultTo","delay","difference","divide","drop","dropRight","dropRightWhile","dropWhile","endsWith","eq","every","filter","find","findIndex","findKey","findLast","findLastIndex","findLastKey","flatMap","flatMapDeep","flattenDepth","forEach","forEachRight","forIn","forInRight","forOwn","forOwnRight","get","groupBy","gt","gte","has","hasIn","includes","indexOf","intersection","invertBy","invoke","invokeMap","isEqual","isMatch","join","keyBy","lastIndexOf","lt","lte","map","mapKeys","mapValues","matchesProperty","maxBy","meanBy","merge","mergeAllWith","minBy","multiply","nth","omit","omitBy","overArgs","pad","padEnd","padStart","parseInt","partial","partialRight","partition","pick","pickBy","propertyOf","pull","pullAll","pullAt","random","range","rangeRight","rearg","reject","remove","repeat","restFrom","result","sampleSize","some","sortBy","sortedIndex","sortedIndexOf","sortedLastIndex","sortedLastIndexOf","sortedUniqBy","split","spreadFrom","startsWith","subtract","sumBy","take","takeRight","takeRightWhile","takeWhile","tap","throttle","thru","times","trimChars","trimCharsEnd","trimCharsStart","truncate","union","uniqBy","uniqWith","unset","unzipWith","without","wrap","xor","zip","zipObject","zipObjectDeep"],3:["assignInWith","assignWith","clamp","differenceBy","differenceWith","findFrom","findIndexFrom","findLastFrom","findLastIndexFrom","getOr","includesFrom","indexOfFrom","inRange","intersectionBy","intersectionWith","invokeArgs","invokeArgsMap","isEqualWith","isMatchWith","flatMapDepth","lastIndexOfFrom","mergeWith","orderBy","padChars","padCharsEnd","padCharsStart","pullAllBy","pullAllWith","rangeStep","rangeStepRight","reduce","reduceRight","replace","set","slice","sortedIndexBy","sortedLastIndexBy","transform","unionBy","unionWith","update","xorBy","xorWith","zipWith"],4:["fill","setWith","updateWith"]},o.aryRearg={2:[1,0],3:[2,0,1],4:[3,2,0,1]},o.iterateeAry={dropRightWhile:1,dropWhile:1,every:1,filter:1,find:1,findFrom:1,findIndex:1,findIndexFrom:1,findKey:1,findLast:1,findLastFrom:1,findLastIndex:1,findLastIndexFrom:1,findLastKey:1,flatMap:1,flatMapDeep:1,flatMapDepth:1,forEach:1,forEachRight:1,forIn:1,forInRight:1,forOwn:1,forOwnRight:1,map:1,mapKeys:1,mapValues:1,partition:1,reduce:2,reduceRight:2,reject:1,remove:1,some:1,takeRightWhile:1,takeWhile:1,times:1,transform:2},o.iterateeRearg={mapKeys:[1],reduceRight:[1,0]},o.methodRearg={assignInAllWith:[1,0],assignInWith:[1,2,0],assignAllWith:[1,0],assignWith:[1,2,0],differenceBy:[1,2,0],differenceWith:[1,2,0],getOr:[2,1,0],intersectionBy:[1,2,0],intersectionWith:[1,2,0],isEqualWith:[1,2,0],isMatchWith:[2,1,0],mergeAllWith:[1,0],mergeWith:[1,2,0],padChars:[2,1,0],padCharsEnd:[2,1,0],padCharsStart:[2,1,0],pullAllBy:[2,1,0],pullAllWith:[2,1,0],rangeStep:[1,2,0],rangeStepRight:[1,2,0],setWith:[3,1,2,0],sortedIndexBy:[2,1,0],sortedLastIndexBy:[2,1,0],unionBy:[1,2,0],unionWith:[1,2,0],updateWith:[3,1,2,0],xorBy:[1,2,0],xorWith:[1,2,0],zipWith:[1,2,0]},o.methodSpread={assignAll:{start:0},assignAllWith:{start:0},assignInAll:{start:0},assignInAllWith:{start:0},defaultsAll:{start:0},defaultsDeepAll:{start:0},invokeArgs:{start:2},invokeArgsMap:{start:2},mergeAll:{start:0},mergeAllWith:{start:0},partial:{start:1},partialRight:{start:1},without:{start:1},zipAll:{start:0}},o.mutate={array:{fill:!0,pull:!0,pullAll:!0,pullAllBy:!0,pullAllWith:!0,pullAt:!0,remove:!0,reverse:!0},object:{assign:!0,assignAll:!0,assignAllWith:!0,assignIn:!0,assignInAll:!0,assignInAllWith:!0,assignInWith:!0,assignWith:!0,defaults:!0,defaultsAll:!0,defaultsDeep:!0,defaultsDeepAll:!0,merge:!0,mergeAll:!0,mergeAllWith:!0,mergeWith:!0},set:{set:!0,setWith:!0,unset:!0,update:!0,updateWith:!0}},o.realToAlias=function(){var s=Object.prototype.hasOwnProperty,i=o.aliasToReal,a={};for(var u in i){var _=i[u];s.call(a,_)?a[_].push(u):a[_]=[u]}return a}(),o.remap={assignAll:"assign",assignAllWith:"assignWith",assignInAll:"assignIn",assignInAllWith:"assignInWith",curryN:"curry",curryRightN:"curryRight",defaultsAll:"defaults",defaultsDeepAll:"defaultsDeep",findFrom:"find",findIndexFrom:"findIndex",findLastFrom:"findLast",findLastIndexFrom:"findLastIndex",getOr:"get",includesFrom:"includes",indexOfFrom:"indexOf",invokeArgs:"invoke",invokeArgsMap:"invokeMap",lastIndexOfFrom:"lastIndexOf",mergeAll:"merge",mergeAllWith:"mergeWith",padChars:"pad",padCharsEnd:"padEnd",padCharsStart:"padStart",propertyOf:"get",rangeStep:"range",rangeStepRight:"rangeRight",restFrom:"rest",spreadFrom:"spread",trimChars:"trim",trimCharsEnd:"trimEnd",trimCharsStart:"trimStart",zipAll:"zip"},o.skipFixed={castArray:!0,flow:!0,flowRight:!0,iteratee:!0,mixin:!0,rearg:!0,runInContext:!0},o.skipRearg={add:!0,assign:!0,assignIn:!0,bind:!0,bindKey:!0,concat:!0,difference:!0,divide:!0,eq:!0,gt:!0,gte:!0,isEqual:!0,lt:!0,lte:!0,matchesProperty:!0,merge:!0,multiply:!0,overArgs:!0,partial:!0,partialRight:!0,propertyOf:!0,random:!0,range:!0,rangeRight:!0,subtract:!0,zip:!0,zipObject:!0,zipObjectDeep:!0}},17255:(s,o,i)=>{var a=i(47422);s.exports=function basePropertyDeep(s){return function(o){return a(o,s)}}},17285:s=>{function source(s){return s?"string"==typeof s?s:s.source:null}function lookahead(s){return concat("(?=",s,")")}function concat(...s){return s.map((s=>source(s))).join("")}function either(...s){return"("+s.map((s=>source(s))).join("|")+")"}s.exports=function xml(s){const o=concat(/[A-Z_]/,function optional(s){return concat("(",s,")?")}(/[A-Z0-9_.-]*:/),/[A-Z0-9_.-]*/),i={className:"symbol",begin:/&[a-z]+;|&#[0-9]+;|&#x[a-f0-9]+;/},a={begin:/\s/,contains:[{className:"meta-keyword",begin:/#?[a-z_][a-z1-9_-]+/,illegal:/\n/}]},u=s.inherit(a,{begin:/\(/,end:/\)/}),_=s.inherit(s.APOS_STRING_MODE,{className:"meta-string"}),w=s.inherit(s.QUOTE_STRING_MODE,{className:"meta-string"}),x={endsWithParent:!0,illegal:/`]+/}]}]}]};return{name:"HTML, XML",aliases:["html","xhtml","rss","atom","xjb","xsd","xsl","plist","wsf","svg"],case_insensitive:!0,contains:[{className:"meta",begin://,relevance:10,contains:[a,w,_,u,{begin:/\[/,end:/\]/,contains:[{className:"meta",begin://,contains:[a,u,w,_]}]}]},s.COMMENT(//,{relevance:10}),{begin://,relevance:10},i,{className:"meta",begin:/<\?xml/,end:/\?>/,relevance:10},{className:"tag",begin:/)/,end:/>/,keywords:{name:"style"},contains:[x],starts:{end:/<\/style>/,returnEnd:!0,subLanguage:["css","xml"]}},{className:"tag",begin:/)/,end:/>/,keywords:{name:"script"},contains:[x],starts:{end:/<\/script>/,returnEnd:!0,subLanguage:["javascript","handlebars","xml"]}},{className:"tag",begin:/<>|<\/>/},{className:"tag",begin:concat(//,/>/,/\s/)))),end:/\/?>/,contains:[{className:"name",begin:o,relevance:0,starts:x}]},{className:"tag",begin:concat(/<\//,lookahead(concat(o,/>/))),contains:[{className:"name",begin:o,relevance:0},{begin:/>/,relevance:0,endsParent:!0}]}]}}},17400:(s,o,i)=>{var a=i(99374),u=1/0;s.exports=function toFinite(s){return s?(s=a(s))===u||s===-1/0?17976931348623157e292*(s<0?-1:1):s==s?s:0:0===s?s:0}},17533:s=>{s.exports=function yaml(s){var o="true false yes no null",i="[\\w#;/?:@&=+$,.~*'()[\\]]+",a={className:"string",relevance:0,variants:[{begin:/'/,end:/'/},{begin:/"/,end:/"/},{begin:/\S+/}],contains:[s.BACKSLASH_ESCAPE,{className:"template-variable",variants:[{begin:/\{\{/,end:/\}\}/},{begin:/%\{/,end:/\}/}]}]},u=s.inherit(a,{variants:[{begin:/'/,end:/'/},{begin:/"/,end:/"/},{begin:/[^\s,{}[\]]+/}]}),_={className:"number",begin:"\\b[0-9]{4}(-[0-9][0-9]){0,2}([Tt \\t][0-9][0-9]?(:[0-9][0-9]){2})?(\\.[0-9]*)?([ \\t])*(Z|[-+][0-9][0-9]?(:[0-9][0-9])?)?\\b"},w={end:",",endsWithParent:!0,excludeEnd:!0,keywords:o,relevance:0},x={begin:/\{/,end:/\}/,contains:[w],illegal:"\\n",relevance:0},C={begin:"\\[",end:"\\]",contains:[w],illegal:"\\n",relevance:0},j=[{className:"attr",variants:[{begin:"\\w[\\w :\\/.-]*:(?=[ \t]|$)"},{begin:'"\\w[\\w :\\/.-]*":(?=[ \t]|$)'},{begin:"'\\w[\\w :\\/.-]*':(?=[ \t]|$)"}]},{className:"meta",begin:"^---\\s*$",relevance:10},{className:"string",begin:"[\\|>]([1-9]?[+-])?[ ]*\\n( +)[^ ][^\\n]*\\n(\\2[^\\n]+\\n?)*"},{begin:"<%[%=-]?",end:"[%-]?%>",subLanguage:"ruby",excludeBegin:!0,excludeEnd:!0,relevance:0},{className:"type",begin:"!\\w+!"+i},{className:"type",begin:"!<"+i+">"},{className:"type",begin:"!"+i},{className:"type",begin:"!!"+i},{className:"meta",begin:"&"+s.UNDERSCORE_IDENT_RE+"$"},{className:"meta",begin:"\\*"+s.UNDERSCORE_IDENT_RE+"$"},{className:"bullet",begin:"-(?=[ ]|$)",relevance:0},s.HASH_COMMENT_MODE,{beginKeywords:o,keywords:{literal:o}},_,{className:"number",begin:s.C_NUMBER_RE+"\\b",relevance:0},x,C,a],L=[...j];return L.pop(),L.push(u),w.contains=L,{name:"YAML",case_insensitive:!0,aliases:["yml"],contains:j}}},17670:(s,o,i)=>{var a=i(12651);s.exports=function mapCacheDelete(s){var o=a(this,s).delete(s);return this.size-=o?1:0,o}},17965:(s,o,i)=>{"use strict";var a=i(16426),u={"text/plain":"Text","text/html":"Url",default:"Text"};s.exports=function copy(s,o){var i,_,w,x,C,j,L=!1;o||(o={}),i=o.debug||!1;try{if(w=a(),x=document.createRange(),C=document.getSelection(),(j=document.createElement("span")).textContent=s,j.ariaHidden="true",j.style.all="unset",j.style.position="fixed",j.style.top=0,j.style.clip="rect(0, 0, 0, 0)",j.style.whiteSpace="pre",j.style.webkitUserSelect="text",j.style.MozUserSelect="text",j.style.msUserSelect="text",j.style.userSelect="text",j.addEventListener("copy",(function(a){if(a.stopPropagation(),o.format)if(a.preventDefault(),void 0===a.clipboardData){i&&console.warn("unable to use e.clipboardData"),i&&console.warn("trying IE specific stuff"),window.clipboardData.clearData();var _=u[o.format]||u.default;window.clipboardData.setData(_,s)}else a.clipboardData.clearData(),a.clipboardData.setData(o.format,s);o.onCopy&&(a.preventDefault(),o.onCopy(a.clipboardData))})),document.body.appendChild(j),x.selectNodeContents(j),C.addRange(x),!document.execCommand("copy"))throw new Error("copy command was unsuccessful");L=!0}catch(a){i&&console.error("unable to copy using execCommand: ",a),i&&console.warn("trying IE specific stuff");try{window.clipboardData.setData(o.format||"text",s),o.onCopy&&o.onCopy(window.clipboardData),L=!0}catch(a){i&&console.error("unable to copy using clipboardData: ",a),i&&console.error("falling back to prompt"),_=function format(s){var o=(/mac os x/i.test(navigator.userAgent)?"⌘":"Ctrl")+"+C";return s.replace(/#{\s*key\s*}/g,o)}("message"in o?o.message:"Copy to clipboard: #{key}, Enter"),window.prompt(_,s)}}finally{C&&("function"==typeof C.removeRange?C.removeRange(x):C.removeAllRanges()),j&&document.body.removeChild(j),w()}return L}},18073:(s,o,i)=>{var a=i(85087),u=i(54641),_=i(70981);s.exports=function createRecurry(s,o,i,w,x,C,j,L,B,$){var U=8&o;o|=U?32:64,4&(o&=~(U?64:32))||(o&=-4);var V=[s,o,x,U?C:void 0,U?j:void 0,U?void 0:C,U?void 0:j,L,B,$],z=i.apply(void 0,V);return a(s)&&u(z,V),z.placeholder=w,_(z,s,o)}},19123:(s,o,i)=>{var a=i(65606),u=i(31499),_=i(88310).Stream;function resolve(s,o,i){var a,_=function create_indent(s,o){return new Array(o||0).join(s||"")}(o,i=i||0),w=s;if("object"==typeof s&&((w=s[a=Object.keys(s)[0]])&&w._elem))return w._elem.name=a,w._elem.icount=i,w._elem.indent=o,w._elem.indents=_,w._elem.interrupt=w,w._elem;var x,C=[],j=[];function get_attributes(s){Object.keys(s).forEach((function(o){C.push(function attribute(s,o){return s+'="'+u(o)+'"'}(o,s[o]))}))}switch(typeof w){case"object":if(null===w)break;w._attr&&get_attributes(w._attr),w._cdata&&j.push(("/g,"]]]]>")+"]]>"),w.forEach&&(x=!1,j.push(""),w.forEach((function(s){"object"==typeof s?"_attr"==Object.keys(s)[0]?get_attributes(s._attr):j.push(resolve(s,o,i+1)):(j.pop(),x=!0,j.push(u(s)))})),x||j.push(""));break;default:j.push(u(w))}return{name:a,interrupt:!1,attributes:C,content:j,icount:i,indents:_,indent:o}}function format(s,o,i){if("object"!=typeof o)return s(!1,o);var a=o.interrupt?1:o.content.length;function proceed(){for(;o.content.length;){var u=o.content.shift();if(void 0!==u){if(interrupt(u))return;format(s,u)}}s(!1,(a>1?o.indents:"")+(o.name?"":"")+(o.indent&&!i?"\n":"")),i&&i()}function interrupt(o){return!!o.interrupt&&(o.interrupt.append=s,o.interrupt.end=proceed,o.interrupt=!1,s(!0),!0)}if(s(!1,o.indents+(o.name?"<"+o.name:"")+(o.attributes.length?" "+o.attributes.join(" "):"")+(a?o.name?">":"":o.name?"/>":"")+(o.indent&&a>1?"\n":"")),!a)return s(!1,o.indent?"\n":"");interrupt(o)||proceed()}s.exports=function xml(s,o){"object"!=typeof o&&(o={indent:o});var i=o.stream?new _:null,u="",w=!1,x=o.indent?!0===o.indent?" ":o.indent:"",C=!0;function delay(s){C?a.nextTick(s):s()}function append(s,o){if(void 0!==o&&(u+=o),s&&!w&&(i=i||new _,w=!0),s&&w){var a=u;delay((function(){i.emit("data",a)})),u=""}}function add(s,o){format(append,resolve(s,x,x?1:0),o)}function end(){if(i){var s=u;delay((function(){i.emit("data",s),i.emit("end"),i.readable=!1,i.emit("close")}))}}return delay((function(){C=!1})),o.declaration&&function addXmlDeclaration(s){var o={version:"1.0",encoding:s.encoding||"UTF-8"};s.standalone&&(o.standalone=s.standalone),add({"?xml":{_attr:o}}),u=u.replace("/>","?>")}(o.declaration),s&&s.forEach?s.forEach((function(o,i){var a;i+1===s.length&&(a=end),add(o,a)})):add(s,end),i?(i.readable=!0,i):u},s.exports.element=s.exports.Element=function element(){var s={_elem:resolve(Array.prototype.slice.call(arguments)),push:function(s){if(!this.append)throw new Error("not assigned to a parent!");var o=this,i=this._elem.indent;format(this.append,resolve(s,i,this._elem.icount+(i?1:0)),(function(){o.append(!0)}))},close:function(s){void 0!==s&&this.push(s),this.end&&this.end()}};return s}},19219:s=>{s.exports=function cacheHas(s,o){return s.has(o)}},19287:s=>{"use strict";s.exports={CSSRuleList:0,CSSStyleDeclaration:0,CSSValueList:0,ClientRectList:0,DOMRectList:0,DOMStringList:0,DOMTokenList:1,DataTransferItemList:0,FileList:0,HTMLAllCollection:0,HTMLCollection:0,HTMLFormElement:0,HTMLSelectElement:0,MediaList:0,MimeTypeArray:0,NamedNodeMap:0,NodeList:1,PaintRequestList:0,Plugin:0,PluginArray:0,SVGLengthList:0,SVGNumberList:0,SVGPathSegList:0,SVGPointList:0,SVGStringList:0,SVGTransformList:0,SourceBufferList:0,StyleSheetList:0,TextTrackCueList:0,TextTrackList:0,TouchList:0}},19358:(s,o,i)=>{"use strict";var a=i(85582),u=i(49724),_=i(61626),w=i(88280),x=i(79192),C=i(19595),j=i(54829),L=i(34084),B=i(32096),$=i(39259),U=i(85884),V=i(39447),z=i(7376);s.exports=function(s,o,i,Y){var Z="stackTraceLimit",ee=Y?2:1,ie=s.split("."),ae=ie[ie.length-1],ce=a.apply(null,ie);if(ce){var le=ce.prototype;if(!z&&u(le,"cause")&&delete le.cause,!i)return ce;var pe=a("Error"),de=o((function(s,o){var i=B(Y?o:s,void 0),a=Y?new ce(s):new ce;return void 0!==i&&_(a,"message",i),U(a,de,a.stack,2),this&&w(le,this)&&L(a,this,de),arguments.length>ee&&$(a,arguments[ee]),a}));if(de.prototype=le,"Error"!==ae?x?x(de,pe):C(de,pe,{name:!0}):V&&Z in ce&&(j(de,ce,Z),j(de,ce,"prepareStackTrace")),C(de,ce),!z)try{le.name!==ae&&_(le,"name",ae),le.constructor=de}catch(s){}return de}}},19570:(s,o,i)=>{var a=i(37334),u=i(93243),_=i(83488),w=u?function(s,o){return u(s,"toString",{configurable:!0,enumerable:!1,value:a(o),writable:!0})}:_;s.exports=w},19595:(s,o,i)=>{"use strict";var a=i(49724),u=i(11042),_=i(13846),w=i(74284);s.exports=function(s,o,i){for(var x=u(o),C=w.f,j=_.f,L=0;L{"use strict";var a=i(23034);s.exports=a},19846:(s,o,i)=>{"use strict";var a=i(20798),u=i(98828),_=i(45951).String;s.exports=!!Object.getOwnPropertySymbols&&!u((function(){var s=Symbol("symbol detection");return!_(s)||!(Object(s)instanceof Symbol)||!Symbol.sham&&a&&a<41}))},19931:(s,o,i)=>{var a=i(31769),u=i(68090),_=i(68969),w=i(77797);s.exports=function baseUnset(s,o){return o=a(o,s),null==(s=_(s,o))||delete s[w(u(o))]}},20181:(s,o,i)=>{var a=/^\s+|\s+$/g,u=/^[-+]0x[0-9a-f]+$/i,_=/^0b[01]+$/i,w=/^0o[0-7]+$/i,x=parseInt,C="object"==typeof i.g&&i.g&&i.g.Object===Object&&i.g,j="object"==typeof self&&self&&self.Object===Object&&self,L=C||j||Function("return this")(),B=Object.prototype.toString,$=Math.max,U=Math.min,now=function(){return L.Date.now()};function isObject(s){var o=typeof s;return!!s&&("object"==o||"function"==o)}function toNumber(s){if("number"==typeof s)return s;if(function isSymbol(s){return"symbol"==typeof s||function isObjectLike(s){return!!s&&"object"==typeof s}(s)&&"[object Symbol]"==B.call(s)}(s))return NaN;if(isObject(s)){var o="function"==typeof s.valueOf?s.valueOf():s;s=isObject(o)?o+"":o}if("string"!=typeof s)return 0===s?s:+s;s=s.replace(a,"");var i=_.test(s);return i||w.test(s)?x(s.slice(2),i?2:8):u.test(s)?NaN:+s}s.exports=function debounce(s,o,i){var a,u,_,w,x,C,j=0,L=!1,B=!1,V=!0;if("function"!=typeof s)throw new TypeError("Expected a function");function invokeFunc(o){var i=a,_=u;return a=u=void 0,j=o,w=s.apply(_,i)}function shouldInvoke(s){var i=s-C;return void 0===C||i>=o||i<0||B&&s-j>=_}function timerExpired(){var s=now();if(shouldInvoke(s))return trailingEdge(s);x=setTimeout(timerExpired,function remainingWait(s){var i=o-(s-C);return B?U(i,_-(s-j)):i}(s))}function trailingEdge(s){return x=void 0,V&&a?invokeFunc(s):(a=u=void 0,w)}function debounced(){var s=now(),i=shouldInvoke(s);if(a=arguments,u=this,C=s,i){if(void 0===x)return function leadingEdge(s){return j=s,x=setTimeout(timerExpired,o),L?invokeFunc(s):w}(C);if(B)return x=setTimeout(timerExpired,o),invokeFunc(C)}return void 0===x&&(x=setTimeout(timerExpired,o)),w}return o=toNumber(o)||0,isObject(i)&&(L=!!i.leading,_=(B="maxWait"in i)?$(toNumber(i.maxWait)||0,o):_,V="trailing"in i?!!i.trailing:V),debounced.cancel=function cancel(){void 0!==x&&clearTimeout(x),j=0,a=C=u=x=void 0},debounced.flush=function flush(){return void 0===x?w:trailingEdge(now())},debounced}},20317:s=>{s.exports=function mapToArray(s){var o=-1,i=Array(s.size);return s.forEach((function(s,a){i[++o]=[a,s]})),i}},20334:(s,o,i)=>{"use strict";var a=i(48287).Buffer;class NonError extends Error{constructor(s){super(NonError._prepareSuperMessage(s)),Object.defineProperty(this,"name",{value:"NonError",configurable:!0,writable:!0}),Error.captureStackTrace&&Error.captureStackTrace(this,NonError)}static _prepareSuperMessage(s){try{return JSON.stringify(s)}catch{return String(s)}}}const u=[{property:"name",enumerable:!1},{property:"message",enumerable:!1},{property:"stack",enumerable:!1},{property:"code",enumerable:!0}],_=Symbol(".toJSON called"),destroyCircular=({from:s,seen:o,to_:i,forceEnumerable:w,maxDepth:x,depth:C})=>{const j=i||(Array.isArray(s)?[]:{});if(o.push(s),C>=x)return j;if("function"==typeof s.toJSON&&!0!==s[_])return(s=>{s[_]=!0;const o=s.toJSON();return delete s[_],o})(s);for(const[i,u]of Object.entries(s))"function"==typeof a&&a.isBuffer(u)?j[i]="[object Buffer]":"function"!=typeof u&&(u&&"object"==typeof u?o.includes(s[i])?j[i]="[Circular]":(C++,j[i]=destroyCircular({from:s[i],seen:o.slice(),forceEnumerable:w,maxDepth:x,depth:C})):j[i]=u);for(const{property:o,enumerable:i}of u)"string"==typeof s[o]&&Object.defineProperty(j,o,{value:s[o],enumerable:!!w||i,configurable:!0,writable:!0});return j};s.exports={serializeError:(s,o={})=>{const{maxDepth:i=Number.POSITIVE_INFINITY}=o;return"object"==typeof s&&null!==s?destroyCircular({from:s,seen:[],forceEnumerable:!0,maxDepth:i,depth:0}):"function"==typeof s?`[Function: ${s.name||"anonymous"}]`:s},deserializeError:(s,o={})=>{const{maxDepth:i=Number.POSITIVE_INFINITY}=o;if(s instanceof Error)return s;if("object"==typeof s&&null!==s&&!Array.isArray(s)){const o=new Error;return destroyCircular({from:s,seen:[],to_:o,maxDepth:i,depth:0}),o}return new NonError(s)}}},20426:s=>{var o=Object.prototype.hasOwnProperty;s.exports=function baseHas(s,i){return null!=s&&o.call(s,i)}},20575:(s,o,i)=>{"use strict";var a=i(3121);s.exports=function(s){return a(s.length)}},20798:(s,o,i)=>{"use strict";var a,u,_=i(45951),w=i(96794),x=_.process,C=_.Deno,j=x&&x.versions||C&&C.version,L=j&&j.v8;L&&(u=(a=L.split("."))[0]>0&&a[0]<4?1:+(a[0]+a[1])),!u&&w&&(!(a=w.match(/Edge\/(\d+)/))||a[1]>=74)&&(a=w.match(/Chrome\/(\d+)/))&&(u=+a[1]),s.exports=u},20850:(s,o,i)=>{"use strict";s.exports=i(46076)},20999:(s,o,i)=>{var a=i(69302),u=i(36800);s.exports=function createAssigner(s){return a((function(o,i){var a=-1,_=i.length,w=_>1?i[_-1]:void 0,x=_>2?i[2]:void 0;for(w=s.length>3&&"function"==typeof w?(_--,w):void 0,x&&u(i[0],i[1],x)&&(w=_<3?void 0:w,_=1),o=Object(o);++a<_;){var C=i[a];C&&s(o,C,a,w)}return o}))}},21549:(s,o,i)=>{var a=i(22032),u=i(63862),_=i(66721),w=i(12749),x=i(35749);function Hash(s){var o=-1,i=null==s?0:s.length;for(this.clear();++o{var a=i(16547),u=i(43360);s.exports=function copyObject(s,o,i,_){var w=!i;i||(i={});for(var x=-1,C=o.length;++x{var a=i(51873),u=i(37828),_=i(75288),w=i(25911),x=i(20317),C=i(84247),j=a?a.prototype:void 0,L=j?j.valueOf:void 0;s.exports=function equalByTag(s,o,i,a,j,B,$){switch(i){case"[object DataView]":if(s.byteLength!=o.byteLength||s.byteOffset!=o.byteOffset)return!1;s=s.buffer,o=o.buffer;case"[object ArrayBuffer]":return!(s.byteLength!=o.byteLength||!B(new u(s),new u(o)));case"[object Boolean]":case"[object Date]":case"[object Number]":return _(+s,+o);case"[object Error]":return s.name==o.name&&s.message==o.message;case"[object RegExp]":case"[object String]":return s==o+"";case"[object Map]":var U=x;case"[object Set]":var V=1&a;if(U||(U=C),s.size!=o.size&&!V)return!1;var z=$.get(s);if(z)return z==o;a|=2,$.set(s,o);var Y=w(U(s),U(o),a,j,B,$);return $.delete(s),Y;case"[object Symbol]":if(L)return L.call(s)==L.call(o)}return!1}},22032:(s,o,i)=>{var a=i(81042);s.exports=function hashClear(){this.__data__=a?a(null):{},this.size=0}},22225:s=>{var o="\\ud800-\\udfff",i="\\u2700-\\u27bf",a="a-z\\xdf-\\xf6\\xf8-\\xff",u="A-Z\\xc0-\\xd6\\xd8-\\xde",_="\\xac\\xb1\\xd7\\xf7\\x00-\\x2f\\x3a-\\x40\\x5b-\\x60\\x7b-\\xbf\\u2000-\\u206f \\t\\x0b\\f\\xa0\\ufeff\\n\\r\\u2028\\u2029\\u1680\\u180e\\u2000\\u2001\\u2002\\u2003\\u2004\\u2005\\u2006\\u2007\\u2008\\u2009\\u200a\\u202f\\u205f\\u3000",w="["+_+"]",x="\\d+",C="["+i+"]",j="["+a+"]",L="[^"+o+_+x+i+a+u+"]",B="(?:\\ud83c[\\udde6-\\uddff]){2}",$="[\\ud800-\\udbff][\\udc00-\\udfff]",U="["+u+"]",V="(?:"+j+"|"+L+")",z="(?:"+U+"|"+L+")",Y="(?:['’](?:d|ll|m|re|s|t|ve))?",Z="(?:['’](?:D|LL|M|RE|S|T|VE))?",ee="(?:[\\u0300-\\u036f\\ufe20-\\ufe2f\\u20d0-\\u20ff]|\\ud83c[\\udffb-\\udfff])?",ie="[\\ufe0e\\ufe0f]?",ae=ie+ee+("(?:\\u200d(?:"+["[^"+o+"]",B,$].join("|")+")"+ie+ee+")*"),ce="(?:"+[C,B,$].join("|")+")"+ae,le=RegExp([U+"?"+j+"+"+Y+"(?="+[w,U,"$"].join("|")+")",z+"+"+Z+"(?="+[w,U+V,"$"].join("|")+")",U+"?"+V+"+"+Y,U+"+"+Z,"\\d*(?:1ST|2ND|3RD|(?![123])\\dTH)(?=\\b|[a-z_])","\\d*(?:1st|2nd|3rd|(?![123])\\dth)(?=\\b|[A-Z_])",x,ce].join("|"),"g");s.exports=function unicodeWords(s){return s.match(le)||[]}},22551:(s,o,i)=>{"use strict";var a=i(96540),u=i(69982);function p(s){for(var o="https://reactjs.org/docs/error-decoder.html?invariant="+s,i=1;i